In [1]:
#step 1 import image
%matplotlib inline
import torchvision.datasets
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.utils import make_grid
from PIL import Image
from time import time
from tqdm import tqdm
import random
from transformers import ViTConfig, ViTModel



def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(4096)
device = torch.device('cuda:3')

image_size = 64
batch_size = 1024
dataset = torchvision.datasets.CelebA("../../celeba/datasets/",split='train', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
                              ]))

test_dataset = torchvision.datasets.CelebA("../../celeba/datasets/",split='test', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
                              ]))

training_data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
print('Done')

  from .autonotebook import tqdm as notebook_tqdm


Done


In [2]:
class VisionTransformer(nn.Module):
    def __init__(self, vit):
        super(VisionTransformer, self).__init__()
        self.vit = vit
        self.seq = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
    
    def forward(self, x):
        z = self.vit(x)
        m = z.last_hidden_state
        g = m[:,0]
        y = self.seq(g)
        return y 

In [3]:
def get_model(data, hidden_layers = 8):
    if data == 'celebA':
        configuration = ViTConfig(num_hidden_layers = hidden_layers, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= 64, patch_size = 16)
        vit = ViTModel(configuration)
        configuration = vit.config
        vit = vit.to(device)
        model = VisionTransformer(vit)
   
    else:
        return None
  
    return model

class FFN(torch.nn.Module):
    def __init__(self, input_size, p_dropout=0.5, n_class=1):
        super(FFN, self).__init__()
        self.fc1 = nn.Linear(input_size, 32)
        self.fc2 = nn.Linear(32, n_class)
    
        self.dropout = nn.Dropout(p=p_dropout)
        
    def forward(self, x):
        x = self.dropout(self.fc1(x).relu())
        #x = self.fc1(x).relu()
        x = self.fc2(x)
        return x

In [4]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, precision_score



model_learner = get_model('celebA').to(device)
opt_learner = torch.optim.Adam(model_learner.parameters(), lr = 1e-4)
model_adv = get_model('celebA',hidden_layers=2).to(device)
opt_adv = torch.optim.Adam(model_adv.parameters(), lr = 1e-4)

tr_loss_sum = 0
epoch = 30

for epoches in range(epoch):
    with tqdm(training_data_loader, unit="batch") as tepoch:
        for train_input, attributes in tepoch:
            # Transfer data to GPU if possible. 
            train_input = train_input.to(device)
            sensitive, train_target = attributes[:,20], attributes[:,2]
            
            train_target = train_target.float().to(device)
            train_target = train_target.unsqueeze(1)

            w = model_adv(train_input).sigmoid()
            w = 1 + (w.shape[0]*(w/w.sum()))

            # training learner on weighted data
            tr_out = model_learner(train_input)
            tr_loss = F.binary_cross_entropy_with_logits(tr_out, train_target, reduction = 'none')
            tr_loss_sum += tr_loss.mean().item()
            
            tr_loss *= w.detach()
            opt_learner.zero_grad()
            tr_loss.mean().backward()
            opt_learner.step()
            
            # training adversary to maximize weighted loss
            adv_loss = -(tr_loss.detach()*w).mean()
            opt_adv.zero_grad()
            adv_loss.backward()
            opt_adv.step()

            tepoch.set_description(f"epoch %2f " % epoches)
            tepoch.set_postfix(ut_loss = tr_loss.mean().item(),adv=adv_loss.item())

    test_pred = []
    test_gt = []
    sense_gt = []
    female_predic = []
    female_gt = []
    male_predic = []
    male_gt = []



# Evaluate on test set.
    for step, (test_input, attributes) in enumerate(test_data_loader):
        sensitive, test_target = attributes[:,20], attributes[:,2]
        test_input = test_input.to(device)
        test_target = test_target.to(device)

        gt = test_target.detach().cpu().numpy()
        sen = sensitive.detach().cpu().numpy()
        test_gt.extend(gt)
        sense_gt.extend(sen)

        # Todo: split according to sensitive attribute
        # Todo: combine all batch togather

        with torch.no_grad():
            test_pred_ = model_learner(test_input)
            test_pred_ = nn.Sigmoid()(test_pred_)
            test_pred.extend(torch.round(test_pred_.squeeze(1)).detach().cpu().numpy())

    for i in range(len(sense_gt)):
        if sense_gt[i] == 0:
            female_predic.append(test_pred[i])
            female_gt.append(test_gt[i])
        else:
            male_predic.append(test_pred[i])
            male_gt.append(test_gt[i])
    female_CM = confusion_matrix(female_gt, female_predic)    
    male_CM = confusion_matrix(male_gt, male_predic) 
    female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
    male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
    female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
    male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
    female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
    male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])

    print('Female TPR', female_TPR)
    print('male TPR', male_TPR)
    print('DP',abs(female_dp - male_dp))
    print('EOP', abs(female_TPR - male_TPR))
    print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
    print('acc', accuracy_score(test_gt, test_pred))
    print('Trade off',accuracy_score(test_gt, test_pred)*(1-0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR))) )

epoch 0.000000 : 100%|█| 158/158 [02:37<00:00,  1.01batch/s, adv=-1.92, ut_loss=


Female TPR 0.8635075972186453
male TPR 0.47665056360708535
DP 0.5408845045055581
EOP 0.38685703361156
EoD 0.39149303402912994
acc 0.7586862664473685
Trade off 0.4616658781196553


epoch 1.000000 : 100%|█| 158/158 [02:38<00:00,  1.01s/batch, adv=-4.67, ut_loss=


Female TPR 0.8637647663071392
male TPR 0.5042780748663102
DP 0.5216423363216857
EOP 0.359486691440829
EoD 0.36102956349854043
acc 0.7676809210526315
Trade off 0.49052541321884247


epoch 2.000000 : 100%|█| 158/158 [02:38<00:00,  1.00s/batch, adv=-3.65, ut_loss=


Female TPR 0.9042362002567395
male TPR 0.6270912034538586
DP 0.5201339371996461
EOP 0.27714499680288085
EoD 0.3367464746984008
acc 0.7756990131578947
Trade off 0.5144851050499453


epoch 3.000000 : 100%|█| 158/158 [02:37<00:00,  1.00batch/s, adv=-1.91, ut_loss=


Female TPR 0.8995509942270686
male TPR 0.5794291868605277
DP 0.5343122329415513
EOP 0.32012180736654094
EoD 0.3577544698378311
acc 0.7773951480263158
Trade off 0.49927855898965895


epoch 4.000000 : 100%|█| 158/158 [02:38<00:00,  1.00s/batch, adv=-11.9, ut_loss=


Female TPR 0.8907012586694066
male TPR 0.5082932049224184
DP 0.5527475560160776
EOP 0.3824080537469883
EoD 0.3882891979693444
acc 0.7762129934210527
Trade off 0.47481787275220816


epoch 5.000000 : 100%|█| 158/158 [02:37<00:00,  1.00batch/s, adv=-1.82, ut_loss=


Female TPR 0.8573629477468224
male TPR 0.4965184788430637
DP 0.5210614501008267
EOP 0.3608444689037587
EoD 0.3509678610470798
acc 0.7801192434210527
Trade off 0.5063224611958996


epoch 6.000000 : 100%|█| 158/158 [02:37<00:00,  1.00batch/s, adv=-1.86, ut_loss=


Female TPR 0.9066598097197223
male TPR 0.6307609860664523
DP 0.5170241249124217
EOP 0.27589882365327
EoD 0.33009463702272135
acc 0.780993009868421
Trade off 0.523191405758622


epoch 7.000000 : 100%|█| 158/158 [02:38<00:00,  1.00s/batch, adv=-1.91, ut_loss=


Female TPR 0.8681191182158051
male TPR 0.513903743315508
DP 0.5155259421542145
EOP 0.3542153749002971
EoD 0.3371081196896194
acc 0.7902446546052632
Trade off 0.5238467649965102


epoch 8.000000 : 100%|█| 158/158 [02:38<00:00,  1.00s/batch, adv=-1.95, ut_loss=


Female TPR 0.8929212618620159
male TPR 0.6704240472356414
DP 0.47013539567990165
EOP 0.2224972146263745
EoD 0.26924007528438465
acc 0.7881373355263158
Trade off 0.5759391799747762


epoch 9.000000 : 100%|█| 158/158 [02:36<00:00,  1.01batch/s, adv=-1.81, ut_loss=


Female TPR 0.8952234206471494
male TPR 0.6591519055287172
DP 0.4824791623292409
EOP 0.23607151511843227
EoD 0.2861085560853548
acc 0.7853104440789473
Trade off 0.5606264068447709


epoch 10.000000 : 100%|█| 158/158 [02:38<00:00,  1.00s/batch, adv=-1.95, ut_loss


Female TPR 0.8978936552787053
male TPR 0.6770721205597416
DP 0.4722519814804168
EOP 0.2208215347189637
EoD 0.2746766431013564
acc 0.7836657072368421
Trade off 0.568411041459376


epoch 11.000000 : 100%|█| 158/158 [02:38<00:00,  1.00s/batch, adv=-2.28, ut_loss


Female TPR 0.8782709081580298
male TPR 0.5639784946236559
DP 0.5095769751802984
EOP 0.31429241353437387
EoD 0.33744441642206824
acc 0.7715871710526315
Trade off 0.5112193883980217


epoch 12.000000 : 100%|█| 158/158 [02:38<00:00,  1.00s/batch, adv=-1.85, ut_loss


Female TPR 0.8608940694248751
male TPR 0.552319309600863
DP 0.49711786516366685
EOP 0.30857475982401217
EoD 0.31319485018496274
acc 0.7852076480263158
Trade off 0.5392846563386269


epoch 13.000000 : 100%|█| 158/158 [02:38<00:00,  1.00s/batch, adv=-1.62, ut_loss


Female TPR 0.847457627118644
male TPR 0.5398182789951897
DP 0.4765082293603238
EOP 0.3076393481234543
EoD 0.29287110766428126
acc 0.7870579769736842
Trade off 0.556551435461393


epoch 14.000000 : 100%|█| 158/158 [02:38<00:00,  1.00s/batch, adv=-3.05, ut_loss


Female TPR 0.8618167821401078
male TPR 0.5651474530831099
DP 0.4777971462978896
EOP 0.29666932905699794
EoD 0.2936107431793377
acc 0.7846936677631579
Trade off 0.5542991768030967


epoch 15.000000 : 100%|█| 158/158 [02:34<00:00,  1.02batch/s, adv=-10.3, ut_loss


Female TPR 0.9284611425630468
male TPR 0.7741591030432461
DP 0.4540525037557252
EOP 0.1543020395198007
EoD 0.2677776848786577
acc 0.7543688322368421
Trade off 0.552365692795844


epoch 16.000000 : 100%|█| 158/158 [02:34<00:00,  1.02batch/s, adv=nan, ut_loss=n


ValueError: Input contains NaN, infinity or a value too large for dtype('float32').