In [1]:
#step 1 import image
%matplotlib inline
import torchvision.datasets
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.utils import make_grid
from PIL import Image



device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
image_size = 64
batch_size = 128

test_dataset = torchvision.datasets.CelebA("../../../celeba/datasets/",split='test', transform=tvt.Compose([
                                #tvt.CenterCrop(image_size),
                                tvt.Resize((image_size,image_size)),
                                tvt.ToTensor(),
                                tvt.Normalize(mean=(0.5000, 0.5000, 0.5000), std=(0.5000, 0.5000, 0.5000)),
                                ]))

test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
print('Done')


  from .autonotebook import tqdm as notebook_tqdm


Done


In [2]:
from transformers import ViTConfig, ViTModel
import torch
import torch.nn as nn
configuration = ViTConfig(num_hidden_layers = 8, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= image_size, patch_size = 16)
vit = ViTModel(configuration)

class VisionTransformer(nn.Module):
    def __init__(self, vit):
        super(VisionTransformer, self).__init__()
        self.vit = vit
        self.fc = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )
    
    def forward(self, x):
        z = self.vit(x)
        m = z.last_hidden_state
        g = m[:,0]
        y = self.fc(g)
        return y

In [6]:
import time
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score


        
def test(test_loader, model, print_fairness):
        model.eval()
        test_pred = []
        test_gt = []
        sense_gt = []
        female_predic = []
        female_gt = []
        male_predic = []
        male_gt = []
        with torch.no_grad():
            with tqdm(test_loader, unit="batch") as tepoch:
                for content in tepoch:
                    test_images, test_attributes = content
                    sensitive, label = test_attributes[:,20], test_attributes[:,9]
                    prediction = model(test_images.to(device))
                    label = label.to(torch.float).to(device)
                    prediction = torch.argmax(prediction, dim=1)
                    gt = label.detach().cpu().numpy()
                    sen = sensitive.detach().cpu().numpy()
                    test_pred.extend(prediction.squeeze().detach().cpu().numpy())
                    test_gt.extend(gt)
                    sense_gt.extend(sen)
            for i in range(len(sense_gt)):
                if sense_gt[i] == 0:
                    female_predic.append(test_pred[i])
                    female_gt.append(test_gt[i])
                else:
                    male_predic.append(test_pred[i])
                    male_gt.append(test_gt[i])
                    
            female_CM = confusion_matrix(female_gt, female_predic)    
            male_CM = confusion_matrix(male_gt, male_predic) 
            female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
            male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
            female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
            male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
            female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
            male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])
            if print_fairness == True:
                print(female_CM)
                print(male_CM)
                print('Female TPR', female_TPR)
                print('male TPR', male_TPR)
                print('DP',abs(female_dp - male_dp))
                print('EOP', abs(female_TPR - male_TPR))
                print('EoD',0.5*(abs(female_FPR-male_FPR)+abs(female_TPR-male_TPR)))
                print('acc', accuracy_score(test_gt, test_pred))
                print('Trade off',accuracy_score(test_gt, test_pred)*(1-0.5*(abs(female_FPR-male_FPR)+abs(female_TPR-male_TPR))) )
 
        
def main():   
    model = VisionTransformer(vit)
    pre_trained_weight='results/CelebA/CelebA_sample_exp/train_downstream_ERM_upweight_0_epochs_50_lr_0.0001_weight_decay_0.0001/final_epoch1/JTT_upweight_50_epochs_50_lr_0.0001_weight_decay_0.0001/model_outputs/20_model.pth'
    #pre_trained_weight='results/CelebA/CelebA_sample_exp/ERM_upweight_0_epochs_50_lr_0.0001_weight_decay_0.1/model_outputs/30_model.pth'
    model = torch.load(pre_trained_weight, map_location=device)
    model = model.to(device)
    test(test_data_loader, model, print_fairness=True)



if __name__ == '__main__':
    main()
    

100%|██████████████████████████████████████████████████████████████████████████████| 156/156 [00:16<00:00,  9.28batch/s]

[[9118  649]
 [ 522 1958]]
[[7362  173]
 [ 103   77]]
Female TPR 0.7895161290322581
male TPR 0.42777777777777776
DP 0.1804640505820966
EOP 0.36173835125448034
EoD 0.20261353655605882
acc 0.9275122733193066
Trade off 0.739585731422932



