In [1]:
#step 1 import image
%matplotlib inline
import torchvision.datasets
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.utils import make_grid
from PIL import Image
from time import time
from tqdm import tqdm
from transformers import ViTModel

device = torch.device('cuda:1')
image_size = 224
batch_size = 128
dataset = torchvision.datasets.CelebA("../../celeba/datasets/",split='train', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
                              ]))

test_dataset = torchvision.datasets.CelebA("../../celeba/datasets/",split='test', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
                              ]))

training_data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
print('Done')

  from .autonotebook import tqdm as notebook_tqdm


Done


In [2]:
def choose_value_patch(attention, value, k):
    """
    Get top-k attention values based on average attention weights across heads.

    Parameters:
    - attention (tensor): Shape [Batch, Head, token]
    - value (tensor): Shape [Batch, token, dim]
    - k (int): Number of top attention values to select (default is 2)

    Returns:
    - top_k_values (tensor): Shape [Batch, k, dim]
    """

    # Average attention across the head dimension.
    avg_attention = attention.mean(dim=1)

    # Get the top-k attention indices.
    _, top_k_indices = avg_attention.topk(k, dim=-1)

    # Gather the top-k values.
    batch_size, _ = top_k_indices.shape
    batch_indices = torch.arange(batch_size)[:, None].to(top_k_indices.device)
    top_k_values = value[batch_indices, top_k_indices].view(batch_size, k, -1)

    return top_k_values

class Last_Attention(nn.Module):
    def __init__(self):
        super(Last_Attention, self).__init__()
        self.p_dim = 5
        self.emb_size = 384
        self.head = 6
        self.head_dim = self.emb_size //self.head
        self.Q = nn.Linear(384,384)
        self.K = nn.Linear(384,384)
        self.V = nn.Linear(384,384)
        self.projection = nn.Linear(384, 384)
        self.soft_max = nn.Softmax(dim=-1)
        self.projector = nn.Sequential(
            nn.Linear(self.p_dim*384, 256, bias=False),
            nn.ReLU(),
            nn.Linear(256, 128, bias=False),
        )
        self.momentum = 0.1
        self.register_buffer('running_mean_q', torch.zeros(1,6,197,64))
        self.register_buffer('running_std_q', torch.ones(1,6,197,64))
        self.register_buffer('running_mean_k', torch.zeros(1,6,197,64))
        self.register_buffer('running_std_k', torch.ones(1,6,197,64))
        
    def forward(self, x, training):
        B, N, C = x.shape
        origin_k = self.K(x)
        origin_q = self.Q(x)
        origin_v = self.V(x)
        
        q = origin_q.reshape(B,N,self.head, C//self.head).permute(0,2,1,3)
        k = origin_k.reshape(B,N,self.head, C//self.head).permute(0,2,1,3)
        v = origin_v.reshape(B,N,self.head, C//self.head).permute(0,2,1,3)
        
        self.running_mean_q = self.running_mean_q.detach()
        self.running_std_q = self.running_std_q.detach()
        self.running_mean_k = self.running_mean_k.detach()
        self.running_std_k = self.running_std_k.detach()

        if training:
            with torch.no_grad():
                q_mean, q_std = torch.mean(q, 0, keepdim=True), torch.std(q, 0, keepdim=True)
                k_mean, k_std = torch.mean(k, 0, keepdim=True), torch.std(k, 0, keepdim=True)  

                self.running_mean_q = (1 - self.momentum) * self.running_mean_q.to(device) + self.momentum * q_mean
                self.running_std_q = (1 - self.momentum) * self.running_std_q.to(device) + self.momentum * q_std
                self.running_mean_k = (1 - self.momentum) * self.running_mean_k.to(device) + self.momentum * k_mean
                self.running_std_k = (1 - self.momentum) * self.running_std_k.to(device) + self.momentum * k_std

        else:
            with torch.no_grad():
                q_mean = self.running_mean_q
                q_std = self.running_std_q
                k_mean = self.running_mean_k
                k_std = self.running_std_k
        
        q = (q - q_mean) 
        q = torch.abs(q)
        k = (k - k_mean) 
        k = torch.abs(k)
        
        attention = (q @ k.transpose(-2,-1))* (self.head_dim ** (-0.5))
        atten = self.soft_max(attention)
        out = (atten @ v).transpose(1, 2).reshape(B, N, C)
        out = self.projection(out)
        attentions = atten[:,:, 0, :]
        v = v.transpose(1, 2).reshape(B, N, C)
        mst_val = choose_value_patch(attentions, v, self.p_dim)
        mst_val = mst_val.reshape(B, -1)
        
        mst_val = self.projector(mst_val)
        z = F.normalize(mst_val, dim=1)
        return out, z, atten

    
class Last_ATBlock(nn.Module):
    def __init__(self):
        super().__init__()
        dim = 384
        self.norm = nn.LayerNorm(dim)
        self.attention = Last_Attention()
        self.norm2 = nn.LayerNorm(dim)
        self.feedforward = nn.Sequential(
            nn.Linear(384, 384),
            nn.ReLU(),
            nn.Linear(384, 384)          
        )
        

        
    def forward(self, x, training):
        identity = x
        x = self.norm(x)
        x, vz, att = self.attention(x, training)
        x += identity
        res = x 
        x = self.norm2(x)
        x = self.feedforward(x)
        x += res
        return x, vz, att

In [3]:
class VisionTransformer(nn.Module):
    def __init__(self):
        super(VisionTransformer, self).__init__()
        self.last_layer = Last_ATBlock()
        self.seq = nn.Sequential(
            nn.Linear(384, 128),
            nn.ReLU(),
            nn.Linear(128, 1),     
            nn.Sigmoid()
        )
    
    def forward(self, x, training):
        x, vz, att = self.last_layer(x, training)
        x = x[:,0]
        y = self.seq(x)
        return y, vz, att

In [4]:
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
import seaborn as sns


def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False



class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss    
    
    
    



def train_model():
    epoch = 12
    DINO = ViTModel.from_pretrained('facebook/dino-vits16').to(device)
    model = VisionTransformer()
    model = model.to(device)
    criterion = nn.BCELoss()
    fair_criterion = SupConLoss()
    for name, param in DINO.named_parameters():
        param.requires_grad = False
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    

    
    for epoches in range(epoch):
        with tqdm(training_data_loader, unit="batch") as tepoch:
            model.train()
            for train_input, attributes in tepoch:
                # Transfer data to GPU if possible. 
                train_input = train_input.to(device)
                sensitive, train_target = attributes[:,20], attributes[:,9]
               
                train_target = train_target.float().to(device)
                #train_target = train_target.unsqueeze(1)
                optimizer.zero_grad()

                with torch.no_grad():
                    y = DINO(train_input)
                outputs, value, _ = model(y.last_hidden_state, True)
                value = value.unsqueeze(1)
                fair_loss = fair_criterion(value, train_target.squeeze())
                train_target = train_target.unsqueeze(1)
                ut_loss = criterion(outputs, train_target)
                
                loss =  ut_loss + fair_loss
                tepoch.set_postfix(ul = ut_loss.item(),fl = fair_loss.item())  
                    
                    
                loss.backward()
                #logger_learner.add_values(logging_dict)
                optimizer.step()
                tepoch.set_description(f"epoch %2f " % epoches)
                

        # Reset the dataloader if out of data.
        #model.load_state_dict(torch.load(PATH), False)
        
        model.eval()
        test_pred = []
        test_gt = []
        sense_gt = []
        female_predic = []
        female_gt = []
        male_predic = []
        male_gt = []



    # Evaluate on test set.
        for step, (test_input, attributes) in enumerate(test_data_loader):
            sensitive, test_target = attributes[:,20], attributes[:,9]
            test_input = test_input.to(device)
            test_target = test_target.to(device)

            gt = test_target.detach().cpu().numpy()
            sen = sensitive.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)

            # Todo: split according to sensitive attribute
            # Todo: combine all batch togather

            with torch.no_grad():
                y = DINO(test_input)
                test_pred_, _, _ = model(y.last_hidden_state, False)
                test_pred.extend(torch.round(test_pred_.squeeze(1)).detach().cpu().numpy())

        for i in range(len(sense_gt)):
            if sense_gt[i] == 0:
                female_predic.append(test_pred[i])
                female_gt.append(test_gt[i])
            else:
                male_predic.append(test_pred[i])
                male_gt.append(test_gt[i])
        female_CM = confusion_matrix(female_gt, female_predic)    
        male_CM = confusion_matrix(male_gt, male_predic) 
        female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
        male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
        female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
        male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
        female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
        male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])

        print('Female TPR', female_TPR)
        print('male TPR', male_TPR)
        print('DP',abs(female_dp - male_dp))
        print('EOP', abs(female_TPR - male_TPR))
        print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
        print('acc', accuracy_score(test_gt, test_pred))


seed_everything(0)    
train_model()

Some weights of ViTModel were not initialized from the model checkpoint at facebook/dino-vits16 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
epoch 0.000000 : 100%|████████████████████████████████████████| 1271/1271 [08:36<00:00,  2.46batch/s, fl=4.66, ul=0.105]


Female TPR 0.8483870967741935
male TPR 0.25555555555555554
DP 0.19888178733681178
EOP 0.592831541218638
EoD 0.3175224738485421
acc 0.950806532411582


epoch 1.000000 : 100%|███████████████████████████████████████| 1271/1271 [09:24<00:00,  2.25batch/s, fl=4.76, ul=0.0865]


Female TPR 0.8370967741935483
male TPR 0.3055555555555556
DP 0.1909717464369308
EOP 0.5315412186379928
EoD 0.2840977375786473
acc 0.952459673379421


epoch 2.000000 : 100%|███████████████████████████████████████| 1271/1271 [09:54<00:00,  2.14batch/s, fl=4.71, ul=0.0984]


Female TPR 0.8641129032258065
male TPR 0.37777777777777777
DP 0.19860538544469286
EOP 0.4863351254480287
EoD 0.2640562668764869
acc 0.9528103396453261


epoch 3.000000 : 100%|███████████████████████████████████████| 1271/1271 [09:56<00:00,  2.13batch/s, fl=4.67, ul=0.0666]


Female TPR 0.7810483870967742
male TPR 0.21666666666666667
DP 0.17405312692869987
EOP 0.5643817204301075
EoD 0.2954877582307912
acc 0.9514577697625488


epoch 4.000000 : 100%|████████████████████████████████████████| 1271/1271 [09:54<00:00,  2.14batch/s, fl=4.73, ul=0.117]


Female TPR 0.7923387096774194
male TPR 0.20555555555555555
DP 0.1789043209280398
EOP 0.5867831541218638
EoD 0.30817876621248963
acc 0.9509568179541128


epoch 5.000000 : 100%|████████████████████████████████████████| 1271/1271 [07:42<00:00,  2.75batch/s, fl=4.67, ul=0.111]


Female TPR 0.8451612903225807
male TPR 0.35
DP 0.1916220783049439
EOP 0.4951612903225807
EoD 0.26625483133687217
acc 0.9511071034966436


epoch 6.000000 : 100%|███████████████████████████████████████| 1271/1271 [08:47<00:00,  2.41batch/s, fl=4.66, ul=0.0742]


Female TPR 0.867741935483871
male TPR 0.42777777777777776
DP 0.19714814759348792
EOP 0.43996415770609326
EoD 0.24074957804065406
acc 0.94980462879471


epoch 7.000000 : 100%|████████████████████████████████████████| 1271/1271 [09:22<00:00,  2.26batch/s, fl=4.69, ul=0.101]


Female TPR 0.8169354838709677
male TPR 0.34444444444444444
DP 0.18310565932239095
EOP 0.47249103942652326
EoD 0.2529781436669092
acc 0.9502554854223024


epoch 8.000000 : 100%|████████████████████████████████████████| 1271/1271 [10:30<00:00,  2.02batch/s, fl=4.69, ul=0.102]


Female TPR 0.8149193548387097
male TPR 0.26666666666666666
DP 0.18615424010885043
EOP 0.5482526881720431
EoD 0.291784202172808
acc 0.9492034866245868


epoch 9.000000 : 100%|████████████████████████████████████████| 1271/1271 [14:05<00:00,  1.50batch/s, fl=4.6, ul=0.0641]


Female TPR 0.7435483870967742
male TPR 0.21666666666666667
DP 0.16393445329582215
EOP 0.5268817204301075
EoD 0.27548267538441223
acc 0.9454964432421601


epoch 10.000000 : 100%|██████████████████████████████████████| 1271/1271 [14:42<00:00,  1.44batch/s, fl=4.68, ul=0.0303]


Female TPR 0.8213709677419355
male TPR 0.4722222222222222
DP 0.18133833190780754
EOP 0.3491487455197133
EoD 0.19229504137916428
acc 0.9448953010720369


epoch 11.000000 : 100%|██████████████████████████████████████| 1271/1271 [15:09<00:00,  1.40batch/s, fl=4.58, ul=0.0737]


Female TPR 0.8358870967741936
male TPR 0.49444444444444446
DP 0.18288796478574698
EOP 0.34144265232974913
EoD 0.18852179711138548
acc 0.9419897805831079
