In [None]:
from tqdm import tqdm
import numpy as np
from PIL import Image
from math import log, sqrt, pi
import argparse
from torch import nn, optim
from torch.autograd import Variable, grad
from scipy import linalg as la
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset, random_split
from torchvision.utils import make_grid
from torchvision import utils
from PIL import Image
import random
from tqdm import trange
from transformers import ViTModel

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
image_size = 224
batch_size = 64
torch.set_num_threads(1)   # Sets the number of threads used for intra-operations
torch.set_num_interop_threads(1)   # Sets the number of threads used for inter-operations

def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    
class ConfounderDataset(Dataset):
    def __init__(self, root_dir,
                 target_name, confounder_names,
                 model_type=None, augment_data=None):
        raise NotImplementedError

    def __len__(self):
        if self.split == 'Train':
            return len(self.training_sample)
        if self.split == 'Valid':
            return len(self.valid_sample)
        if self.split == 'Test':
            return len(self.test_sample)

    def __getitem__(self, idx):
        if self.split == 'Train': 
            y = self.training_sample_y_array[idx]
            a = self.training_sample_confounder_array[idx]
            img_filename = os.path.join(
                self.data_dir,
                self.training_sample[idx])       
            img = Image.open(img_filename).convert('RGB')
            img = self.train_transform(img)
            x = img
            
        if self.split == 'Valid': 
            y = self.valid_sample_y_array[idx]
            a = self.valid_sample_confounder_array[idx]
            img_filename = os.path.join(
                self.data_dir,
                self.valid_sample[idx])       
            img = Image.open(img_filename).convert('RGB')
            img = self.eval_transform(img)
            x = img
            
        if self.split == 'Test': 
            y = self.test_sample_y_array[idx]
            a = self.test_sample_confounder_array[idx]
            img_filename = os.path.join(
                self.data_dir,
                self.test_sample[idx])       
            img = Image.open(img_filename).convert('RGB')
            img = self.eval_transform(img)
            x = img
        return x,y,a

    
    
class CUBDataset(ConfounderDataset):
    """
    CUB dataset (already cropped and centered).
    Note: metadata_df is one-indexed.
    """
    def __init__(self, fold_dir, split):
        self.data_dir = fold_dir
        self.split = split

        if not os.path.exists(self.data_dir):
            raise ValueError(
                f'{self.data_dir} does not exist yet. Please generate the dataset first.')

        # Read in metadata
        self.metadata_df = pd.read_csv(
            os.path.join(self.data_dir, 'metadata.csv'))

        # Get the y values
        self.y_array = self.metadata_df['y'].values
        self.n_classes = 2

        # We only support one confounder for CUB for now
        self.confounder_array = self.metadata_df['place'].values
        self.n_confounders = 1
        
        # Extract filenames and splits
        self.filename_array = self.metadata_df['img_filename'].values
        self.split_array = self.metadata_df['split'].values

        self.training_sample = self.filename_array[self.split_array == 0]
        self.training_sample_y_array = self.y_array[self.split_array == 0]
        self.training_sample_confounder_array =self.confounder_array[self.split_array == 0]
        
        self.valid_sample = self.filename_array[self.split_array == 1]
        self.valid_sample_y_array = self.y_array[self.split_array == 1]
        self.valid_sample_confounder_array =self.confounder_array[self.split_array == 1]
        
        self.test_sample = self.filename_array[self.split_array == 2]
        self.test_sample_y_array = self.y_array[self.split_array == 2]
        self.test_sample_confounder_array =self.confounder_array[self.split_array == 2]
        
        # Set transform
        self.train_transform = get_transform_cub(aug = True)
        self.eval_transform = get_transform_cub(aug = False)

        


def get_transform_cub(aug):
    if aug == False:
        transform=tvt.Compose([tvt.Resize((256,256)),
                               tvt.CenterCrop((224,224)),
                                tvt.ToTensor(),
                                tvt.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])                                  
                                ])
    if aug == True:
        transform=tvt.Compose([tvt.Resize((256,256)),
                               tvt.RandomResizedCrop(
                                    (224,224),
                                    scale=(0.7, 1.0),
                                    ratio=(0.75, 1.3333333333333333),
                                    interpolation=2),
                                tvt.RandomHorizontalFlip(),
                                tvt.ToTensor(),
                                tvt.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])                                  
                                ])
    return transform


seed_everything(2048)

fold_dir = r'../../waterbird'

training_dataset = CUBDataset(fold_dir, 'Train')
valid_dataset = CUBDataset(fold_dir, 'Valid')
test_dataset = CUBDataset(fold_dir, 'Test')

training_data_loader  = torch.utils.data.DataLoader(dataset = training_dataset,
                                                batch_size= batch_size,
                                                shuffle=True,
                                                num_workers=0)

valid_data_loader  = torch.utils.data.DataLoader(dataset = valid_dataset,
                                                batch_size= batch_size,
                                                shuffle=False,
                                                num_workers=0)

test_data_loader  = torch.utils.data.DataLoader(dataset = test_dataset,
                                                batch_size= batch_size,
                                                shuffle=False,
                                                num_workers=0)
print('Done')

In [None]:
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
import seaborn as sns
import torch.nn.functional as F


    
def compute_variance(g1, g2, g3, g4):
    g1_tensor = torch.from_numpy(g1)
    g2_tensor = torch.from_numpy(g2)
    g3_tensor = torch.from_numpy(g3)
    g4_tensor = torch.from_numpy(g4)
    sample_feature = np.concatenate((g1_tensor, g2_tensor, g3_tensor, g4_tensor), axis=0)
    sample_tensor = torch.from_numpy(sample_feature)
    
    cov_g1 = torch.cov(g1_tensor.T)
    cov_g2 = torch.cov(g2_tensor.T)
    cov_g3 = torch.cov(g3_tensor.T)
    cov_g4 = torch.cov(g4_tensor.T)
    
    cov_all = torch.cov(sample_tensor.T)
    g1_trace = torch.trace(cov_g1)
    g2_trace = torch.trace(cov_g2)
    g3_trace = torch.trace(cov_g3)
    g4_trace = torch.trace(cov_g4)
    
    all_trace = torch.trace(cov_all)
    
    print(cov_g1.shape)
    print(cov_all.shape)
    print("g1 trace",g1_trace,"g2 trace",g2_trace,"g3 trace",g3_trace,"g4 trace",g4_trace,"ALL trace",all_trace)
    print("g1",g1_trace/all_trace,"g2",g2_trace/all_trace,"g3",g3_trace/all_trace,"g4",g4_trace/all_trace)

    
def neural_collapse():
    model = models.resnet50(pretrained=False)
    model.load_state_dict(torch.load('ICML_weight/WB_ResNet_50_model_epoch_100.pth'))
    model.fc = nn.Identity()
    model.to(device)   
    feature_list = []
    feature_y_0_a0 = []
    feature_y_0_a1 = []
    feature_y_1_a0 = []
    feature_y_1_a1 = []
    
    for step, (test_input, test_target, sensitive) in tqdm(enumerate(test_data_loader), total=len(test_data_loader)):
        with torch.no_grad():
            test_input = test_input.to(device)
            test_pred_ = model(test_input)
            label = test_target.squeeze().detach().cpu()
            mask_00 = ((label == 0) & (sensitive == 0))
            mask_01 = ((label == 0) & (sensitive == 1))
            mask_10 = ((label == 1) & (sensitive == 0))
            mask_11 = ((label == 1) & (sensitive == 1))
            g1 = test_pred_[mask_00]
            g2 = test_pred_[mask_01]
            g3 = test_pred_[mask_10]
            g4 = test_pred_[mask_11]
            feature_y_0_a0.extend(g1.detach().cpu().numpy())
            feature_y_0_a1.extend(g2.detach().cpu().numpy())
            feature_y_1_a0.extend(g3.detach().cpu().numpy())
            feature_y_1_a1.extend(g4.detach().cpu().numpy())

            
    feature_g1 = np.array(feature_y_0_a0)
    feature_g2 = np.array(feature_y_0_a1)
    feature_g3 = np.array(feature_y_1_a0)
    feature_g4 = np.array(feature_y_1_a1)
    print(feature_g1.shape, feature_g2.shape,feature_g3.shape,feature_g4.shape)
    
    compute_variance(feature_g1,feature_g2,feature_g3,feature_g4)




def train_model():
    epoch =30
    weight_decay=1e-3
    init_lr=1e-4
    momentum_decay = 0.9
    schedule = True
    resnet50 = models.resnet50(pretrained=True)
    resnet50.fc = nn.Identity()
    num_classes = 2 
    classifier = nn.Linear(2048, num_classes)
    model = resnet50.to(device)
    classifier = classifier.to(device)
    resnet50_parameters = model.parameters()
    classifier_parameters = classifier.parameters()
    combined_parameters = list(resnet50_parameters) + list(classifier_parameters)
    criterion = nn.CrossEntropyLoss()
    mean_criterion = nn.MSELoss()
    acc = 0
    optimizer = optim.SGD(combined_parameters, lr=init_lr, momentum=momentum_decay, weight_decay = weight_decay)
    optimizer_2 = optim.SGD(model.parameters(), lr=init_lr, momentum=momentum_decay, weight_decay = weight_decay)
    
    if schedule == True:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max= epoch)
    else:
        scheduler = None
    for epoches in range(epoch):
        with tqdm(training_data_loader, unit="batch") as tepoch:
            model.train()
            feature_y_0_a0 = []
            feature_y_0_a1 = []
            feature_y_1_a0 = []
            feature_y_1_a1 = []
            loss00 = 0
            loss01 = 0
            loss10 = 0
            loss11 = 0
            
            with torch.no_grad(): 
                for step, (valid_input, valid_target, validsensitive) in enumerate(valid_data_loader):
                    valid_input = valid_input.to(device)
                    with torch.no_grad():
                        valid_feature = model(valid_input)
                        label = valid_target.squeeze().detach().cpu()
                        mask_00 = ((label == 0) & (validsensitive == 0))
                        mask_01 = ((label == 0) & (validsensitive == 1))
                        mask_10 = ((label == 1) & (validsensitive == 0))
                        mask_11 = ((label == 1) & (validsensitive == 1))
                        g1 = valid_feature[mask_00]
                        g2 = valid_feature[mask_01]
                        g3 = valid_feature[mask_10]
                        g4 = valid_feature[mask_11]
                        feature_y_0_a0.extend(g1.detach().cpu().numpy())
                        feature_y_0_a1.extend(g2.detach().cpu().numpy())
                        feature_y_1_a0.extend(g3.detach().cpu().numpy())
                        feature_y_1_a1.extend(g4.detach().cpu().numpy())

                feature_g1 = np.array(feature_y_0_a0)
                feature_g4 = np.array(feature_y_1_a1)
                feature_g1_tensor = torch.from_numpy(feature_g1)
                feature_g4_tensor = torch.from_numpy(feature_g4)

                mu_1 = torch.mean(feature_g1_tensor, 0)
                mu_1 = mu_1 /torch.norm(mu_1)
                mu_2 = torch.mean(feature_g4_tensor, 0)
                mu_2 = mu_2 /torch.norm(mu_2)
                weight = torch.cat((mu_1.unsqueeze(0), mu_2.unsqueeze(0)), 0)

                with torch.no_grad():
                    classifier.weight = nn.Parameter(weight)

            for train_input, train_target, sensitive in tepoch:
                train_input = train_input.to(device)
                label = train_target.detach().cpu() 
                one_hot_labels = F.one_hot(train_target, num_classes=2)
                train_target = one_hot_labels.float().to(device)
                
                feature = model(train_input)
                classifier = classifier.to(device)
                outputs  = classifier(feature)
              
                mask_00 = ((label== 0) & (sensitive == 0))
                mask_01 = ((label == 0) & (sensitive == 1))
                mask_10 = ((label == 1) & (sensitive == 0))
                mask_11 = ((label == 1) & (sensitive == 1))
   
                count_00 = mask_00.sum()
                count_01 = mask_01.sum()
                count_10 = mask_10.sum()
                count_11 = mask_11.sum()
                    
                if count_01==0 or count_10==0 or count_00 == 0 or count_11 == 0:
                    continue    
                g1_f = feature[mask_00]
                g2_f = feature[mask_01]    
                mu1 = torch.mean(g1_f, 0)
                mu2 = torch.mean(g2_f, 0)
                    
                g3_f = feature[mask_10]
                g4_f = feature[mask_11]
                mu3 = torch.mean(g3_f, 0)
                mu4 = torch.mean(g4_f, 0)
                
                loss_mean = mean_criterion(mu1, mu2) + mean_criterion(mu3, mu4)
                
                
                if count_00 > 0:
                    loss_00 = criterion(outputs[mask_00], train_target[mask_00])
                    loss00 += loss_00.item()
                else:
                    loss_00 = torch.tensor(0)
                    
                if count_01 > 0:
                    loss_01 = criterion(outputs[mask_01], train_target[mask_01])
                    loss01 += loss_01.item()
                else:
                    loss_01 = torch.tensor(0)
                    
                if count_10 > 0:
                    loss_10 = criterion(outputs[mask_10], train_target[mask_10])
                    loss10 += loss_10.item()
                else:
                    loss_10 = torch.tensor(0)
                    
                if count_11 > 0:
                    loss_11 = criterion(outputs[mask_11], train_target[mask_11])
                    loss11 += loss_11.item()
                else:
                    loss_11 = torch.tensor(0)

                loss = loss_00 + loss_01 + loss_10 + loss_11 + loss_mean               
                tepoch.set_postfix(ut_loss = loss.item())                   
                optimizer_2.zero_grad()    
                loss.backward()
                optimizer_2.step()
                tepoch.set_description(f"epoch %2f " % epoches)
            if schedule:
                scheduler.step()
                
        print("loss g1 (label=0, sensitive=0):",loss00 )
        print("loss g2 (label=0, sensitive=1):",loss01 )
        print("loss g3 (label=1, sensitive=0):",loss10 )
        print("loss g4 (label=1, sensitive=1):",loss11 )
        print('mean loss:', loss_mean.item())
        
        model.eval()
        test_pred = []
        test_gt = []
        sense_gt = []
        female_predic = []
        female_gt = []
        male_predic = []
        male_gt = []
        correct_00, total_00 = 0, 0
        correct_01, total_01 = 0, 0
        correct_10, total_10 = 0, 0
        correct_11, total_11 = 0, 0

    # Evaluate on test set.
        for step, (test_input, test_target, sensitive) in enumerate(test_data_loader):
            test_input = test_input.to(device)
            test_target = test_target.to(device)

            gt = test_target.detach().cpu().numpy()
            sen = sensitive.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)

            with torch.no_grad():
                test_fature = model(test_input)
                test_pred_ = classifier(test_fature)
                _, predic = torch.max(test_pred_.data, 1)
                predic = predic.detach().cpu()
                test_pred.extend(predic.numpy())
                label = test_target.squeeze().detach().cpu()
                mask_00 = ((label == 0) & (sensitive == 0))
                mask_01 = ((label == 0) & (sensitive == 1))
                mask_10 = ((label == 1) & (sensitive == 0))
                mask_11 = ((label == 1) & (sensitive == 1))

                correct_00 += (predic[mask_00] == label[mask_00]).float().sum().item()
                total_00 += mask_00.float().sum().item()

                correct_01 += (predic[mask_01] == label[mask_01]).float().sum().item()
                total_01 += mask_01.float().sum().item()

                correct_10 += (predic[mask_10] == label[mask_10]).float().sum().item()
                total_10 += mask_10.float().sum().item()

                correct_11 += (predic[mask_11] == label[mask_11]).float().sum().item()
                total_11 += mask_11.float().sum().item() 
        acc_00 = correct_00 / total_00
        acc_01 = correct_01 / total_01
        acc_10 = correct_10 / total_10
        acc_11 = correct_11 / total_11

        print(f'Accuracy for y=0, s=0: {acc_00}')
        print(f'Accuracy for y=0, s=1: {acc_01}')
        print(f'Accuracy for y=1, s=0: {acc_10}')
        print(f'Accuracy for y=1, s=1: {acc_11}')       
        for i in range(len(sense_gt)):
            if sense_gt[i] == 0:
                female_predic.append(test_pred[i])
                female_gt.append(test_gt[i])
            else:
                male_predic.append(test_pred[i])
                male_gt.append(test_gt[i])
        female_CM = confusion_matrix(female_gt, female_predic)    
        male_CM = confusion_matrix(male_gt, male_predic) 
        female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
        male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
        female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
        male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
        female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
        male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])
        print('Female TPR', female_TPR)
        print('male TPR', male_TPR)
        print('DP',abs(female_dp - male_dp))
        print('EOP', abs(female_TPR - male_TPR))
        print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
        print('acc', accuracy_score(test_gt, test_pred))
        print('Trade off',accuracy_score(test_gt, test_pred)*(1-0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR))) )

train_model()