In [1]:
#step 1 import image
%matplotlib inline
import torchvision.datasets
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset, random_split
from torchvision.utils import make_grid
from PIL import Image
from time import time
from tqdm import tqdm
import random
from transformers import ViTModel

device = torch.device('cuda:1')

def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(1024)
image_size = 64
batch_size = 256
dataset = torchvision.datasets.CelebA("../celeba/datasets/",split='train', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])                                  
                              ]))

test_dataset = torchvision.datasets.CelebA("../celeba/datasets/",split='test', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])                                  
                              ]))


training_data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
print('Done')

  from .autonotebook import tqdm as notebook_tqdm


Done


In [2]:
from transformers import ViTConfig, ViTModel
configuration = ViTConfig(num_hidden_layers = 8, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= 64, patch_size = 16)
model = ViTModel(configuration)
configuration = model.config
t = iter(test_data_loader)
img, label = next(t)
img
y = model(img)
y.last_hidden_state.shape

torch.Size([256, 17, 768])

In [3]:
class VisionTransformer(nn.Module):
    def __init__(self, vit):
        super(VisionTransformer, self).__init__()
        self.vit = vit
        self.seq = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1),     
            nn.Sigmoid()
        )
    
    def forward(self, x):
        z = self.vit(x)
        m = z.last_hidden_state
        g = m[:,0]
        y = self.seq(g)
        return y 

In [4]:
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
import seaborn as sns

def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def train_model():
    epoch = 30
    configuration = ViTConfig(num_hidden_layers = 8, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= 64, patch_size = 16)
    vit = ViTModel(configuration)
    configuration = vit.config
    vit = vit.to(device)
    model = VisionTransformer(vit)
    model = model.to(device)
    
    criterion = nn.BCELoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-5)
    
    valid_acc = []
    valid_eod = []

    for epoches in range(epoch):
        with tqdm(training_data_loader, unit="batch") as tepoch:
            for train_input, attributes in tepoch:
                # Transfer data to GPU if possible. 
                train_input = train_input.to(device)
                sensitive, train_target = attributes[:,20], attributes[:,9]
               
                train_target = train_target.float().to(device)
                train_target = train_target.unsqueeze(1)
                optimizer.zero_grad()

                # Learner update step.
                outputs = model(train_input)
                loss = criterion(outputs, train_target)
                loss.backward()
                #logger_learner.add_values(logging_dict)
                optimizer.step()
                tepoch.set_description(f"epoch %2f " % epoches)
                tepoch.set_postfix(ut_loss = loss.item())
        # Reset the dataloader if out of data.
        #model.load_state_dict(torch.load(PATH), False)
        test_pred = []
        test_gt = []
        sense_gt = []
        female_predic = []
        female_gt = []
        male_predic = []
        male_gt = []
        
        
        
        



    # Evaluate on valdi set.
        for step, (test_input, attributes) in enumerate(test_data_loader):
            sensitive, test_target = attributes[:,20], attributes[:,9]
            test_input = test_input.to(device)
            test_target = test_target.to(device)

            gt = test_target.detach().cpu().numpy()
            sen = sensitive.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)

            # Todo: split according to sensitive attribute
            # Todo: combine all batch togather

            with torch.no_grad():
                test_pred_ = model(test_input)
                test_pred.extend(torch.round(test_pred_.squeeze(1)).detach().cpu().numpy())

        for i in range(len(sense_gt)):
            if sense_gt[i] == 0:
                female_predic.append(test_pred[i])
                female_gt.append(test_gt[i])
            else:
                male_predic.append(test_pred[i])
                male_gt.append(test_gt[i])
        female_CM = confusion_matrix(female_gt, female_predic)    
        male_CM = confusion_matrix(male_gt, male_predic) 
        female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
        male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
        female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
        male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
        female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
        male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])
        
        valid_acc.append(accuracy_score(test_gt, test_pred))
        valid_eod.append(0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))

        print('Female TPR', female_TPR)
        print('male TPR', male_TPR)
        print('DP',abs(female_dp - male_dp))
        print('EOP', abs(female_TPR - male_TPR))
        print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
        print('acc', accuracy_score(test_gt, test_pred))
        print('Trade off',accuracy_score(test_gt, test_pred)*(1-0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR))) )
        
    return valid_acc,  valid_eod


seed_everything(1024)    
va, ve = train_model()

epoch 0.000000 : 100%|██████████████████████████████████████████████| 635/635 [02:51<00:00,  3.70batch/s, ut_loss=0.164]


Female TPR 0.6016129032258064
male TPR 0.1388888888888889
DP 0.13369441831906564
EOP 0.4627240143369175
EoD 0.24159508282121442
acc 0.9268610359683398
Trade off 0.7029359672198121


epoch 1.000000 : 100%|██████████████████████████████████████████████| 635/635 [02:53<00:00,  3.66batch/s, ut_loss=0.142]


Female TPR 0.825
male TPR 0.39444444444444443
DP 0.19000377888250808
EOP 0.4305555555555555
EoD 0.23719066388336102
acc 0.9377817853922452
Trade off 0.7153487011373351


epoch 2.000000 : 100%|██████████████████████████████████████████████| 635/635 [02:50<00:00,  3.72batch/s, ut_loss=0.147]


Female TPR 0.7407258064516129
male TPR 0.25
DP 0.16693217977489797
EOP 0.4907258064516129
EoD 0.26017482933235714
acc 0.9423905420298567
Trade off 0.6972042435928112


epoch 3.000000 : 100%|███████████████████████████████████████████████| 635/635 [02:40<00:00,  3.97batch/s, ut_loss=0.15]


Female TPR 0.805241935483871
male TPR 0.3055555555555556
DP 0.18661672325641562
EOP 0.49968637992831544
EoD 0.2697971959411115
acc 0.9444945396252881
Trade off 0.6896725612526943


epoch 4.000000 : 100%|██████████████████████████████████████████████| 635/635 [02:40<00:00,  3.96batch/s, ut_loss=0.171]


Female TPR 0.728225806451613
male TPR 0.24444444444444444
DP 0.16017329835587124
EOP 0.4837813620071685
EoD 0.25377698150670036
acc 0.9451457769762549
Trade off 0.705289534611416


epoch 5.000000 : 100%|█████████████████████████████████████████████| 635/635 [02:46<00:00,  3.82batch/s, ut_loss=0.0886]


Female TPR 0.7254032258064517
male TPR 0.2388888888888889
DP 0.1616236356850337
EOP 0.48651433691756274
EoD 0.25630003880432217
acc 0.9438433022743212
Trade off 0.7019362272762131


epoch 6.000000 : 100%|█████████████████████████████████████████████| 635/635 [02:51<00:00,  3.70batch/s, ut_loss=0.0607]


Female TPR 0.7967741935483871
male TPR 0.3333333333333333
DP 0.1783852048150615
EOP 0.46344086021505376
EoD 0.24814406972277178
acc 0.9459973950505961
Trade off 0.7112537514956004


epoch 7.000000 : 100%|█████████████████████████████████████████████| 635/635 [02:51<00:00,  3.69batch/s, ut_loss=0.0793]


Female TPR 0.7741935483870968
male TPR 0.2833333333333333
DP 0.17460865070398818
EOP 0.4908602150537634
EoD 0.26145743272421856
acc 0.9443943492636009
Trade off 0.6974754272258809


epoch 8.000000 : 100%|█████████████████████████████████████████████| 635/635 [02:50<00:00,  3.72batch/s, ut_loss=0.0757]


Female TPR 0.8129032258064516
male TPR 0.3333333333333333
DP 0.19144962875561838
EOP 0.47956989247311826
EoD 0.2623517208983895
acc 0.9419897805831079
Trade off 0.6948571405784333


epoch 9.000000 : 100%|█████████████████████████████████████████████| 635/635 [02:47<00:00,  3.79batch/s, ut_loss=0.0613]


Female TPR 0.7483870967741936
male TPR 0.25555555555555554
DP 0.17114555174833246
EOP 0.49283154121863804
EoD 0.26320152364509913
acc 0.9397354974451457
Trade off 0.6923956826941983


epoch 10.000000 : 100%|████████████████████████████████████████████| 635/635 [02:42<00:00,  3.90batch/s, ut_loss=0.0452]


Female TPR 0.7971774193548387
male TPR 0.3
DP 0.18302625040078857
EOP 0.49717741935483867
EoD 0.26775600439279457
acc 0.9401863540727382
Trade off 0.6884458125215926


epoch 11.000000 : 100%|████████████████████████████████████████████| 635/635 [03:00<00:00,  3.52batch/s, ut_loss=0.0214]


Female TPR 0.7709677419354839
male TPR 0.32222222222222224
DP 0.17809340375182017
EOP 0.44874551971326165
EoD 0.24402541795148852
acc 0.9372307384029657
Trade off 0.7085226157471998


epoch 12.000000 : 100%|████████████████████████████████████████████| 635/635 [02:52<00:00,  3.68batch/s, ut_loss=0.0313]


Female TPR 0.744758064516129
male TPR 0.3277777777777778
DP 0.17046088660807113
EOP 0.41698028673835125
EoD 0.22666197437059293
acc 0.9362789299669372
Trade off 0.7240600991390451


epoch 13.000000 : 100%|███████████████████████████████████████████| 635/635 [02:53<00:00,  3.66batch/s, ut_loss=0.00795]


Female TPR 0.7935483870967742
male TPR 0.31666666666666665
DP 0.1843715135231446
EOP 0.47688172043010757
EoD 0.259200831650359
acc 0.9379821661156197
Trade off 0.6948564085852459


epoch 14.000000 : 100%|████████████████████████████████████████████| 635/635 [02:58<00:00,  3.55batch/s, ut_loss=0.0291]


Female TPR 0.7366935483870968
male TPR 0.26666666666666666
DP 0.1686674176452593
EOP 0.4700268817204301
EoD 0.25192244385494245
acc 0.9380322612964633
Trade off 0.7017208816158803


epoch 15.000000 : 100%|███████████████████████████████████████████| 635/635 [02:58<00:00,  3.56batch/s, ut_loss=0.00399]


Female TPR 0.7879032258064517
male TPR 0.32222222222222224
DP 0.18260370984553678
EOP 0.4656810035842294
EoD 0.2532155591419695
acc 0.9382827372006813
Trade off 0.7006949492671531


epoch 16.000000 : 100%|████████████████████████████████████████████| 635/635 [02:54<00:00,  3.64batch/s, ut_loss=0.0111]


Female TPR 0.7600806451612904
male TPR 0.34444444444444444
DP 0.17112950697622142
EOP 0.4156362007168459
EoD 0.22488649059954816
acc 0.9380322612964633
Trade off 0.7270814779843433


epoch 17.000000 : 100%|████████████████████████████████████████████| 635/635 [03:13<00:00,  3.29batch/s, ut_loss=0.0257]


Female TPR 0.8241935483870968
male TPR 0.3888888888888889
DP 0.19273088212749445
EOP 0.4353046594982079
EoD 0.24107249878518955
acc 0.9378819757539325
Trade off 0.7117844242933414


epoch 18.000000 : 100%|███████████████████████████████████████████| 635/635 [03:17<00:00,  3.22batch/s, ut_loss=0.00657]


Female TPR 0.7846774193548387
male TPR 0.34444444444444444
DP 0.18064917931149407
EOP 0.44023297491039426
EoD 0.24001564546794935
acc 0.938433022743212
Trade off 0.7131944150610612


epoch 19.000000 : 100%|████████████████████████████████████████████| 635/635 [03:21<00:00,  3.16batch/s, ut_loss=0.0087]


Female TPR 0.7923387096774194
male TPR 0.3277777777777778
DP 0.1836794715978164
EOP 0.4645609318996416
EoD 0.25280340150961145
acc 0.9390842600941789
Trade off 0.7016805648382338


epoch 20.000000 : 100%|████████████████████████████████████████████| 635/635 [03:20<00:00,  3.17batch/s, ut_loss=0.0184]


Female TPR 0.7967741935483871
male TPR 0.3611111111111111
DP 0.18435210315899445
EOP 0.435663082437276
EoD 0.2386559333645291
acc 0.9399859733493638
Trade off 0.7156527435301061


epoch 21.000000 : 100%|████████████████████████████████████████████| 635/635 [03:14<00:00,  3.26batch/s, ut_loss=0.0149]


Female TPR 0.7681451612903226
male TPR 0.28888888888888886
DP 0.17473024594593006
EOP 0.47925627240143376
EoD 0.25690882536236376
acc 0.9399358781685202
Trade off 0.6984580557923039


epoch 22.000000 : 100%|████████████████████████████████████████████| 635/635 [03:21<00:00,  3.15batch/s, ut_loss=0.0458]


Female TPR 0.7649193548387097
male TPR 0.28888888888888886
DP 0.1765357484878252
EOP 0.4760304659498208
EoD 0.25673308419024493
acc 0.9389339745516482
Trade off 0.6978785594139987


epoch 23.000000 : 100%|███████████████████████████████████████████| 635/635 [02:51<00:00,  3.70batch/s, ut_loss=0.00185]


Female TPR 0.8068548387096774
male TPR 0.3888888888888889
DP 0.18177113857714092
EOP 0.41796594982078855
EoD 0.22794184328540326
acc 0.9384831179240557
Trade off 0.724563546132214


epoch 24.000000 : 100%|████████████████████████████████████████████| 635/635 [02:41<00:00,  3.93batch/s, ut_loss=0.0196]


Female TPR 0.777016129032258
male TPR 0.3277777777777778
DP 0.17497632575882854
EOP 0.44923835125448025
EoD 0.2416458371186903
acc 0.9404869251577999
Trade off 0.7132221748288603


epoch 25.000000 : 100%|██████████████████████████████████████████| 635/635 [02:40<00:00,  3.96batch/s, ut_loss=0.000406]


Female TPR 0.7870967741935484
male TPR 0.35
DP 0.1789692832045686
EOP 0.4370967741935484
EoD 0.2374218616823555
acc 0.9378318805730889
Trade off 0.7151700895423616


epoch 26.000000 : 100%|████████████████████████████████████████████| 635/635 [02:40<00:00,  3.95batch/s, ut_loss=0.0141]


Female TPR 0.7883064516129032
male TPR 0.37222222222222223
DP 0.1783171838715538
EOP 0.416084229390681
EoD 0.22687203294517344
acc 0.9367297865945297
Trade off 0.7242119955895302


epoch 27.000000 : 100%|███████████████████████████████████████████| 635/635 [02:39<00:00,  3.97batch/s, ut_loss=0.00828]


Female TPR 0.7673387096774194
male TPR 0.3333333333333333
DP 0.17313713554567386
EOP 0.43400537634408604
EoD 0.23426068537277903
acc 0.9385332131048993
Trade off 0.7186717793578292


epoch 28.000000 : 100%|███████████████████████████████████████████| 635/635 [02:40<00:00,  3.96batch/s, ut_loss=0.00448]


Female TPR 0.8129032258064516
male TPR 0.40555555555555556
DP 0.18600793210775335
EOP 0.407347670250896
EoD 0.2248245295164369
acc 0.9367798817753732
Trade off 0.7261687855947616


epoch 29.000000 : 100%|███████████████████████████████████████████| 635/635 [02:41<00:00,  3.93batch/s, ut_loss=0.00961]


Female TPR 0.7721774193548387
male TPR 0.35555555555555557
DP 0.17738820638339564
EOP 0.4166218637992832
EoD 0.2279598261183482
acc 0.9368800721370604
Trade off 0.7233090537989506
