In [1]:
import torch
import clip
import pandas as pd
import numpy as np
import json

In [2]:
available_models = clip.available_models()
print(available_models)

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [3]:
print('\nLoading model...')

clip_model = 'ViT-B/16'

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(clip_model, device=device, jit=False)
print(f"Done! Model {clip_model} loaded to {device} device")


Loading model...
Done! Model ViT-B/16 loaded to cuda device


In [4]:
img_embs = pd.read_pickle('../data/fface_val_img_embs.pkl')

txt_embs = torch.load('../data/synms-gender-labels.pt')

with open('../data/synms_gender_labels.json', encoding='utf-8') as json_data:
    data = json.load(json_data)
    fface_classes = list(data.keys())
    prompts = list(data.values())
    fface_prompts = prompts[0] + prompts[1]

fface_df = pd.read_csv('../data/fface_val.csv')

In [5]:
def get_similarities(img, txts):
    """Grab similarity between text and image embeddings."""
    image_features = torch.from_numpy(img).to('cuda')
    similarity = 100.0 * image_features @ txts.T

    return similarity

In [6]:
fface_df.drop(columns=['service_test'], inplace=True)
fface_df.head()

Unnamed: 0,file,age,gender,race
0,val/1.jpg,3-9,Male,East Asian
1,val/2.jpg,50-59,Female,East Asian
2,val/3.jpg,30-39,Male,White
3,val/4.jpg,20-29,Female,Latino_Hispanic
4,val/5.jpg,20-29,Male,Southeast Asian


In [7]:
single_img_emb = img_embs.iloc[0]['embeddings']
single_img_sims = get_similarities(single_img_emb, txt_embs)

In [8]:
single_img_sims

tensor([[24.6875, 18.7656, 21.6719, 22.2500, 24.4375, 25.9688, 23.2969, 22.4688,
         22.8906, 22.3125, 19.5938, 16.0938, 17.7656, 20.6406, 19.7344, 22.2656,
         21.2344, 21.4531, 21.6719, 21.3594]], device='cuda:0',
       dtype=torch.float16)

In [9]:
sims_dict = {}
final_dict = {}
for label, score in zip(fface_prompts, single_img_sims[0]):
    sims_dict[label] = score.cpu().numpy().item()
final_dict[img_embs.iloc[0]['file']] = sims_dict

In [10]:
final_dict

{'val/1.jpg': {'young man': 24.6875,
  'adult male': 18.765625,
  'male': 21.671875,
  'man': 22.25,
  'guy': 24.4375,
  'boy': 25.96875,
  'middle-aged man': 23.296875,
  'old man': 22.46875,
  'grandfather': 22.890625,
  'grandpa': 22.3125,
  'young woman': 19.59375,
  'adult female': 16.09375,
  'female': 17.765625,
  'woman': 20.640625,
  'lady': 19.734375,
  'girl': 22.265625,
  'madam': 21.234375,
  'old woman': 21.453125,
  'grandmother': 21.671875,
  'grandma': 21.359375}}

In [11]:
final_dict = {}
for idx, emb in img_embs.iterrows():
    name = emb['file']
    img_features = emb['embeddings']
    img_sims = get_similarities(img_features, txt_embs)
    sims_dict = {}
    for label, score in zip(fface_prompts, single_img_sims[0]):
        sims_dict[label] = score.cpu().numpy().item()
    final_dict[name] = sims_dict

In [17]:
sims_df = pd.DataFrame(data=final_dict)

In [18]:
sims_df.head()

Unnamed: 0,val/1.jpg,val/2.jpg,val/3.jpg,val/4.jpg,val/5.jpg,val/6.jpg,val/7.jpg,val/8.jpg,val/9.jpg,val/10.jpg,...,val/10945.jpg,val/10946.jpg,val/10947.jpg,val/10948.jpg,val/10949.jpg,val/10950.jpg,val/10951.jpg,val/10952.jpg,val/10953.jpg,val/10954.jpg
young man,24.6875,24.6875,24.6875,24.6875,24.6875,24.6875,24.6875,24.6875,24.6875,24.6875,...,24.6875,24.6875,24.6875,24.6875,24.6875,24.6875,24.6875,24.6875,24.6875,24.6875
adult male,18.765625,18.765625,18.765625,18.765625,18.765625,18.765625,18.765625,18.765625,18.765625,18.765625,...,18.765625,18.765625,18.765625,18.765625,18.765625,18.765625,18.765625,18.765625,18.765625,18.765625
male,21.671875,21.671875,21.671875,21.671875,21.671875,21.671875,21.671875,21.671875,21.671875,21.671875,...,21.671875,21.671875,21.671875,21.671875,21.671875,21.671875,21.671875,21.671875,21.671875,21.671875
man,22.25,22.25,22.25,22.25,22.25,22.25,22.25,22.25,22.25,22.25,...,22.25,22.25,22.25,22.25,22.25,22.25,22.25,22.25,22.25,22.25
guy,24.4375,24.4375,24.4375,24.4375,24.4375,24.4375,24.4375,24.4375,24.4375,24.4375,...,24.4375,24.4375,24.4375,24.4375,24.4375,24.4375,24.4375,24.4375,24.4375,24.4375
