In [None]:
#step 1 import image
%matplotlib inline
import torchvision.datasets
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.utils import make_grid
from PIL import Image
from time import time
from tqdm import tqdm
from transformers import ViTConfig, ViTModel

device = torch.device('cuda:1')

"""
data_root = "../celeba/datasets"

base_url = "https://graal.ift.ulaval.ca/public/celeba/"

file_list = [
    "img_align_celeba.zip",
    "list_attr_celeba.txt",
    "identity_CelebA.txt",
    "list_bbox_celeba.txt",
    "list_landmarks_align_celeba.txt",
    "list_eval_partition.txt",
]

# Path to folder with the dataset
dataset_folder = f"{data_root}/celeba"
os.makedirs(dataset_folder, exist_ok=True)

for file in file_list:
    url = f"{base_url}/{file}"
    if not os.path.exists(f"{dataset_folder}/{file}"):
        wget.download(url, f"{dataset_folder}/{file}")

with zipfile.ZipFile(f"{dataset_folder}/img_align_celeba.zip", "r") as ziphandler:
    ziphandler.extractall(dataset_folder)
"""

image_size = 64
batch_size = 256
dataset = torchvision.datasets.CelebA("../celeba/datasets/",split='train', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
                              ]))

test_dataset = torchvision.datasets.CelebA("../celeba/datasets/",split='test', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
                              ]))

training_data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
print('Done')

In [None]:
from transformers import ViTConfig, ViTModel
configuration = ViTConfig(num_hidden_layers = 8, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= 64, patch_size = 16)
model = ViTModel(configuration)
configuration = model.config
t = iter(test_data_loader)
img, label = next(t)
img
y = model(img)
y.last_hidden_state.shape

In [None]:
def choose_value_patch(atten, value, p_dim):
    # input insturction: 
    # atten: shape: Batch, Head, Patch
    # value: Batch, Head, Patch, Dim
    # Output: Batch, Head, Selct_Patch, dim
    atten = atten[:,:,1:]
    top_k_values, top_k_indices = torch.topk(atten, k=p_dim, dim=2, sorted=False)
    #top_k_indices : Batch, Head, Select_patch
    output = torch.gather(value, 2, top_k_indices.unsqueeze(-1).expand(-1,-1,-1,value.size(-1)))
    return output
    
class Last_Attention(nn.Module):
    def __init__(self):
        super(Last_Attention, self).__init__()
        self.p_dim = 3
        self.emb_size = 768
        self.head = 8
        self.temperature = 1
        self.head_dim = self.emb_size //self.head
        self.Q = nn.Linear(768,768)
        self.K = nn.Linear(768,768)
        self.V = nn.Linear(768,768)
        self.projection = nn.Linear(768, 768)
        self.soft_max = nn.Softmax(dim=-1)
        self.projector = nn.Sequential(
            nn.Linear(self.p_dim*768, 256, bias=False),
            nn.ReLU(),
            nn.Linear(256, 128, bias=False),
        )
        self.momentum = 0.05
        self.register_buffer('running_mean_q', torch.zeros(1,8,17,96))
        self.register_buffer('running_std_q', torch.ones(1,8,17,96))
        self.register_buffer('running_mean_k', torch.zeros(1,8,17,96))
        self.register_buffer('running_std_k', torch.ones(1,8,17,96))

    def register_buffer(self, name, tensor):
        setattr(self, name, tensor)
        
    def forward(self, x, training=True):
        B, N, C = x.shape
        origin_k = self.K(x)
        origin_q = self.Q(x)
        origin_v = self.V(x)
        self.running_mean_q = self.running_mean_q.detach()
        self.running_std_q = self.running_std_q.detach()
        self.running_mean_k = self.running_mean_k.detach()
        self.running_std_k = self.running_std_k.detach()
        q = origin_q.reshape(B,N,self.head, C//self.head).permute(0,2,1,3)
        k = origin_k.reshape(B,N,self.head, C//self.head).permute(0,2,1,3)
        
        if training:
            with torch.no_grad():
                q_mean, q_std = torch.mean(q, 0, keepdim=True), torch.std(q, 0, keepdim=True)
                k_mean, k_std = torch.mean(k, 0, keepdim=True), torch.std(k, 0, keepdim=True)  

                self.running_mean_q = (1 - self.momentum) * self.running_mean_q.to(device) + self.momentum * q_mean
                self.running_std_q = (1 - self.momentum) * self.running_std_q.to(device) + self.momentum * q_std
                self.running_mean_k = (1 - self.momentum) * self.running_mean_k.to(device) + self.momentum * k_mean
                self.running_std_k = (1 - self.momentum) * self.running_std_k.to(device) + self.momentum * k_std
        else:
            with torch.no_grad():
                q_mean = self.running_mean_q
                q_std = self.running_std_q
                k_mean = self.running_mean_k
                k_std = self.running_std_k
        
        q = (q - q_mean) /q_std
        q = torch.abs(q)
        k = (k - k_mean) /k_std
        k = torch.abs(k)
        
        v = origin_v.reshape(B,N,self.head, C//self.head).permute(0,2,1,3)
        attention = (q @ k.transpose(-2,-1))* (self.head_dim ** (-0.5))    
        atten = self.soft_max(attention/self.temperature)
        out = (atten @ v).transpose(1, 2).reshape(B, N, C)
        out = self.projection(out)
        attentions = atten[:,:, 0, :]
        mst_val = choose_value_patch(attentions, v, self.p_dim)
        mst_val = mst_val.reshape(B, -1)
        mst_val = self.projector(mst_val)
        z = F.normalize(mst_val, dim=1)
        return out, z

   

    
class Last_ATBlock(nn.Module):
    def __init__(self):
        super().__init__()
        dim = 768
        self.norm = nn.LayerNorm(dim)
        self.attention = Last_Attention()
        self.norm2 = nn.LayerNorm(dim)
        self.feedforward = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 768)          
        )
        
    def forward(self, x, training=True):
        identity = x
        x = self.norm(x)
        x, vz = self.attention(x, training)
        x += identity
        res = x 
        x = self.norm2(x)
        x = self.feedforward(x)
        x += res
        return x, vz

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, vit):
        super(VisionTransformer, self).__init__()
        self.vit = vit
        self.last_encoder = Last_ATBlock()
        self.seq = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 1),     
            nn.Sigmoid()
        )
    
    def forward(self, x, training=True):
        z = self.vit(x)
        m = z.last_hidden_state
        m, vz = self.last_encoder(m, training)
        g = m[:,0]
        y = self.seq(g)
        return y, vz 

In [None]:
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
import seaborn as sns

def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss    

    

def train_model():
    configuration = ViTConfig(num_hidden_layers = 7, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= 64, patch_size = 16)
    vit = ViTModel(configuration)
    configuration = vit.config
    vit = vit.to(device)
    model = VisionTransformer(vit)
    model = model.to(device)
    epoch = 20
    criterion = nn.BCELoss()
    fair_criterion = SupConLoss()
    fair_optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    
    for epoches in range(epoch):
        with tqdm(training_data_loader, unit="batch") as tepoch:
            model.train()
            for train_input, attributes in tepoch:
                # Transfer data to GPU if possible. 
                train_input = train_input.to(device)
                train_target = attributes[:,2]
                train_target = train_target.float().to(device)
                fair_optimizer.zero_grad()

                # Learner update step.
                #if fairness processorigin_v
                outputs, value = model(train_input)
                value = value.unsqueeze(1)
                fair_loss = fair_criterion(value, train_target.squeeze())
                train_target = train_target.unsqueeze(1)
                ut_loss = criterion(outputs, train_target)
                loss =  ut_loss + fair_loss
                tepoch.set_postfix(ul = ut_loss.item(),fl = fair_loss.item())  
                loss.backward()
                fair_optimizer.step()
                tepoch.set_description(f"epoch %2f " % epoches)

        model.eval()
        test_pred = []
        test_gt = []
        sense_gt = []
        female_predic = []
        female_gt = []
        male_predic = []
        male_gt = []

    # Evaluate on validation
        for step, (test_input, attributes) in enumerate(test_data_loader):
            sensitive, test_target = attributes[:,20], attributes[:,2]
            test_input = test_input.to(device)
            test_target = test_target.to(device)

            gt = test_target.detach().cpu().numpy()
            sen = sensitive.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)

            with torch.no_grad():
                test_pred_, _ = model(test_input, False)
                test_pred.extend(torch.round(test_pred_.squeeze(1)).detach().cpu().numpy())

        for i in range(len(sense_gt)):
            if sense_gt[i] == 0:
                female_predic.append(test_pred[i])
                female_gt.append(test_gt[i])
            else:
                male_predic.append(test_pred[i])
                male_gt.append(test_gt[i])
        female_CM = confusion_matrix(female_gt, female_predic)    
        male_CM = confusion_matrix(male_gt, male_predic) 
        female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
        male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
        female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
        male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
        female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
        male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])

        print('Female TPR', female_TPR)
        print('male TPR', male_TPR)
        print('DP',abs(female_dp - male_dp))
        print('EOP', abs(female_TPR - male_TPR))
        print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
        print('acc', accuracy_score(test_gt, test_pred))
        print('Trade off',accuracy_score(test_gt, test_pred)*(1-0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR))) )
        

seed_everything(4096)    
train_model()
