In [1]:
import clip
import torchvision.datasets
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.utils import make_grid
from PIL import Image
from tqdm import tqdm
import random
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
import open_clip

device = torch.device("cuda:0")

# model,_, preprocess =  open_clip.create_model_and_transforms("ViT-B/32", pretrained='openai') #ViTB/32
# model = model.to(device)
# tokenizer = open_clip.get_tokenizer('ViT-B-32')


# model, preprocess = clip.load('RN50', device)
# model = model.to(device)
# tokenizer = open_clip.get_tokenizer('RN50')


model,_, preprocess =  open_clip.create_model_and_transforms("ViT-L-14", pretrained='laion2b_s32b_b82k') #ViTL/14
model = model.to(device)
tokenizer = open_clip.get_tokenizer('ViT-L-14')



torch.set_num_threads(5)   # Sets the number of threads used for intra-operations
torch.set_num_interop_threads(5)   # Sets the number of threads used for inter-operations




def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything(4096)   

device_id = 1
image_size = 64
batch_size = 512



root_dir =  '../celeba/datasets/celeba/img_align_celeba/'
csv_file = '../celeba/datasets/celeba/metadata.csv'
data_frame = pd.read_csv(csv_file)
data_frame.replace(-1, 0, inplace=True)



class CustomDataset(Dataset):
    def __init__(self, csv_file, y, a, root_dir,split, transform):
        self.data_frame = csv_file
        self.data_frame = self.data_frame[self.data_frame['split'] == split].reset_index(drop=True)
        self.root_dir = root_dir
        self.transform = transform
        self.targets = self.data_frame[y].values
        self.biases = self.data_frame[a].values
        self.i = list(range(self.targets.shape[0]))

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx, 0])
        image = Image.open(img_name)
        target = self.targets[idx]
        targets = torch.tensor(target) 
        sensitive = self.biases[idx]
        biases = torch.tensor(sensitive)
        img = preprocess(image)
        img_for_res = self.transform(image)
            
        return img, targets, biases, img_for_res

target = 'Blond_Hair'
sensitive = 'Male'
    
transform=tvt.Compose([tvt.Resize((256,256)),
                               tvt.RandomResizedCrop(
                                    (224,224),
                                    scale=(0.7, 1.0),
                                    ratio=(0.75, 1.3333333333333333),
                                    interpolation=2),
                                tvt.RandomHorizontalFlip(),
                                tvt.ToTensor(),
                                tvt.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])                                  
                                ])


valid_transform=tvt.Compose([tvt.Resize((256,256)),
                               tvt.CenterCrop((224,224)),
                                tvt.ToTensor(),
                                tvt.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])                                  
                                ])

train_set = CustomDataset(csv_file=data_frame, y = target, a= sensitive, root_dir=root_dir, split=0, transform=transform)
training_data_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_set = CustomDataset(csv_file=data_frame, y = target, a= sensitive, root_dir=root_dir, split=1, transform=valid_transform)
valid_data_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)
test_set = CustomDataset(csv_file=data_frame, y = target, a= sensitive, root_dir=root_dir, split=2, transform=valid_transform)
test_data_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
print('Done')
data_frame

  from .autonotebook import tqdm as notebook_tqdm


Done


  "Argument interpolation should be of type InterpolationMode instead of int. "


Unnamed: 0,image_id,partition,5_o_Clock_Shadow,Arched_Eyebrows,Attractive,Bags_Under_Eyes,Bald,Bangs,Big_Lips,Big_Nose,...,Smiling,Straight_Hair,Wavy_Hair,Wearing_Earrings,Wearing_Hat,Wearing_Lipstick,Wearing_Necklace,Wearing_Necktie,Young,split
0,000001.jpg,0,0,1,1,0,0,0,0,0,...,1,1,0,1,0,1,0,0,1,0
1,000002.jpg,0,0,0,0,1,0,0,0,1,...,1,0,0,0,0,0,0,0,1,0
2,000003.jpg,0,0,0,0,0,0,0,1,0,...,0,0,1,0,0,0,0,0,1,0
3,000004.jpg,0,0,0,1,0,0,0,0,0,...,0,1,0,1,0,1,1,0,1,0
4,000005.jpg,0,0,1,1,0,0,0,1,0,...,0,0,0,0,0,1,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
202594,202595.jpg,2,0,0,1,0,0,0,1,0,...,0,0,0,0,0,1,0,0,1,2
202595,202596.jpg,2,0,0,0,0,0,1,1,0,...,1,1,0,0,0,0,0,0,1,2
202596,202597.jpg,2,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,1,2
202597,202598.jpg,2,0,1,1,0,0,0,1,0,...,1,0,1,1,0,1,0,0,1,2


In [2]:
t = iter(training_data_loader)
image,y,a,_ = next(t)

In [3]:
image.shape

torch.Size([512, 3, 224, 224])

In [4]:
texts = ["A photo of a male", "A photo of a female"] 
texts = ["male", "female"] 
text = tokenizer(texts).to(device)
text_features = model.encode_text(text)
male = text_features[0].unsqueeze(0).to(device)
female = text_features[1].unsqueeze(0).to(device)


In [5]:
def training_a():
    epoch =200
    weight_decay=1e-3
    init_lr=1e-4
    momentum_decay = 0.9
    schedule = False
    resnet18 = models.resnet18(pretrained=True)
    num_classes = 2 
    resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)
    
    res_model = resnet18
    res_model = res_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(res_model.parameters(), lr=init_lr, momentum=momentum_decay, weight_decay = weight_decay)
    for epoches in range(epoch):
        with tqdm(training_data_loader, unit="batch") as tepoch:
            res_model.train()         
            for _, _, sensitive, train_input in tepoch:  #img_clip, y, a, x_for_resnet
                train_input = train_input.to(device)
                label = sensitive.detach().cpu()
                one_hot_labels = F.one_hot(label, num_classes=2)
                train_target = one_hot_labels.float().to(device)
                outputs = res_model(train_input)
                loss = criterion(outputs, train_target)
                tepoch.set_postfix(ut_loss = loss.item()) 
                optimizer.zero_grad()    
                loss.backward()
                optimizer.step()
                tepoch.set_description(f"epoch %2f " % epoches)
        
        if schedule:
            scheduler.step()
    super_a_test(res_model)
    torch.save(res_model.state_dict(), 'res_net_celebA.pth')
    
def super_a_test(model):
    model.eval()
    correct_00, total_00 = 0, 0
    correct_01, total_01 = 0, 0
    correct_10, total_10 = 0, 0
    correct_11, total_11 = 0, 0
    
    for step, (_, test_target, sensitive,test_input) in enumerate(tqdm(test_data_loader, desc="Testing")):
        with torch.no_grad():
            test_target = test_target
            sensitive = sensitive
            test_input = test_input.to(device)

            test_pred_ = model(test_input)
            _, predic = torch.max(test_pred_.data, 1)
            predic = predic.detach().cpu()
                
            
            mask_00 = ((test_target == 0) & (sensitive == 0))
            mask_01 = ((test_target == 0) & (sensitive == 1))
            mask_10 = ((test_target == 1) & (sensitive == 0))
            mask_11 = ((test_target == 1) & (sensitive == 1))


            correct_00 += (predic[mask_00] == sensitive[mask_00]).float().sum().item()
            total_00 += mask_00.float().sum().item()

            correct_01 += (predic[mask_01] == sensitive[mask_01]).float().sum().item()
            total_01 += mask_01.float().sum().item()

            correct_10 += (predic[mask_10] == sensitive[mask_10]).float().sum().item()
            total_10 += mask_10.float().sum().item()

            correct_11 += (predic[mask_11] == sensitive[mask_11]).float().sum().item()
            total_11 += mask_11.float().sum().item() 
            
            
    acc_00 = correct_00 / total_00
    acc_01 = correct_01 / total_01
    acc_10 = correct_10 / total_10
    acc_11 = correct_11 / total_11

    print(f'Accuracy for y=0, s=0: {acc_00}')
    print(f'Accuracy for y=0, s=1: {acc_01}')
    print(f'Accuracy for y=1, s=0: {acc_10}')
    print(f'Accuracy for y=1, s=1: {acc_11}')   

    

def inference_a_test(vlm, spu_v0, spu_v1):
    correct_00, total_00 = 0, 0
    correct_01, total_01 = 0, 0
    correct_10, total_10 = 0, 0
    correct_11, total_11 = 0, 0
    
    for step, (test_input, test_target, sensitive, _) in enumerate(tqdm(test_data_loader, desc="Testing")):
        with torch.no_grad():
            test_target = test_target.to(device)
            sensitive = sensitive.to(device)
            test_input = test_input.to(device)
            z = vlm.encode_image(test_input)
            infered_a = inference_a(vlm, female, male,z )
            
            mask_00 = ((test_target == 0) & (sensitive == 0))
            mask_01 = ((test_target == 0) & (sensitive == 1))
            mask_10 = ((test_target == 1) & (sensitive == 0))
            mask_11 = ((test_target == 1) & (sensitive == 1))


            correct_00 += (infered_a[mask_00] == sensitive[mask_00]).float().sum().item()
            total_00 += mask_00.float().sum().item()

            correct_01 += (infered_a[mask_01] == sensitive[mask_01]).float().sum().item()
            total_01 += mask_01.float().sum().item()

            correct_10 += (infered_a[mask_10] == sensitive[mask_10]).float().sum().item()
            total_10 += mask_10.float().sum().item()

            correct_11 += (infered_a[mask_11] == sensitive[mask_11]).float().sum().item()
            total_11 += mask_11.float().sum().item() 
    acc_00 = correct_00 / total_00
    acc_01 = correct_01 / total_01
    acc_10 = correct_10 / total_10
    acc_11 = correct_11 / total_11

    print(f'Accuracy for y=0, s=0: {acc_00}')
    print(f'Accuracy for y=0, s=1: {acc_01}')
    print(f'Accuracy for y=1, s=0: {acc_10}')
    print(f'Accuracy for y=1, s=1: {acc_11}')   

            



def inference_a(vlm, spu_v0, spu_v1, z):
    text_embeddings = torch.cat((spu_v0, spu_v1), dim=0)
    norm_img_embeddings = z 
    norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
    cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())
    logits_per_image = cosine_similarity 
    probs = logits_per_image.softmax(dim=1)
    _, predic = torch.max(probs.data, 1)
    return predic

            
def supervised_inference_a(img):
    resnet18 = models.resnet18(pretrained=False)
    num_classes = 2 
    resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)
    res_model = resnet18
    res_model.load_state_dict(torch.load('res_net.pth'))
    res_model = res_model.to(device)
    res_model.eval()
    img = img.to(device)
    test_pred_ = res_model(img)
    _, predic = torch.max(test_pred_.data, 1)
    return predic            
            
    
def compute_scale(vlm, spu_v0, spu_v1):
    vlm = vlm.to(device)
    scale_0 = []
    scale_1 = []
    spu0 = spu_v0  / spu_v0.norm(dim=1, keepdim=True)
    spu1 = spu_v1 / spu_v1.norm(dim=1, keepdim=True)
    
    for step, (test_input, _, sensitive, img) in enumerate(tqdm(training_data_loader, desc="Computing Scale")):  
        with torch.no_grad():
            test_input = test_input.to(device)
            z = vlm.encode_image(test_input)
            if a ==True:
                sensitive = sensitive
            else:
                if partial_a == False:
                    sensitive = inference_a(vlm, female, male,z )
                elif partial_a == True:
                    sensitive = supervised_inference_a(img)
            
            mask_0 = sensitive == 0
            mask_0 = mask_0.to(device)
            h = z[mask_0]
            inner_land = torch.mm(h/ h.norm(dim=1, keepdim=True), spu0.t())
            scale_0.extend(inner_land.detach().cpu().numpy())
                
            mask_1 = sensitive == 1
            mask_1 = mask_1.to(device)
            g = z[mask_1]
            inner_water = torch.mm(g/ g.norm(dim=1, keepdim=True), spu1.t())
            scale_1.extend(inner_water.detach().cpu().numpy())
    scale_0 = np.array(scale_0)
    scale_1 = np.array(scale_1)
    print(np.mean(scale_0))
    print(np.mean(scale_1))
    return torch.tensor(np.mean(scale_0)), torch.tensor(np.mean(scale_1))



def test_epoch(vlm,   dataloader):
    scale_0, scale_1 = compute_scale(model, female, male)
    texts_label = ['a photo of a celebrity with dark hair', 'a photo of a celebrity with blonde hair']  
    text_label_tokened = tokenizer(texts_label).to(device)
    
    vlm = vlm.to(device)
    vlm.eval()   
    test_pred = []
    test_gt = []
    sense_gt = []
    female_predic = []
    female_gt = []
    male_predic = []
    male_gt = []
    correct_00, total_00 = 0, 0
    correct_01, total_01 = 0, 0
    correct_10, total_10 = 0, 0
    correct_11, total_11 = 0, 0
    cos = nn.CosineSimilarity(dim = 0)
    feature_a0 = []
    feature_a1 = []

    for step, (test_input, test_target, sensitive_real,img) in enumerate(tqdm(dataloader, desc="Zero Shot Testing")):
        with torch.no_grad():
            gt = test_target.detach().cpu().numpy()
            sen = sensitive_real.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)
            test_input = test_input.to(device)


            z = vlm.encode_image(test_input)
            if a == True:
                sensitive = sensitive_real
            if a == False:
                if partial_a == False:
                    sensitive = inference_a(vlm, female, male,z )
                    sensitive = torch.tensor(sensitive)
                elif partial_a == True:
                    sensitive = supervised_inference_a(img)
            
            mask_0 = sensitive == 0
            mask_0 = mask_0.to(device)
            z[mask_0] -= scale_0 * female/ female.norm(dim=1, keepdim=True)
                
            mask_1 = sensitive == 1
            mask_1 = mask_1.to(device)
            z[mask_1] -= scale_1 * male/ male.norm(dim=1, keepdim=True)
            
            feature_a0.extend(z[mask_0].detach().cpu().numpy())
            feature_a1.extend(z[mask_1].detach().cpu().numpy())
            
            text_embeddings = vlm.encode_text(text_label_tokened)
            img_embeddings = z
            norm_img_embeddings = img_embeddings / img_embeddings.norm(dim=1, keepdim=True)
            norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
            cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())
                    
            logits_per_image = cosine_similarity 
            probs = logits_per_image.softmax(dim=1)
            _, predic = torch.max(probs.data, 1)
            predic = predic.detach().cpu()
            test_pred.extend(predic.numpy())
            label = test_target.squeeze().detach().cpu()
            mask_00 = ((label == 0) & (sensitive_real == 0))
            mask_01 = ((label == 0) & (sensitive_real == 1))
            mask_10 = ((label == 1) & (sensitive_real == 0))
            mask_11 = ((label == 1) & (sensitive_real == 1))


            correct_00 += (predic[mask_00] == label[mask_00]).float().sum().item()
            total_00 += mask_00.float().sum().item()

            correct_01 += (predic[mask_01] == label[mask_01]).float().sum().item()
            total_01 += mask_01.float().sum().item()

            correct_10 += (predic[mask_10] == label[mask_10]).float().sum().item()
            total_10 += mask_10.float().sum().item()

            correct_11 += (predic[mask_11] == label[mask_11]).float().sum().item()
            total_11 += mask_11.float().sum().item() 
    acc_00 = correct_00 / total_00
    acc_01 = correct_01 / total_01
    acc_10 = correct_10 / total_10
    acc_11 = correct_11 / total_11

    print(f'Accuracy for y=0, s=0: {acc_00}')
    print(f'Accuracy for y=0, s=1: {acc_01}')
    print(f'Accuracy for y=1, s=0: {acc_10}')
    print(f'Accuracy for y=1, s=1: {acc_11}')       
    
    feature_a0 = np.array(feature_a0)
    feature_a1 = np.array(feature_a1)
    a0_tensor = torch.from_numpy(np.mean(feature_a0,0))
    a1_tensor = torch.from_numpy(np.mean(feature_a1,0))

    for i in range(len(sense_gt)):
        if sense_gt[i] == 0:
            female_predic.append(test_pred[i])
            female_gt.append(test_gt[i])
        else:
            male_predic.append(test_pred[i])
            male_gt.append(test_gt[i])
    female_CM = confusion_matrix(female_gt, female_predic)    
    male_CM = confusion_matrix(male_gt, male_predic) 
    female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
    male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
    female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
    male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
    female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
    male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])
    acc = accuracy_score(test_gt, test_pred)
    #print('Female TPR', female_TPR)
    #print('male TPR', male_TPR)
    # print('DP',abs(female_dp - male_dp))
    # print('EOP', abs(female_TPR - male_TPR))
    # print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
    # print('acc', accuracy_score(test_gt, test_pred))

a = True
partial_a = False
    

model = model.to(device)
#inference_a_test(model, female, male)
test_epoch(model, test_data_loader)

Computing Scale: 100%|████████████████████████████████████████████████████████████████| 318/318 [28:56<00:00,  5.46s/it]


0.20750567
0.19806518


Zero Shot Testing: 100%|████████████████████████████████████████████████████████████████| 39/39 [03:19<00:00,  5.12s/it]

Accuracy for y=0, s=0: 0.8467287805876933
Accuracy for y=0, s=1: 0.8460517584605176
Accuracy for y=1, s=0: 0.9661290322580646
Accuracy for y=1, s=1: 0.8888888888888888



