In [17]:
from torchvision.datasets import CelebA
import numpy as np
import os
import torch
import clip
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import copy
import pandas as pd 
import matplotlib.pyplot as plt
from IPython.display import Image, display
%matplotlib inline
import sys
sys.path.insert(0, '../')
import utils as ut
import importlib

In [18]:
importlib.reload(ut)

<module 'utils' from '/mnt/efs/fairclip/FinalCode/CelebA/../utils.py'>

In [19]:
device = "cuda:3" #if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)
model.eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [20]:
attr_file = pd.read_csv("../../celebA/celeba/list_attr_celeba.txt", skiprows=1)

In [21]:
splits = pd.read_csv("../../celebA/celeba/list_eval_partition.txt", delim_whitespace=True, header=None, index_col=0)
splits.index.values[0][:-4]

'000001'

In [22]:
def get_features_CelebA(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            
            features = model.encode_image(images.to(device))
            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features), torch.cat(all_labels).cpu().numpy()

In [23]:
CelebA_ds = CelebA("../../celebA", split="test", transform=None, download=True)

Files already downloaded and verified


In [24]:
def get_CelebA(split='test'):
    
    assert split in ["train", "test", "val", "all"], "split must be either 'train', 'test', 'val' or 'all' for CelebA"
    if split=='val':
        split = 'valid' 
    
#     root = os.path.expanduser("~/efs-clip-experiments")
    CelebA_ds = CelebA("../../celebA", split=split, transform=preprocess, download=True)
    CelebA_features, CelebA_labels = get_features_CelebA(CelebA_ds)
    # normalizing features
    CelebA_features /= CelebA_features.norm(dim=-1, keepdim=True)
    
    CelebA_labels = {
#                     'arched_eyebrow': CelebA_labels[:,1],
#                     'baggy_eyes' : CelebA_labels[:,3], 
#                     'big_lips': CelebA_labels[:,6], 
#                     'big_nose': CelebA_labels[:,7],
                    'black_hair': CelebA_labels[:,8],
                    'blond_hair': CelebA_labels[:,9],
                    'brown_hair' : CelebA_labels[:,11],
#                     'bushy_eyebrows' : CelebA_labels[:,12],
#                     'chubby' : CelebA_labels[:,13],
#                     'double_chin' : CelebA_labels[:,14],
                    'glasses' : CelebA_labels[:,15],
#                     'high_Cheekbones' : CelebA_labels[:,19],
                    'gender': CelebA_labels[:,20],
                    #'race': 1 * np.logical_and(CelebA_labels[:,6]==1, CelebA_labels[:,8]==1),
#                     'oval_face': CelebA_labels[:,25],
#                     'pointy_nose': CelebA_labels[:,27], # remove this 
                    
                    'smiling': CelebA_labels[:,31],
                    #'straight_hair': CelebA_labels[:,32],  
                    'wavy_hair': CelebA_labels[:,33],  
        
                    'earrings': CelebA_labels[:,34],  
                    'hat': CelebA_labels[:,35], 
                     
                    'necktie': CelebA_labels[:,38],
#                     'necklace': CelebA_labels[:,37],
                     
                    }
    CelebA_attr_to_int_dict = {'gender': {'female': 0, 'male': 1}, 
                              # 'race': {'not-black': 0, 'black': 1}
                              }
    CelebA_int_to_attr_dict = {'gender': {0: 'female', 1: 'male'}, 
                               #'race': {0: 'not-black', 1: 'black'}
                              }
    
    group_sizes = copy.deepcopy(CelebA_attr_to_int_dict)
    for attr in group_sizes.keys():
        for group_name, group_val in group_sizes[attr].items():
            group_sizes[attr][group_name] = np.sum(CelebA_labels[attr]==group_val)
            
    CelebA_ = {
        'features': CelebA_features,
        'labels': CelebA_labels,
        'int_to_attr': CelebA_int_to_attr_dict,
        'attr_to_int': CelebA_attr_to_int_dict,
        'nr_groups_to_consider': {'gender': 2, 
                                 # 'race': 2,
                                 },
        'group_sizes': group_sizes 
    }
    return CelebA_


In [25]:
data = get_CelebA("test")

Files already downloaded and verified


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:01<00:00,  3.27it/s]


In [26]:
for k, v in data['labels'].items():
    gender = data['labels']['gender']
    idx_female_l = np.where(np.logical_and(gender == 0, v == 1))[0]
    idx_male_l = np.where(np.logical_and(gender == 1, v == 1))[0]
    num = np.where(v)[0].shape[0]
    print(f"{k}, {np.mean(v):.2f} {num} {idx_female_l.shape[0]/sum(v) :.2f} {idx_male_l.shape[0]/sum(v) :.2f}")

black_hair, 0.27 5422 0.53 0.47
blond_hair, 0.13 2660 0.93 0.07
brown_hair, 0.18 3587 0.74 0.26
glasses, 0.06 1289 0.25 0.75
gender, 0.39 7715 0.00 1.00
smiling, 0.50 9987 0.69 0.31
wavy_hair, 0.36 7267 0.84 0.16
earrings, 0.21 4125 0.96 0.04
hat, 0.04 839 0.33 0.67
necktie, 0.07 1399 0.01 0.99


In [27]:
def run_relevance_celeba(similarities, all_labels, desired_cat,  fname):
	topks = [20, 50, 100]
	query_dict = {}
	for  idx, k in enumerate(queries):

		query_dict[k] = similarities[idx]
		
	sorted_idx = {}
	for k in query_dict.keys():
		s = np.asarray(query_dict[k])
		ind_sorted = s.argsort()
		sorted_idx[k] = []
		for topk in topks:
			ind = ind_sorted[-topk:][::-1]
			sorted_idx[k].append(ind)

	col_test = ['Query', 'precision_top_20', 'precision_top_50', 'precision_top_100']
	df_skews = pd.DataFrame(columns = col_test)
    
        
	for k_idx, (k, v) in enumerate(sorted_idx.items()):
		
		row = [k]
		print_ = False
		if k_idx == 0:
			print_ = False
		for topi, v_ in enumerate(v):
			# print(k, len(v_))
			# check the categories of the retrieved
			
			row.append(np.sum(all_labels
                              [desired_cat[k_idx]][v_])/ topks[topi]) #np.sum(all_labels[desired_cat[k_idx]])
		df_row = pd.DataFrame([row], columns= col_test)
		df_skews = df_skews.append(df_row, ignore_index=True)

	df_skews = df_skews.round(2)
	df_skews.to_csv(f"../results_csv/{fname}_relevance.csv")
	print(df_skews)
	return df_skews

In [28]:
desired_labels = list(data['labels'].keys())
desired_labels.remove('gender')

print(desired_labels)


['black_hair', 'blond_hair', 'brown_hair', 'glasses', 'smiling', 'wavy_hair', 'earrings', 'hat', 'necktie']


In [29]:
print(desired_labels[0], np.sum(data['labels'][desired_labels[0]]))


black_hair 5422


In [30]:
queries = [  "a person with black hair", "a person with blond hair", \
           "a person with brown hair",
            "a person with glasses", \
            \
             "a person who is smiling", \
#            "a person with straight hair",\
           "a person with wavy hair", \
           "a person wearing earrings", "a person wearing a hat", \
           "a person wearing a tie"   
                         
          ]
#"a person with straight hair", 
#            "a person with wavy hair", 
#"a person with a pointy nose", "a person with an oval face","a person with high cheekbones",, 
#             "a person wearing a necklace" , "a person with a double chin",
#"a person with arched eyebrows", \
#            "a person with bags under the eyes", "a person with big lips",  \
#            "a person with a big nose",
# "a person with bushy eyebrows", "a chubby person",\

text_tokens = clip.tokenize(["a photo of " + desc for desc in queries]).to(device)

with torch.no_grad():
    text_features = model.encode_text(text_tokens)#.float()
text_features /= text_features.norm(dim=-1, keepdim=True)

# all_features_val /= all_features_val.norm(dim=-1, keepdim=True)
similarity = (100.0 * data['features'] @ text_features.T).cpu().numpy().astype(np.float64).T

In [31]:
run_relevance_celeba(similarity, data['labels'], desired_labels, "orig_celeba")

                       Query  precision_top_20  precision_top_50  \
0   a person with black hair              0.95              0.88   
1   a person with blond hair              0.60              0.62   
2   a person with brown hair              0.50              0.54   
3      a person with glasses              1.00              1.00   
4    a person who is smiling              1.00              1.00   
5    a person with wavy hair              0.90              0.90   
6  a person wearing earrings              0.65              0.70   
7     a person wearing a hat              0.95              0.98   
8     a person wearing a tie              0.60              0.66   

   precision_top_100  
0               0.84  
1               0.70  
2               0.57  
3               0.98  
4               1.00  
5               0.90  
6               0.72  
7               0.97  
8               0.56  


Unnamed: 0,Query,precision_top_20,precision_top_50,precision_top_100
0,a person with black hair,0.95,0.88,0.84
1,a person with blond hair,0.6,0.62,0.7
2,a person with brown hair,0.5,0.54,0.57
3,a person with glasses,1.0,1.0,0.98
4,a person who is smiling,1.0,1.0,1.0
5,a person with wavy hair,0.9,0.9,0.9
6,a person wearing earrings,0.65,0.7,0.72
7,a person wearing a hat,0.95,0.98,0.97
8,a person wearing a tie,0.6,0.66,0.56


In [None]:
projection_GT,projection_inferred, MI_GT, MI_inferred = ut.calculate_projections_coco(model, preprocess, device)

loading annotations into memory...
Done (t=10.91s)
creating index...
index created!
loading annotations into memory...
Done (t=0.67s)
creating index...
index created!


 34%|██████████████████████████████████████████████████████████▊                                                                                                                 | 283/828 [07:55<15:04,  1.66s/it]

In [None]:
print("======== Running Fair pca G.T on the model ============== ")

for attr in ['gender']:
    
    text_inputs = torch.cat([clip.tokenize(f"a photo of {word}") for word in queries]).to(device)
#     text_inputs = clip.tokenize(["a photo of " + desc for desc in queries]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    projection_train = projection_GT
    all_features_val_transf = projection_train.just_transform(data['features'].cpu().numpy().astype(np.float64))
    text_features_pca = projection_train.just_transform(text_features.cpu().numpy().astype(np.float64))
    similarity = (100.0 * all_features_val_transf @ text_features_pca.T).T
    retrieval_fpca_gt = run_relevance_celeba(similarity, data['labels'], desired_labels, "fpca_gt_celeba")
#     print(retrieval_fpca)

In [None]:
print("======== Running Fair pca INF on the model ============== ")

for attr in ['gender']:
    
    text_inputs = torch.cat([clip.tokenize(f"a photo of {word}") for word in queries]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    projection_train = projection_inferred
    all_features_val_transf = projection_train.just_transform(data['features'].cpu().numpy().astype(np.float64))
    text_features_pca = projection_train.just_transform(text_features.cpu().numpy().astype(np.float64))
    similarity = (100.0 * all_features_val_transf @ text_features_pca.T).T
    retrieval_fpca_inf = run_relevance_celeba(similarity, data['labels'], desired_labels, "fpca_inf_celeba")

In [None]:
for attr in ['gender']:
    
    text_inputs = torch.cat([clip.tokenize(f"a photo of {word}") for word in queries]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    text_features = text_features.cpu().numpy().astype(np.float64)
    num_clip_s = [400, 256]
    mis = MI_GT#[attr]
    for num_clip in num_clip_s:
        print(f"..... {num_clip}.........")
        
        text_features_mi =text_features[:, mis[:num_clip]]
        image_features_val = data['features'].cpu().numpy().astype(np.float64)[:, mis[:num_clip]]
        similarity = (100.0 * image_features_val @ text_features_mi.T).T 
        run_relevance_celeba(similarity, data['labels'], desired_labels, f"MI_gt{num_clip}_celeba")        
             

In [None]:
for attr in ['gender']:
    
    text_inputs = torch.cat([clip.tokenize(f"a photo of {word}") for word in queries]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    text_features = text_features.cpu().numpy().astype(np.float64)
    num_clip_s = [400, 256]
    mis = MI_inferred#[attr]
    for num_clip in num_clip_s:
        print(f"..... {num_clip}.........")
        
        text_features_mi =text_features[:, mis[:num_clip]]
        image_features_val = data['features'].cpu().numpy().astype(np.float64)[:, mis[:num_clip]]
        similarity = (100.0 * image_features_val @ text_features_mi.T).T 
        run_relevance_celeba(similarity, data['labels'], desired_labels, f"MI_inf{num_clip}_celeba") 