In [1]:
import torchvision.transforms as tvt
from torchvision.datasets import CelebA
from typing import Any, Callable, Optional, Tuple, Union, List
#step 1 import image
%matplotlib inline
import torchvision.datasets
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.utils import make_grid
from PIL import Image
from time import time
from tqdm import tqdm
from transformers import ViTConfig, ViTModel

image_size = 64
batch_size = 256
device = torch.device('cuda:2')

class CelebAWithIndex(CelebA):
    def __init__(
            self,
            root: str,
            split: str = "train",
            target_type: Union[List[str], str] = "attr",
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
    ) -> None:
        super(CelebAWithIndex, self).__init__(
            root,
            split=split,
            target_type=target_type,
            transform=transform,
            target_transform=target_transform,
            download=download
        )

    def __getitem__(self, index: int) -> Tuple[Any, Any, Any]:
        X, y = super(CelebAWithIndex, self).__getitem__(index)

        return index, X, y

dataset = CelebAWithIndex("../../celeba/datasets/",split='train', transform=tvt.Compose([
                              tvt.Resize((image_size,image_size)),
                              tvt.ToTensor(),
                              tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
                          ]))

test_dataset = CelebAWithIndex("../../celeba/datasets/",split='test', transform=tvt.Compose([
                              tvt.Resize((image_size,image_size)),
                              tvt.ToTensor(),
                              tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
                          ]))

training_data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
valid_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
print('Done')


  from .autonotebook import tqdm as notebook_tqdm


Done


In [2]:
class VisionTransformer(nn.Module):
    def __init__(self, vit):
        super(VisionTransformer, self).__init__()
        self.vit = vit
        self.seq = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )
    
    def forward(self, x):
        z = self.vit(x)
        m = z.last_hidden_state
        g = m[:,0]
        y = self.seq(g)
        return y 

In [None]:
import os
import pickle
from tqdm import tqdm
from datetime import datetime
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
#test code

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset

import warnings

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    from torch.utils.tensorboard import SummaryWriter



class MultiDimAverageMeter(object):
    def __init__(self, dims):
        self.dims = dims
        self.cum = torch.zeros(np.prod(dims))
        self.cnt = torch.zeros(np.prod(dims))
        self.idx_helper = torch.arange(np.prod(dims), dtype=torch.long).reshape(
            *dims
        )

    def add(self, vals, idxs):
        flattened_idx = torch.stack(
            [self.idx_helper[tuple(idxs[i])] for i in range(idxs.size(0))],
            dim=0,
        )
        self.cum.index_add_(0, flattened_idx, vals.view(-1).float())
        self.cnt.index_add_(
            0, flattened_idx, torch.ones_like(vals.view(-1), dtype=torch.float)
        )
        
    def get_mean(self):
        return (self.cum / self.cnt).reshape(*self.dims)

    def reset(self):
        self.cum.zero_()
        self.cnt.zero_()


class EMA:
    
    def __init__(self, label, alpha=0.9):
        self.label = label
        self.alpha = alpha
        self.parameter = torch.zeros(label.size(0))
        self.updated = torch.zeros(label.size(0))
        
    def update(self, data, index):
        self.parameter[index] = self.alpha * self.parameter[index] + (1-self.alpha*self.updated[index]) * data
        self.updated[index] = 1
        
    def max_loss(self, label):
        label_index = np.where(self.label == label)[0]
        return self.parameter[label_index].max()

class GeneralizedCELoss(nn.Module):
    
    def __init__(self, q=0.3):
        super(GeneralizedCELoss, self).__init__()
        self.q = q
             
    def forward(self, logits, targets):
        p = F.softmax(logits, dim=1)
        if np.isnan(p.mean().item()):
            raise NameError('GCE_p')
        Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1))
        # modify gradient of cross entropy
        loss_weight = (Yg.squeeze().detach()**self.q)*self.q
        if np.isnan(Yg.mean().item()):
            raise NameError('GCE_Yg')

        loss = F.cross_entropy(logits, targets, reduction='none') * loss_weight

        return loss



def train(
    log_dir,
    target_attr_idx,
    bias_attr_idx,
    main_num_steps,
    main_valid_freq,
    main_optimizer_tag,
    main_learning_rate,
    main_weight_decay,
):
    
    start_time = datetime.now()
    writer = SummaryWriter(os.path.join(log_dir, "summary", 'CelebA'))
    
    # define model and optimizer
    configuration = ViTConfig(num_hidden_layers = 8, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= 64, patch_size = 16)
    vit = ViTModel(configuration)
    configuration = vit.config
    vit = vit.to(device)
    model_1 = VisionTransformer(vit)
    model_2 = VisionTransformer(vit)
    model_b = model_1.to(device)
    model_d = model_2.to(device)
    
    if main_optimizer_tag == "Adam":
        optimizer_b = torch.optim.Adam(
            model_b.parameters(),
            lr=main_learning_rate,
            weight_decay=main_weight_decay,
        )
        optimizer_d = torch.optim.Adam(
            model_d.parameters(),
            lr=main_learning_rate,
            weight_decay=main_weight_decay,
        )
    elif main_optimizer_tag == "AdamW":
        optimizer_b = torch.optim.AdamW(
            model_b.parameters(),
            lr=main_learning_rate,
            weight_decay=main_weight_decay,
        )
        optimizer_d = torch.optim.AdamW(
            model_d.parameters(),
            lr=main_learning_rate,
            weight_decay=main_weight_decay,
        )
    else:
        raise NotImplementedError
    
    # define loss
    criterion = nn.CrossEntropyLoss(reduction='none')
    bias_criterion = GeneralizedCELoss()
    
    train_target_attr = dataset.attr[:, target_attr_idx]
    train_bias_attr = dataset.attr[:, bias_attr_idx]
    attr_dims = []
    attr_dims.append(torch.max(train_target_attr).item() + 1)
    attr_dims.append(torch.max(train_bias_attr).item() + 1)
    num_classes = attr_dims[0]
    print('num',num_classes)
    

    sample_loss_ema_b = EMA(torch.LongTensor(train_target_attr), alpha=0.7)
    sample_loss_ema_d = EMA(torch.LongTensor(train_target_attr), alpha=0.7)

    # define evaluation function
    def evaluate(model, data_loader):
        print('valid')
        model.eval()
        acc = 0
        test_pred = []
        test_gt = []
        
        attrwise_acc_meter = MultiDimAverageMeter(attr_dims)
        for index, data, attr in tqdm(data_loader, leave=False):
            label = attr[:, target_attr_idx]
            data = data.to(device)
            attr = attr.to(device)
            label = label.to(device)
            test_gt.extend(label.detach().cpu().numpy())
            with torch.no_grad():
                logit = model(data)
                pred = logit.data.max(1, keepdim=True)[1].squeeze(1)
                correct = (pred == label).long()
                
                test_pred_ = torch.argmax(logit, dim=1)
                test_pred.extend(test_pred_.detach().cpu().numpy())
                

            attr = attr[:, [target_attr_idx, bias_attr_idx]]

            attrwise_acc_meter.add(correct.cpu(), attr.cpu())

        accs = attrwise_acc_meter.get_mean()
        ACC = accuracy_score(test_gt, test_pred)
        print('Test ACC', ACC)

        model.train()

        return accs

    # jointly training biased/de-biased model
    valid_attrwise_accs_list = []
    num_updated = 0
    
    for step in tqdm(range(main_num_steps)):
        # train main model
        try:
            index, data, attr = next(train_iter)
        except:
            train_iter = iter(training_data_loader)
            index, data, attr = next(train_iter)

        data = data.to(device)
        attr = attr.to(device)
        label = attr[:, target_attr_idx]
        bias_label = attr[:, bias_attr_idx]
        
        logit_b = model_b(data)
        if np.isnan(logit_b.mean().item()):
            print(logit_b)
            raise NameError('logit_b')
        logit_d = model_d(data)
        
        loss_b = criterion(logit_b, label).cpu().detach()
        loss_d = criterion(logit_d, label).cpu().detach()
                
        if np.isnan(loss_b.mean().item()):
            raise NameError('loss_b')
        if np.isnan(loss_d.mean().item()):
            raise NameError('loss_d')
        
        loss_per_sample_b = loss_b
        loss_per_sample_d = loss_d
        
        # EMA sample loss
        sample_loss_ema_b.update(loss_b, index)
        sample_loss_ema_d.update(loss_d, index)
        
        # class-wise normalize
        loss_b = sample_loss_ema_b.parameter[index].clone().detach()
        loss_d = sample_loss_ema_d.parameter[index].clone().detach()
        
        if np.isnan(loss_b.mean().item()):
            raise NameError('loss_b_ema')
        if np.isnan(loss_d.mean().item()):
            raise NameError('loss_d_ema')
        
        label_cpu = label.cpu()
        
        for c in range(num_classes):
            class_index = np.where(label_cpu == c)[0]
            max_loss_b = sample_loss_ema_b.max_loss(c)
            max_loss_d = sample_loss_ema_d.max_loss(c)
            loss_b[class_index] /= max_loss_b
            loss_d[class_index] /= max_loss_d
            
        # re-weighting based on loss value / generalized CE for biased model
        loss_weight = loss_b / (loss_b + loss_d + 1e-8)
        if np.isnan(loss_weight.mean().item()):
            raise NameError('loss_weight')
            
        loss_b_update = bias_criterion(logit_b, label)

        if np.isnan(loss_b_update.mean().item()):
            raise NameError('loss_b_update')
        loss_d_update = criterion(logit_d, label) * loss_weight.to(device)
        if np.isnan(loss_d_update.mean().item()):
            raise NameError('loss_d_update')
        loss = loss_b_update.mean() + loss_d_update.mean()
        
        num_updated += loss_weight.mean().item() * data.size(0)

        optimizer_b.zero_grad()
        optimizer_d.zero_grad()
        loss.backward()
        optimizer_b.step()
        optimizer_d.step()
        
        main_log_freq = 100
        if step % main_log_freq == 0:
        
            writer.add_scalar("loss/b_train", loss_per_sample_b.mean(), step)
            writer.add_scalar("loss/d_train", loss_per_sample_d.mean(), step)

            bias_attr = attr[:, bias_attr_idx]

            aligned_mask = (label == bias_attr)
            skewed_mask = (label != bias_attr)
            
            writer.add_scalar('loss_variance/b_ema', sample_loss_ema_b.parameter.var(), step)
            writer.add_scalar('loss_std/b_ema', sample_loss_ema_b.parameter.std(), step)
            writer.add_scalar('loss_variance/d_ema', sample_loss_ema_d.parameter.var(), step)
            writer.add_scalar('loss_std/d_ema', sample_loss_ema_d.parameter.std(), step)

            if aligned_mask.any().item():
                writer.add_scalar("loss/b_train_aligned", loss_per_sample_b[aligned_mask].mean(), step)
                writer.add_scalar("loss/d_train_aligned", loss_per_sample_d[aligned_mask].mean(), step)
                writer.add_scalar('loss_weight/aligned', loss_weight[aligned_mask].mean(), step)

            if skewed_mask.any().item():
                writer.add_scalar("loss/b_train_skewed", loss_per_sample_b[skewed_mask].mean(), step)
                writer.add_scalar("loss/d_train_skewed", loss_per_sample_d[skewed_mask].mean(), step)
                writer.add_scalar('loss_weight/skewed', loss_weight[skewed_mask].mean(), step)

        if step % main_valid_freq == 0:
            valid_attrwise_accs_b = evaluate(model_b, valid_loader)
            valid_attrwise_accs_d = evaluate(model_d, valid_loader)
            valid_attrwise_accs_list.append(valid_attrwise_accs_d)
            valid_accs_b = torch.mean(valid_attrwise_accs_b)
            writer.add_scalar("acc/b_valid", valid_accs_b, step)
            valid_accs_d = torch.mean(valid_attrwise_accs_d)
            writer.add_scalar("acc/d_valid", valid_accs_d, step)
            print('valid_accs_d',valid_accs_d)

            eye_tsr = torch.eye(attr_dims[0]).long()
            
            writer.add_scalar(
                "acc/b_valid_aligned",
                valid_attrwise_accs_b[eye_tsr == 1].mean(),
                step,
            )
            writer.add_scalar(
                "acc/b_valid_skewed",
                valid_attrwise_accs_b[eye_tsr == 0].mean(),
                step,
            )
            writer.add_scalar(
                "acc/d_valid_aligned",
                valid_attrwise_accs_d[eye_tsr == 1].mean(),
                step,
            )
            writer.add_scalar(
                "acc/d_valid_skewed",
                valid_attrwise_accs_d[eye_tsr == 0].mean(),
                step,
            )
            
            num_updated_avg = num_updated / batch_size / main_valid_freq
            writer.add_scalar("num_updated/all", num_updated_avg, step)
            num_updated = 0

    os.makedirs(os.path.join(log_dir, "result"), exist_ok=True)
    result_path = os.path.join(log_dir, "result", "result.th")
    model_path = os.path.join(log_dir, "result", "model.th")
    valid_attrwise_accs_list = torch.stack(valid_attrwise_accs_list)
    print('valid_attrwise_accs_list',valid_attrwise_accs_list)
    with open(result_path, "wb") as f:
        torch.save({"valid/attrwise_accs": valid_attrwise_accs_list}, f)
    state_dict = {
        #'steps': step, 
        'state_dict': model_d.state_dict(), 
        'state_dict_b': model_b.state_dict(),
        #'optimizer': optimizer_d.state_dict(), 
    }
    with open(model_path, "wb") as f:
        torch.save(state_dict, f)
    
it = int(len(dataset)/batch_size)
train(log_dir=r'./log', target_attr_idx=2, bias_attr_idx=20, main_num_steps=636 * 200, main_valid_freq=636, main_optimizer_tag="Adam",main_learning_rate=5e-5,
    main_weight_decay=0
)


In [None]:
int(len(dataset)/batch_size)*20

In [9]:
model_path = './log/result/model.th'
checkpoint = torch.load(model_path)


configuration = ViTConfig(num_hidden_layers = 8, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= 64, patch_size = 16)
vit = ViTModel(configuration)
configuration = vit.config
vit = vit.to(device)

model_d = VisionTransformer(vit)
model_d.load_state_dict(checkpoint['state_dict'])

model_d = model_d.to(device)

test_pred = []
test_gt = []
sense_gt = []
female_predic = []
female_gt = []
male_predic = []
male_gt = []


model_d.eval()
# Evaluate on test set.
for step, (index, test_input, attributes) in enumerate(valid_loader):
    sensitive, test_target = attributes[:,20], attributes[:,2]
    test_input = test_input.to(device)
    test_target = test_target.to(device)

    gt = test_target.detach().cpu().numpy()
    sen = sensitive.detach().cpu().numpy()
    test_gt.extend(gt)
    sense_gt.extend(sen)

    with torch.no_grad():
        prediction= model_d(test_input)
        test_pred_ = torch.argmax(prediction, dim=1)
        test_pred.extend(test_pred_.detach().cpu().numpy())


for i in range(len(sense_gt)):
    if sense_gt[i] == 0:
        female_predic.append(test_pred[i])
        female_gt.append(test_gt[i])
    else:
        male_predic.append(test_pred[i])
        male_gt.append(test_gt[i])
female_CM = confusion_matrix(female_gt, female_predic)    
male_CM = confusion_matrix(male_gt, male_predic) 
female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])

print('Female TPR', female_TPR)
print('male TPR', male_TPR)
print('DP',abs(female_dp - male_dp))
print('EOP', abs(female_TPR - male_TPR))
print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
print('acc', accuracy_score(test_gt, test_pred))
print('Trade off',accuracy_score(test_gt, test_pred)*(1-0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR))) )


Female TPR 0.8203907815631263
male TPR 0.6718913270637409
DP 0.3459720451596834
EOP 0.1484994544993854
EoD 0.1979034947919354
acc 0.7087466185752931
Trade off 0.5684831858372758
