In [1]:
from torchvision.datasets import Food101
import numpy as np
import os
import torch
import clip
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import copy
import pandas as pd 
import matplotlib.pyplot as plt
from IPython.display import Image, display
%matplotlib inline
import sys
sys.path.insert(0, '../')
import utils as ut
import scipy.io
from sklearn.metrics import accuracy_score
from scipy.special import softmax
import json 
import shutil

In [2]:
device = "cuda:0" #if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [3]:
dataset = Food101("../../Food101", split='test', transform=preprocess, download=True)
dataset_orig = Food101("../../Food101", split='test', download=True)

In [4]:
len(dataset)

25250

In [5]:
foods = pd.read_csv("../../Food101/food-101/meta/labels.txt", header = None)

In [6]:
foods

Unnamed: 0,0
0,Apple pie
1,Baby back ribs
2,Baklava
3,Beef carpaccio
4,Beef tartare
...,...
96,Tacos
97,Takoyaki
98,Tiramisu
99,Tuna tartare


In [7]:
def get_features_food(dataset, model):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            
            features = model.encode_image(images.to(device))
            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features), torch.cat(all_labels).cpu().numpy()

In [8]:
features, labels =  get_features_food(dataset, model)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 253/253 [02:43<00:00,  1.55it/s]


In [9]:
features /= features.norm(dim=-1, keepdim=True)

In [21]:
text_inputs = torch.cat([clip.tokenize(f"a photo of {word.lower()}, a type of food.") for word in foods[0].values]).to(device)

with torch.no_grad():
    text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * features @ text_features.T).softmax(dim=-1).cpu().numpy().astype(np.float64)

    predictions = np.argmax(similarity,axis=1)

    print(f'{accuracy_score(predictions, labels):.2f}')

0.82


In [11]:
# 0.8699009900990099

In [12]:
# 0.8713267326732673

In [13]:
projection_GT,projection_inferred, MI_GT, MI_inferred, train_features, train_labels = ut.calculate_projections_ff(model, preprocess, device)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 868/868 [03:27<00:00,  4.18it/s]


Error of predicting gender train = 0.06
 unique attr 7
Error of predicting race train = 0.37


In [20]:
print("********** Fair PCA GT ***************")
for attr in ['gender', 'race']:
    text_inputs = torch.cat([clip.tokenize(f"a photo of {word.lower()}, a type of food.") for word in foods[0].values]).to(device)
    #text_inputs = torch.cat([clip.tokenize(f"a photo of a {word[0].strip()}.") for word in mat['class_names'][0]]).to(device)

    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        projection_train = projection_GT[attr]
        all_features_val_transf = projection_train.just_transform(features.cpu().numpy().astype(np.float64))
        text_features_pca = projection_train.just_transform(text_features.cpu().numpy().astype(np.float64))
        similarity = softmax(100.0 * np.matmul(all_features_val_transf, np.transpose(text_features_pca)),axis=1)
        predictions = np.argmax(similarity,axis=1)
    #     print(np.around(np.mean(predictions == labels),2))
        print(f'{accuracy_score(predictions, labels):.2f}')

********** Fair PCA GT ***************
0.82
0.81


In [15]:
# ********** Fair PCA GT ***************
# 0.8714059405940594
# 0.8624950495049505

In [22]:
print("********** Fair PCA Inf ***************")
for attr in ['gender', 'race']:
    text_inputs = torch.cat([clip.tokenize(f"a photo of {word.lower()}, a type of food.") for word in foods[0].values]).to(device)
    #text_inputs = torch.cat([clip.tokenize(f"a photo of a {word[0].strip()}.") for word in mat['class_names'][0]]).to(device)

    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        projection_train = projection_inferred[attr]
        all_features_val_transf = projection_train.just_transform(features.cpu().numpy().astype(np.float64))
        text_features_pca = projection_train.just_transform(text_features.cpu().numpy().astype(np.float64))
        similarity = softmax(100.0 * np.matmul(all_features_val_transf, np.transpose(text_features_pca)),axis=1)
        predictions = np.argmax(similarity,axis=1)
    #     print(np.around(np.mean(predictions == labels),2))
        print(f'{accuracy_score(predictions, labels):.2f}')

********** Fair PCA Inf ***************
0.82
0.81


In [23]:
print("********** MI GT ***************")

for attr in ['gender', 'race']:
    text_inputs = torch.cat([clip.tokenize(f"a photo of {word.lower()}, a type of food.") for word in foods[0].values]).to(device)
    #text_inputs = torch.cat([clip.tokenize(f"a photo of a {word[0].strip()}.") for word in mat['class_names'][0]]).to(device)
    
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    num_clip_s = [400, 256]
    mis = MI_GT[attr]
    for num_clip in num_clip_s:
        text_features_mi =text_features.cpu().numpy().astype(np.float64)[:, mis[:num_clip]]
        image_features_val = features.cpu().numpy().astype(np.float64)[:, mis[:num_clip]]
        similarity = softmax(100.0 * np.matmul(image_features_val, np.transpose(text_features_mi)),axis=1)
        predictions = np.argmax(similarity,axis=1)
        print(num_clip, attr, f'{accuracy_score(predictions, labels):.2f}')  

********** MI GT ***************
400 gender 0.79
256 gender 0.70
400 race 0.77
256 race 0.66


In [24]:
print("********** MI inf ***************")

for attr in ['gender', 'race']:
    text_inputs = torch.cat([clip.tokenize(f"a photo of {word.lower()}, a type of food.") for word in foods[0].values]).to(device)
    #text_inputs = torch.cat([clip.tokenize(f"a photo of a {word[0].strip()}.") for word in mat['class_names'][0]]).to(device)
    
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    num_clip_s = [400, 256]
    mis = MI_inferred[attr]
    for num_clip in num_clip_s:
        text_features_mi =text_features.cpu().numpy().astype(np.float64)[:, mis[:num_clip]]
        image_features_val = features.cpu().numpy().astype(np.float64)[:, mis[:num_clip]]
        similarity = softmax(100.0 * np.matmul(image_features_val, np.transpose(text_features_mi)),axis=1)
        predictions = np.argmax(similarity,axis=1)
        print(num_clip, attr, f'{accuracy_score(predictions, labels):.2f}')

********** MI inf ***************
400 gender 0.79
256 gender 0.68
400 race 0.78
256 race 0.67


In [16]:
print("********** Prompt ***************")

import sys
sys.path.insert(1, '../debias-vision-lang')
import debias_clip

********** Prompt ***************


In [17]:
print("********** Prompt ***************")
device = "cuda:1"
deb_clip_model, deb_preprocess = debias_clip.load("ViT-B/16-gender", device=device)
dataset_deb = Food101("../../Food101", split='test', transform=deb_preprocess, download=True)
deb_clip_model.eval()   
features_deb, labels_deb =  get_features_food(dataset_deb, deb_clip_model)
features_deb /= features_deb.norm(dim=-1, keepdim=True)

text_inputs = torch.cat([clip.tokenize(f"a photo of {word.lower()}, a type of food.") for word in foods[0].values]).to("cpu")
deb_clip_model_cpu, deb_preprocess = debias_clip.load("ViT-B/16-gender", device='cpu')
deb_clip_model.eval()
with torch.no_grad():
#     deb_clip_model = deb_clip_model.to("cpu") # didn't work! 
    text_features_deb = deb_clip_model_cpu.encode_text(text_inputs).to(torch.float16)
    text_features_deb = text_features_deb.to(device)
text_features_deb /= text_features_deb.norm(dim=-1, keepdim=True)
similarity_deb = (100.0 * features_deb @ text_features_deb.T).softmax(dim=-1).cpu().numpy().astype(np.float64)
predictions = np.argmax(similarity_deb,axis=1)
print(f'{accuracy_score(predictions, labels):.2f}')

********** Prompt ***************
Installing pretrained embedings
 best_ndkl_oai-clip-vit-b-16_neptune_run_OXVLB-317_model_e4_step_5334_embeddings.pt...


100%|█████████████████████████████████████| 4.73k/4.73k [00:00<00:00, 9.47MiB/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 253/253 [03:37<00:00,  1.16it/s]


Installing pretrained embedings
 best_ndkl_oai-clip-vit-b-16_neptune_run_OXVLB-317_model_e4_step_5334_embeddings.pt...


100%|█████████████████████████████████████| 4.73k/4.73k [00:00<00:00, 10.1MiB/s]


0.8734653465346535
