In [None]:
from tqdm import tqdm
import os.path as osp
import numpy as np
from PIL import Image
from math import log, sqrt, pi
import argparse
from torch import nn, optim
from torch.autograd import Variable, grad
from scipy import linalg as la
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset, random_split
from torchvision.utils import make_grid
from torchvision import utils
from PIL import Image
import random
from tqdm import trange
from transformers import ViTModel
from sklearn.model_selection import train_test_split

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_size = 224
batch_size = 32
torch.set_num_threads(5)   # Sets the number of threads used for intra-operations
torch.set_num_interop_threads(5)   # Sets the number of threads used for inter-operations

def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

       
        
def get_transform_ISIC(aug):
    if aug == False:
        transform=tvt.Compose([tvt.Resize((256,256)),
                               tvt.CenterCrop((224,224)),
                                tvt.ToTensor(),
                                tvt.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])                                  
                                ])
    if aug == True:
        transform=tvt.Compose([tvt.Resize((256,256)),
                               tvt.CenterCrop((224,224)),
                               tvt.RandomHorizontalFlip(),
                               tvt.RandomVerticalFlip(),
                               tvt.RandomResizedCrop(224, scale=(0.75, 1.0)),
                               tvt.RandomRotation(45),
                               tvt.ToTensor(),
                               tvt.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])                                  
                                ])
    return transform


class ConfounderDataset(Dataset):
    def __init__(self, root_dir,
                 target_name, confounder_names,
                 model_type=None, augment_data=None):
        raise NotImplementedError

    def __len__(self):
        if self.split == 'train':
            return len(self.training_sample)
        if self.split == 'val':
            return len(self.valid_sample)
        if self.split == 'test':
            return len(self.test_sample)

    def __getitem__(self, idx):
        if self.split == 'train': 
            y = self.training_sample_y_array[idx]
            y = torch.tensor(y)
            a = self.training_sample_confounder_array[idx]
            a = torch.tensor(a)
            img_filename = os.path.join(
                self.data_dir,
                self.training_sample[idx]) 
            img = Image.open(img_filename).convert('RGB')
            img = self.train_transform(img)
            x = img
            
        if self.split == 'val': 
            y = self.valid_sample_y_array[idx]
            y = torch.tensor(y)
            a = self.valid_sample_confounder_array[idx]
            a = torch.tensor(a)
            img_filename = os.path.join(
                self.data_dir,
                self.valid_sample[idx])       
            img = Image.open(img_filename).convert('RGB')
            img = self.eval_transform(img)
            x = img
            
        if self.split == 'test': 
            y = self.test_sample_y_array[idx]
            a = self.test_sample_confounder_array[idx]
            y = torch.tensor(y)
            a = torch.tensor(a)
            img_filename = os.path.join(
                self.data_dir,
                self.test_sample[idx])       
            img = Image.open(img_filename).convert('RGB')
            img = self.eval_transform(img)
            x = img
        return x,y,a



    
class ISICDataset(ConfounderDataset):
    def __init__(self, 
                 root_dir,
                 seed,
                 split,
                 target_name = ['label'], 
                 confounder_names=['patches'],
                 model_type=None,
                 augment_data=False,
                 mix_up=False,
                 group_id=None,
                 id_val=True):
        self.split = split
        self.augment_data = augment_data
        self.group_id = group_id
        self.mix_up = mix_up
        self.model_type = model_type
        self.target_name = target_name
        self.confounder_names = confounder_names
        self.split_dir = osp.join(root_dir, 'trap-sets')
        self.data_dir = osp.join(root_dir, 'ISIC2018_Task1-2_Training_Input')
        
        metadata = {}
        metadata['train'] = pd.read_csv(osp.join(self.split_dir, f'isic_annotated_train{seed}.csv'))
        if id_val:
            test_val_data = pd.read_csv(osp.join(self.split_dir, f'isic_annotated_test{seed}.csv'))
            idx_val, idx_test = train_test_split(np.arange(len(test_val_data)), 
                                                test_size=0.8, random_state=0)
            metadata['test'] = test_val_data.iloc[idx_test]
            metadata['val'] = test_val_data.iloc[idx_val]
        else:
            metadata['test'] = pd.read_csv(osp.join(self.split_dir, f'isic_annotated_test{seed}.csv'))
            metadata['val'] = pd.read_csv(osp.join(self.split_dir, f'isic_annotated_val{seed}.csv'))
            # subtracting two dataframes 
            metadata_new = metadata['train'].merge(metadata['val'], how='left', indicator=True)
            metadata_new = metadata_new[metadata_new['_merge'] == 'left_only']
            metadata['train'] = metadata_new.drop(columns=['_merge'])
        
        self.train_transform = get_transform_ISIC(aug = True)
        self.eval_transform = get_transform_ISIC(aug = False)
        
        self.precomputed = False
        self.pretransformed = False
        self.n_classes = 2
        self.n_confounders = 1
        confounder = confounder_names[0]
        
        self.training_sample = metadata['train']['image'].values
        self.training_sample_y_array = metadata['train'][target_name].values
        self.training_sample_confounder_array = metadata['train'][confounder].values
        
        self.valid_sample = metadata['val']['image'].values
        self.valid_sample_y_array = metadata['val'][target_name].values
        self.valid_sample_confounder_array = metadata['val'][confounder].values
        
        self.test_sample = metadata['test']['image'].values
        self.test_sample_y_array = metadata['test'][target_name].values
        self.test_sample_confounder_array = metadata['test'][confounder].values
        

    
data_dir = r"../../isic"
seed = 1

training_isic_dataset  = ISICDataset(data_dir, seed, 'train')
valid_isic_dataset  = ISICDataset(data_dir, seed, 'val')
test_isic_dataset  = ISICDataset(data_dir, seed, 'test')


training_data_loader  = torch.utils.data.DataLoader(dataset = training_isic_dataset,
                                                batch_size= batch_size,
                                                shuffle=True,
                                                num_workers=0)

valid_data_loader  = torch.utils.data.DataLoader(dataset = valid_isic_dataset,
                                                batch_size= batch_size,
                                                shuffle=False,
                                                num_workers=0)

test_data_loader  = torch.utils.data.DataLoader(dataset = test_isic_dataset,
                                                batch_size= batch_size,
                                                shuffle=False,
                                                num_workers=0)
print('Done')

In [None]:
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
import seaborn as sns
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

def augment_confusion_matrix(CM, gt):
    # Augmenting a 1x1 confusion matrix to a 2x2 matrix
    if CM.shape == (1, 1):
        augmented_CM = np.zeros((2, 2))
        if np.unique(gt).item() == 0:
            augmented_CM[0, 0] = CM[0, 0]
        else:
            augmented_CM[1, 1] = CM[0, 0]
        return augmented_CM
    return CM
    
def train_model():
    epoch =100
    weight_decay=1e-3
    init_lr=1e-3
    momentum_decay = 0.9
    schedule = False
    resnet50 = models.resnet50(pretrained=True)
    
    resnet50.fc = nn.Identity()
    num_classes = 2 
    classifier = nn.Linear(2048, num_classes)
    
    # Move the model to the appropriate device
    model = resnet50.to(device)
    #model.layer4[-1].relu = nn.SELU()
    classifier = classifier.to(device)
    resnet50_parameters = model.parameters()
    classifier_parameters = classifier.parameters()
    combined_parameters = list(resnet50_parameters) + list(classifier_parameters)
    
    criterion = nn.CrossEntropyLoss()
    mean_criterion = nn.MSELoss()
    acc = 0
    optimizer = optim.SGD(combined_parameters, lr=init_lr, momentum=momentum_decay, weight_decay = weight_decay)
    optimizer_2 = optim.SGD(model.parameters(), lr=init_lr, momentum=momentum_decay, weight_decay = weight_decay)
    
    if schedule == True:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max= epoch)
    else:
        scheduler = None
        
    for epoches in range(epoch):   
        with tqdm(training_data_loader, unit="batch") as tepoch:
            model.train()
            feature_y_0_a0 = []
            feature_y_0_a1 = []
            feature_y_1_a0 = []
            feature_y_1_a1 = []
            loss00 = 0
            loss01 = 0
            loss10 = 0
            loss11 = 0            
            with torch.no_grad(): 
                for step, (valid_input, valid_target, validsensitive) in enumerate(valid_data_loader):
                    valid_input = valid_input.to(device)
                    with torch.no_grad():
                        valid_feature = model(valid_input)
                        label = valid_target.squeeze().detach().cpu()
                        mask_00 = ((label == 0) & (validsensitive == 0))
                        mask_01 = ((label == 0) & (validsensitive == 1))
                        mask_10 = ((label == 1) & (validsensitive == 0))
                        mask_11 = ((label == 1) & (validsensitive == 1))
                        g1 = valid_feature[mask_00]
                        g2 = valid_feature[mask_01]
                        g3 = valid_feature[mask_10]
                        g4 = valid_feature[mask_11]
                        feature_y_0_a0.extend(g1.detach().cpu().numpy())
                        feature_y_0_a1.extend(g2.detach().cpu().numpy())
                        feature_y_1_a0.extend(g3.detach().cpu().numpy())
                        feature_y_1_a1.extend(g4.detach().cpu().numpy())
                        

                feature_g1 = np.array(feature_y_0_a0)
                feature_g3 = np.array(feature_y_1_a0)
                feature_g1_tensor = torch.from_numpy(feature_g1)
                feature_g3_tensor = torch.from_numpy(feature_g3)

                mu_1 = torch.mean(feature_g1_tensor, 0)
                mu_1 = mu_1 /torch.norm(mu_1)
                mu_2 = torch.mean(feature_g3_tensor, 0)
                mu_2 = mu_2 /torch.norm(mu_2)
                weight = torch.cat((mu_1.unsqueeze(0), mu_2.unsqueeze(0)), 0)
                print(weight,"sim:",  F.cosine_similarity(mu_1.unsqueeze(0), mu_2.unsqueeze(0)) )

                with torch.no_grad():
                    classifier.weight = nn.Parameter(weight)           
            
            for train_input, train_target, sensitive in tepoch:
                train_input = train_input.to(device)
                label = train_target.squeeze().detach().cpu() 
                sensitive = sensitive.squeeze().detach().cpu() 
                one_hot_labels = F.one_hot(train_target.squeeze(), num_classes=2)
                train_target = one_hot_labels.float().to(device)
                
                feature = model(train_input)
                classifier = classifier.to(device)
                outputs  = classifier(feature)

                mask_00 = ((label== 0) & (sensitive == 0))
                mask_01 = ((label == 0) & (sensitive == 1))
                mask_10 = ((label == 1) & (sensitive == 0))
                mask_11 = ((label == 1) & (sensitive == 1))
                
                count_00 = mask_00.sum()
                count_01 = mask_01.sum()
                count_10 = mask_10.sum()
                count_11 = mask_11.sum()
                
                g1_f = feature[mask_00]
                g2_f = feature[mask_01]               
                mu1 = torch.mean(g1_f, 0)
                mu2 = torch.mean(g2_f, 0)    
                
                g3_f = feature[mask_10]
                g4_f = feature[mask_11]               
                mu3 = torch.mean(g3_f, 0)
                mu4 = torch.mean(g4_f, 0)
                
                if count_00 > 0 and count_01 >0:
                    l1 = mean_criterion(mu1, mu2)
                else:
                    l1 = torch.tensor(0)
                    
                if count_10 > 0 and count_11 >0:
                    l2 = mean_criterion(mu3, mu4)
                else:
                    l2 = torch.tensor(0)
                
                loss_mean = l1 + l2
                
                
                if count_00 > 0:
                    loss_00 = criterion(outputs[mask_00], train_target[mask_00])
                    loss00 += loss_00.item()
                else:
                    loss_00 = torch.tensor(0)
                if count_01 > 0:
                    loss_01 = criterion(outputs[mask_01], train_target[mask_01])
                    loss01 += loss_01.item()
                else:
                    loss_01 = torch.tensor(0)
                if count_10 > 0:
                    loss_10 = criterion(outputs[mask_10], train_target[mask_10])
                    loss10 += loss_10.item()
                else:
                    loss_10 = torch.tensor(0)
                if count_11 > 0:
                    loss_11 = criterion(outputs[mask_11], train_target[mask_11])
                    loss11 += loss_11.item()
                else:
                    loss_11 = torch.tensor(0)

                loss = loss_00 + loss_01 + loss_10 + loss_11 + loss_mean
                tepoch.set_postfix(ut_loss = loss.item())
                optimizer_2.zero_grad()    
                loss.backward()
                optimizer_2.step()
                tepoch.set_description(f"epoch %2f " % epoches)
            if schedule:
                scheduler.step()
                       
        print("loss g1 (label=0, sensitive=0):",loss00 )
        print("loss g2 (label=0, sensitive=1):",loss01 )
        print("loss g3 (label=1, sensitive=0):",loss10 )
        print("loss g4 (label=1, sensitive=1):",loss11 )
        print('mean loss:', loss_mean.item())
        
        model.eval()
        test_pred = []
        test_prop = []
        test_gt = []
        sense_gt = []
        female_predic = []
        female_gt = []
        male_predic = []
        male_gt = []
        correct_00, total_00 = 0, 0
        correct_01, total_01 = 0, 0
        correct_10, total_10 = 0, 0
        correct_11, total_11 = 0, 0
        for step, (test_input, test_target, sensitive) in tqdm(enumerate(test_data_loader), total=len(test_data_loader)):
            test_input = test_input.to(device)
            test_target = test_target.squeeze()
            sensitive = sensitive.squeeze()
            gt = test_target.detach().cpu().numpy()
            sen = sensitive.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)


            with torch.no_grad():
                test_feature = model(test_input)
                test_pred_  = classifier(test_feature)
                test_prop_ = F.softmax(test_pred_, 1)[:,1].view(-1)

                _, predic = torch.max(test_pred_.data, 1)
                predic = predic.detach().cpu()
                test_prop_ = test_prop_.detach().cpu()
                test_pred.extend(predic.numpy())
                test_prop.extend(test_prop_.numpy())
                label = test_target.squeeze().detach().cpu()
                mask_00 = ((label == 0) & (sensitive == 0))
                mask_01 = ((label == 0) & (sensitive == 1))
                mask_10 = ((label == 1) & (sensitive == 0))
                mask_11 = ((label == 1) & (sensitive == 1))

                correct_00 += (predic[mask_00] == label[mask_00]).float().sum().item()
                total_00 += mask_00.float().sum().item()

                correct_01 += (predic[mask_01] == label[mask_01]).float().sum().item()
                total_01 += mask_01.float().sum().item()

                correct_10 += (predic[mask_10] == label[mask_10]).float().sum().item()
                total_10 += mask_10.float().sum().item()

                correct_11 += (predic[mask_11] == label[mask_11]).float().sum().item()
                total_11 += mask_11.float().sum().item() 
        if total_11 == 0:
            acc_11 = 0
        else:
            acc_11 = correct_11 / total_11
            
        acc_00 = correct_00 / total_00
        acc_01 = correct_01 / total_01
        acc_10 = correct_10 / total_10

        print(f'Accuracy for y=0, s=0: {acc_00}', total_00)
        print(f'Accuracy for y=0, s=1: {acc_01}', total_01)
        print(f'Accuracy for y=1, s=0: {acc_10}', total_10)
        print(f'Accuracy for y=1, s=1: {acc_11}', total_11)  
        roc = roc_auc_score(test_gt, test_prop)
        print('**************')
        print('ROC:', roc)
        for i in range(len(sense_gt)):
            if sense_gt[i] == 0:
                female_predic.append(test_pred[i])
                female_gt.append(test_gt[i])
            else:
                male_predic.append(test_pred[i])
                male_gt.append(test_gt[i])
        female_CM = confusion_matrix(female_gt, female_predic)    
        male_CM = confusion_matrix(male_gt, male_predic) 
        male_CM = augment_confusion_matrix(male_CM, male_gt)
        print('cm for male', male_CM)
        female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
        male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
        female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
        male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0]+1e-10)
        female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
        male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])
        if accuracy_score(test_gt, test_pred)> acc:
            acc = accuracy_score(test_gt, test_pred)
        if acc_00 > 0.7 and acc_10> 0.7:
            torch.save(model.state_dict(), f'ISIC_best_ours.pth')
            
        print('Female TPR', female_TPR)
        print('male TPR', male_TPR)
        print('DP',abs(female_dp - male_dp))
        print('EOP', abs(female_TPR - male_TPR))
        print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
        print('acc', accuracy_score(test_gt, test_pred))
        print('Trade off',accuracy_score(test_gt, test_pred)*(1-0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR))) )
        



seed_everything(2048)    
train_model()