In [1]:
#step 1 import image
%matplotlib inline
import torchvision.datasets
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.utils import make_grid
from PIL import Image
from time import time
from tqdm import tqdm
from transformers import ViTConfig, ViTModel


device = torch.device('cuda:0')

image_size = 64
batch_size = 1024
dataset = torchvision.datasets.CelebA("../../celeba/datasets/",split='train', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
                              ]))

test_dataset = torchvision.datasets.CelebA("../../celeba/datasets/",split='test', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
                              ]))

training_data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
print('Done')

  from .autonotebook import tqdm as notebook_tqdm


Done


In [2]:
class VisionTransformer(nn.Module):
    def __init__(self, vit):
        super(VisionTransformer, self).__init__()
        self.vit = vit
        self.seq = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
    
    def forward(self, x):
        z = self.vit(x)
        m = z.last_hidden_state
        g = m[:,0]
        y = self.seq(g)
        return y 

In [3]:
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
import seaborn as sns


class DRO_loss(torch.nn.Module):
    def __init__(self, eta, k):
        super(DRO_loss, self).__init__()
        self.eta = eta
        self.k = k
        self.logsig = torch.nn.LogSigmoid()
        self.relu = torch.nn.ReLU()
    
    def forward(self, x, y):
        bce = -1*y*self.logsig(x) - (1-y)*self.logsig(-x)
        if self.k > 0:
            bce = self.relu(bce - self.eta)
            bce = bce**self.k
            return bce.mean()
        else:
            return bce.mean()


def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def train_model():
    epoch = 55
    configuration = ViTConfig(num_hidden_layers = 8, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= 64, patch_size = 16)
    vit = ViTModel(configuration)
    configuration = vit.config
    vit = vit.to(device)
    model = VisionTransformer(vit)
    model = model.to(device)
    
    criterion = DRO_loss(eta=0.25, k=1.5)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)

    #fair_optimizer = optim.Adam(fairnet.parameters(), lr = 1e-3)
    #fair_epoch = 30

    for epoches in range(epoch):
        with tqdm(training_data_loader, unit="batch") as tepoch:
            for train_input, attributes in tepoch:
                # Transfer data to GPU if possible. 
                train_input = train_input.to(device)
                sensitive, train_target = attributes[:,20], attributes[:,2]
               
                train_target = train_target.float().to(device)
                train_target = train_target.unsqueeze(1)
                optimizer.zero_grad()

                # Learner update step.
                outputs = model(train_input)
                loss = criterion(outputs, train_target)
                loss.backward(retain_graph=True)
                #logger_learner.add_values(logging_dict)
                optimizer.step()
                tepoch.set_description(f"epoch %2f " % epoches)
                tepoch.set_postfix(ut_loss = loss.item())
        # Reset the dataloader if out of data.
        #model.load_state_dict(torch.load(PATH), False)
        test_pred = []
        test_gt = []
        sense_gt = []
        female_predic = []
        female_gt = []
        male_predic = []
        male_gt = []



    # Evaluate on test set.
        for step, (test_input, attributes) in enumerate(test_data_loader):
            sensitive, test_target = attributes[:,20], attributes[:,2]
            test_input = test_input.to(device)
            test_target = test_target.to(device)

            gt = test_target.detach().cpu().numpy()
            sen = sensitive.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)

            # Todo: split according to sensitive attribute
            # Todo: combine all batch togather

            with torch.no_grad():
                test_pred_= model(test_input)
                test_pred_ = torch.sigmoid(test_pred_)
                test_pred.extend(torch.round(test_pred_.squeeze(1)).detach().cpu().numpy())

        for i in range(len(sense_gt)):
            if sense_gt[i] == 0:
                female_predic.append(test_pred[i])
                female_gt.append(test_gt[i])
            else:
                male_predic.append(test_pred[i])
                male_gt.append(test_gt[i])
        female_CM = confusion_matrix(female_gt, female_predic)    
        male_CM = confusion_matrix(male_gt, male_predic) 
        female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
        male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
        female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
        male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
        female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
        male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])

        print('Female TPR', female_TPR)
        print('male TPR', male_TPR)
        print('DP',abs(female_dp - male_dp))
        print('EOP', abs(female_TPR - male_TPR))
        print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
        print('acc', accuracy_score(test_gt, test_pred))
        print('Trade off',accuracy_score(test_gt, test_pred)*(1-0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR))) )



seed_everything(4096)    
train_model()


epoch 0.000000 : 100%|██████| 158/158 [02:30<00:00,  1.05batch/s, ut_loss=0.216]


Female TPR 0.8577652485904664
male TPR 0.46802794196668457
DP 0.5316184231069476
EOP 0.38973730662378187
EoD 0.38335848260766037
acc 0.758069490131579
Trade off 0.4674571206835741


epoch 1.000000 : 100%|██████| 158/158 [02:29<00:00,  1.06batch/s, ut_loss=0.177]


Female TPR 0.9388123148267422
male TPR 0.7254691689008043
DP 0.5079374952706557
EOP 0.21334314592593795
EoD 0.32870556824818303
acc 0.7560649671052632
Trade off 0.5075422024603838


epoch 2.000000 : 100%|██████| 158/158 [02:29<00:00,  1.06batch/s, ut_loss=0.176]


Female TPR 0.8774143703322174
male TPR 0.5509333333333334
DP 0.5137049527074082
EOP 0.326481036998884
EoD 0.3372994008506437
acc 0.7810444078947368
Trade off 0.5175985970740964


epoch 3.000000 : 100%|██████| 158/158 [02:29<00:00,  1.06batch/s, ut_loss=0.188]


Female TPR 0.8719002955158679
male TPR 0.5771090811391725
DP 0.4956878384377634
EOP 0.2947912143766954
EoD 0.3089098020337912
acc 0.7865439967105263
Trade off 0.5435728463958107


epoch 4.000000 : 100%|██████| 158/158 [02:29<00:00,  1.06batch/s, ut_loss=0.163]


Female TPR 0.8454814243476025
male TPR 0.4751469802244789
DP 0.5062563664081826
EOP 0.37033444412312366
EoD 0.3323856408195661
acc 0.7885485197368421
Trade off 0.5264463146867916


epoch 5.000000 : 100%|██████| 158/158 [02:29<00:00,  1.06batch/s, ut_loss=0.165]


Female TPR 0.8775273663876368
male TPR 0.5367292225201072
DP 0.5130140069160443
EOP 0.3407981438675296
EoD 0.331383164757198
acc 0.7906044407894737
Trade off 0.5286114391295631


epoch 6.000000 : 100%|███████| 158/158 [02:29<00:00,  1.06batch/s, ut_loss=0.17]


Female TPR 0.9300861071841665
male TPR 0.7076180257510729
DP 0.49670154742335176
EOP 0.2224680814330936
EoD 0.30540851517729556
acc 0.7744140625
Trade off 0.5379014135394576


epoch 7.000000 : 100%|██████| 158/158 [02:29<00:00,  1.06batch/s, ut_loss=0.158]


Female TPR 0.9005019951087656
male TPR 0.692390717754992
DP 0.4647941007636227
EOP 0.20811127735377366
EoD 0.2704900487024433
acc 0.7758532072368421
Trade off 0.5659926354254019


epoch 8.000000 : 100%|██████| 158/158 [02:31<00:00,  1.04batch/s, ut_loss=0.143]


Female TPR 0.8542551821810223
male TPR 0.5431635388739946
DP 0.48004858501165754
EOP 0.31109164330702765
EoD 0.30555464708977886
acc 0.7760074013157895
Trade off 0.538894733667687


epoch 9.000000 : 100%|██████| 158/158 [02:32<00:00,  1.03batch/s, ut_loss=0.133]


Female TPR 0.8481924610832369
male TPR 0.5860042735042735
DP 0.45333166241917244
EOP 0.26218818757896334
EoD 0.2711343936097874
acc 0.775133634868421
Trade off 0.5649682468118213


epoch 10.000000 : 100%|█████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.118]


Female TPR 0.8470331363986643
male TPR 0.5845905172413793
DP 0.4445697335405683
EOP 0.26244261915728495
EoD 0.2710008404696459
acc 0.763671875
Trade off 0.5567161550319696


epoch 11.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0911]


Female TPR 0.8501733213506226
male TPR 0.5950590762620838
DP 0.4464830390628612
EOP 0.25511424508853886
EoD 0.2731321883780386
acc 0.7620785361842105
Trade off 0.5539303578802848


epoch 12.000000 : 100%|████| 158/158 [02:32<00:00,  1.03batch/s, ut_loss=0.0744]


Female TPR 0.8609467455621301
male TPR 0.5970069481560663
DP 0.4510270588068269
EOP 0.2639397974060639
EoD 0.2845689148198857
acc 0.7553967927631579
Trade off 0.5404343471881239


epoch 13.000000 : 100%|█████| 158/158 [02:32<00:00,  1.03batch/s, ut_loss=0.066]


Female TPR 0.8348082595870207
male TPR 0.5689008042895443
DP 0.4424757209574043
EOP 0.2659074552974764
EoD 0.27180398008957873
acc 0.76171875
Trade off 0.5546805620411412


epoch 14.000000 : 100%|████| 158/158 [02:31<00:00,  1.04batch/s, ut_loss=0.0589]


Female TPR 0.8033419023136247
male TPR 0.5392838054516301
DP 0.4220939485392323
EOP 0.26405809686199455
EoD 0.26390488077998187
acc 0.7464021381578947
Trade off 0.549422970873412


epoch 15.000000 : 100%|████| 158/158 [02:32<00:00,  1.03batch/s, ut_loss=0.0452]


Female TPR 0.8738089106361061
male TPR 0.6581333333333333
DP 0.4341161207414519
EOP 0.21567557730277276
EoD 0.26247968738337923
acc 0.7516447368421053
Trade off 0.5543532612924271


epoch 16.000000 : 100%|████| 158/158 [02:31<00:00,  1.04batch/s, ut_loss=0.0367]


Female TPR 0.8198175041768411
male TPR 0.6189462480042576
DP 0.4047280992519582
EOP 0.2008712561725835
EoD 0.23325167183589385
acc 0.7501027960526315
Trade off 0.5751400648245767


epoch 17.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0238]


Female TPR 0.7949376846974174
male TPR 0.5485254691689008
DP 0.404621802352782
EOP 0.24641221552851666
EoD 0.24211365190764153
acc 0.7488178453947368
Trade off 0.5675188222326054


epoch 18.000000 : 100%|████| 158/158 [02:32<00:00,  1.03batch/s, ut_loss=0.0311]


Female TPR 0.8292651593011305
male TPR 0.5594405594405595
DP 0.4360394304400781
EOP 0.26982459986057106
EoD 0.2715664649608886
acc 0.7552939967105263
Trade off 0.5501814760176676


epoch 19.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0246]


Female TPR 0.7842179668423082
male TPR 0.5382978723404256
DP 0.40822899999376716
EOP 0.24592009450188257
EoD 0.24624207903169765
acc 0.7478412828947368
Trade off 0.5636912906090048


epoch 20.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0196]


Female TPR 0.8117074420391956
male TPR 0.5362860192102454
DP 0.43617647044074487
EOP 0.2754214228289502
EoD 0.27200720521500443
acc 0.7559107730263158
Trade off 0.550297596263514


epoch 21.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0279]


Female TPR 0.8227032598891896
male TPR 0.6009641135511515
DP 0.4177711970118201
EOP 0.22173914633803804
EoD 0.24658388795760236
acc 0.7533408717105263
Trade off 0.5675791506067754


epoch 22.000000 : 100%|████| 158/158 [02:33<00:00,  1.03batch/s, ut_loss=0.0272]


Female TPR 0.822247853389722
male TPR 0.5641438539989264
DP 0.43306349051055415
EOP 0.25810399939079554
EoD 0.2642929992501404
acc 0.7578638980263158
Trade off 0.5575657753935382


epoch 23.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0196]


Female TPR 0.8336750384418247
male TPR 0.577209797657082
DP 0.43743554321912803
EOP 0.25646524078474264
EoD 0.26802601981565105
acc 0.7583778782894737
Trade off 0.5551128740553078


epoch 24.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0264]


Female TPR 0.8200051427102083
male TPR 0.5634028892455859
DP 0.43368005139303567
EOP 0.2566022534646224
EoD 0.26724670765627556
acc 0.7548314144736842
Trade off 0.5531052041200626


epoch 25.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0255]


Female TPR 0.8716859716859717
male TPR 0.6349462365591397
DP 0.43968542542401606
EOP 0.236739735126832
EoD 0.27221841209339775
acc 0.7500513980263158
Trade off 0.5458735974671591


epoch 26.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0234]


Female TPR 0.8305019305019306
male TPR 0.5675094136632598
DP 0.4386943588921793
EOP 0.26299251683867075
EoD 0.2734991446880654
acc 0.7540604440789473
Trade off 0.5478255575802524


epoch 27.000000 : 100%|████| 158/158 [02:30<00:00,  1.05batch/s, ut_loss=0.0136]


Female TPR 0.8356252409715975
male TPR 0.5708177445216461
DP 0.4339189722189002
EOP 0.2648074964499514
EoD 0.26899236918828423
acc 0.7544716282894737
Trade off 0.5515245175105455


epoch 28.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0185]


Female TPR 0.8356041131105398
male TPR 0.5463806970509384
DP 0.45997487394520536
EOP 0.28922341605960145
EoD 0.2946093022972594
acc 0.7594058388157895
Trade off 0.5356778144818047


epoch 29.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0176]


Female TPR 0.8011785805790418
male TPR 0.5302624531333691
DP 0.4229478892696953
EOP 0.2709161274456727
EoD 0.26396706751779386
acc 0.7484066611842105
Trade off 0.5508519495206313


epoch 30.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0187]


Female TPR 0.8234009761109684
male TPR 0.5973118279569892
DP 0.4143501009537082
EOP 0.22608914815397918
EoD 0.24163681219111965
acc 0.7559621710526315
Trade off 0.5732938819023957


epoch 31.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0204]


Female TPR 0.8251513980157196
male TPR 0.5922953451043339
DP 0.41968015721010826
EOP 0.23285605291138567
EoD 0.2515507325791266
acc 0.7517989309210527
Trade off 0.5626833590956576


epoch 32.000000 : 100%|████| 158/158 [02:31<00:00,  1.04batch/s, ut_loss=0.0196]


Female TPR 0.8482475285659263
male TPR 0.6455764075067024
DP 0.4105792838777892
EOP 0.2026711210592239
EoD 0.2390200733027144
acc 0.7491262335526315
Trade off 0.5700700262958952


epoch 33.000000 : 100%|████| 158/158 [02:31<00:00,  1.04batch/s, ut_loss=0.0169]


Female TPR 0.8328197226502311
male TPR 0.5888352120236178
DP 0.43082243712430923
EOP 0.2439845106266133
EoD 0.2596652118131796
acc 0.7576069078947368
Trade off 0.560882749685122


epoch 34.000000 : 100%|████| 158/158 [02:31<00:00,  1.04batch/s, ut_loss=0.0212]


Female TPR 0.8315802975885069
male TPR 0.6063543349488422
DP 0.413038535530024
EOP 0.22522596263966477
EoD 0.24124879229944013
acc 0.7541118421052632
Trade off 0.5721832709386623


epoch 35.000000 : 100%|████| 158/158 [02:31<00:00,  1.04batch/s, ut_loss=0.0108]


Female TPR 0.814486041425447
male TPR 0.5812332439678284
DP 0.4187556262121079
EOP 0.23325279745761862
EoD 0.25160500986240963
acc 0.7505653782894737
Trade off 0.5617193688825675


epoch 36.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0166]


Female TPR 0.8480801335559266
male TPR 0.6042666666666666
DP 0.4377303221417122
EOP 0.24381346688925998
EoD 0.2683718664220906
acc 0.7552425986842105
Trade off 0.552556732873859


epoch 37.000000 : 100%|████| 158/158 [02:32<00:00,  1.03batch/s, ut_loss=0.0154]


Female TPR 0.8252702007205353
male TPR 0.5475427350427351
DP 0.4480967352361065
EOP 0.27772746567780016
EoD 0.28003272645308763
acc 0.7613075657894737
Trade off 0.548116532472084


epoch 38.000000 : 100%|████| 158/158 [02:32<00:00,  1.03batch/s, ut_loss=0.0135]


Female TPR 0.8314274701272003
male TPR 0.5933654360620653
DP 0.4266205402717909
EOP 0.23806203406513504
EoD 0.2554545010175103
acc 0.7564247532894737
Trade off 0.5631926453806179


epoch 39.000000 : 100%|████| 158/158 [02:33<00:00,  1.03batch/s, ut_loss=0.0153]


Female TPR 0.8149434156378601
male TPR 0.5313001605136437
DP 0.4355594765193989
EOP 0.2836432551242164
EoD 0.274781214488936
acc 0.75390625
Trade off 0.5467469750142006


epoch 40.000000 : 100%|█████| 158/158 [02:32<00:00,  1.03batch/s, ut_loss=0.013]


Female TPR 0.8327557438069567
male TPR 0.579849946409432
DP 0.4256107038959558
EOP 0.2529057973975247
EoD 0.26125128219958144
acc 0.7510279605263158
Trade off 0.5548209428710792


epoch 41.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0181]


Female TPR 0.8559354590856704
male TPR 0.5897161221210498
DP 0.45531636926080477
EOP 0.2662193369646205
EoD 0.2876946605327928
acc 0.7572471217105263
Trade off 0.5393911680905821


epoch 42.000000 : 100%|█████| 158/158 [02:32<00:00,  1.03batch/s, ut_loss=0.016]


Female TPR 0.8190806368772471
male TPR 0.5712754555198285
DP 0.4215375261976015
EOP 0.24780518135741858
EoD 0.25438722757403215
acc 0.75390625
Trade off 0.5621221292117647


epoch 43.000000 : 100%|████| 158/158 [02:33<00:00,  1.03batch/s, ut_loss=0.0109]


Female TPR 0.8551359711303003
male TPR 0.5871313672922251
DP 0.4572212828437435
EOP 0.2680046038380751
EoD 0.2885350628583374
acc 0.7587890625
Trade off 0.5398518126553435


epoch 44.000000 : 100%|███| 158/158 [02:33<00:00,  1.03batch/s, ut_loss=0.00946]


Female TPR 0.8232271325796505
male TPR 0.5531233315536572
DP 0.43667799704188137
EOP 0.2701038010259933
EoD 0.2701584066417632
acc 0.758069490131579
Trade off 0.5532706445538977


epoch 45.000000 : 100%|████| 158/158 [02:33<00:00,  1.03batch/s, ut_loss=0.0106]


Female TPR 0.8314259997428314
male TPR 0.5860415556739478
DP 0.42465159012697995
EOP 0.2453844440688836
EoD 0.26013821098922435
acc 0.7504111842105263
Trade off 0.5552005612436948


epoch 46.000000 : 100%|████| 158/158 [02:33<00:00,  1.03batch/s, ut_loss=0.0168]


Female TPR 0.8120571943836146
male TPR 0.5737265415549598
DP 0.41683635884623893
EOP 0.23833065282865473
EoD 0.2503798121597807
acc 0.7512849506578947
Trade off 0.563178365833701


epoch 47.000000 : 100%|█████| 158/158 [02:33<00:00,  1.03batch/s, ut_loss=0.015]


Female TPR 0.8406635802469136
male TPR 0.6014997321906802
DP 0.4342925413748684
EOP 0.23916384805623336
EoD 0.2677666554791957
acc 0.7503083881578947
Trade off 0.549400820482869


epoch 48.000000 : 100%|███| 158/158 [02:32<00:00,  1.03batch/s, ut_loss=0.00833]


Female TPR 0.8558013892462053
male TPR 0.6048257372654156
DP 0.4464991243739712
EOP 0.25097565198078975
EoD 0.27377166337072745
acc 0.7602796052631579
Trade off 0.5521365931034231


epoch 49.000000 : 100%|████| 158/158 [02:32<00:00,  1.04batch/s, ut_loss=0.0153]


Female TPR 0.8362445414847162
male TPR 0.583467525496511
DP 0.4335688429721283
EOP 0.2527770159882051
EoD 0.26446311959386953
acc 0.756578947368421
Trade off 0.5564917187283224


epoch 50.000000 : 100%|████| 158/158 [02:33<00:00,  1.03batch/s, ut_loss=0.0157]


Female TPR 0.843798250128667
male TPR 0.5951742627345844
DP 0.43908116824643234
EOP 0.24862398739408265
EoD 0.271237969697144
acc 0.7540090460526315
Trade off 0.5494931632680354


epoch 51.000000 : 100%|███| 158/158 [02:32<00:00,  1.03batch/s, ut_loss=0.00767]


Female TPR 0.8342314600975109
male TPR 0.5128068303094984
DP 0.4613705290784069
EOP 0.32142462978801256
EoD 0.30842393098717813
acc 0.751233552631579
Trade off 0.5195351472394842


epoch 52.000000 : 100%|███| 158/158 [02:33<00:00,  1.03batch/s, ut_loss=0.00536]


Female TPR 0.8469741744828473
male TPR 0.5657541599570585
DP 0.4567801323370031
EOP 0.2812200145257888
EoD 0.2906537576228341
acc 0.7586348684210527
Trade off 0.5381347932507694


epoch 53.000000 : 100%|████| 158/158 [02:33<00:00,  1.03batch/s, ut_loss=0.0138]


Female TPR 0.819138019497178
male TPR 0.5705850778314546
DP 0.4223056694554209
EOP 0.2485529416657234
EoD 0.2533173806196834
acc 0.7567845394736842
Trade off 0.5650778622407372


epoch 54.000000 : 100%|████| 158/158 [02:33<00:00,  1.03batch/s, ut_loss=0.0107]


Female TPR 0.8565362840967576
male TPR 0.5939167556029883
DP 0.4489990799911139
EOP 0.26261952849376935
EoD 0.2846418008618217
acc 0.7529296875
Trade off 0.5386144253276713
