In [3]:
import torch
import clip
import pandas as pd
import numpy as np
import json

In [4]:
available_models = clip.available_models()
print(available_models)

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [5]:
print('\nLoading model...')

clip_model = 'ViT-B/16'

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(clip_model, device=device, jit=False)
print(f"Done! Model {clip_model} loaded to {device} device")


Loading model...


100%|███████████████████████████████████████| 335M/335M [00:09<00:00, 38.1MiB/s]


Done! Model ViT-B/16 loaded to cuda device


In [6]:
with open('../data/synms_gender_labels.json', encoding='utf-8') as json_data:
    data = json.load(json_data)
    fface_classes = list(data.keys())
    fface_prompts = list(data.values())

In [7]:
man_prompts = fface_prompts[0]
woman_prompts = fface_prompts[1]

In [9]:
def get_similarities(img_embs, classes):
    image_features = torch.from_numpy(img_embs).to(device)

    text_inputs = torch.cat(
        [clip.tokenize(f"a photo of a {c}") for c in classes]).to(device)

    with torch.no_grad():
        text_features = model.encode_text(text_inputs)

    text_features /= text_features.norm(dim=-1, keepdim=True)
    
    # is this the cosine distance approximation?
    similarity = (100.0 * image_features @ text_features.T)
    return similarity

In [None]:
def get_synms_winner(sims):
    np_sims = sims.cpu().numpy()
    np_loc = np.where(np_sims[0] == np_sims.max())
    return np_loc[0][0]

In [None]:
def run_clip_classifier(img_emb, classes):
    """Run classes by CLIP to choose the closest one"""
    sims = get_similarities(img_emb, classes)
    sims_max = sims.softmax(dim=-1)
    values, indices = sims_max[0].topk(len(sims_max[0]))
    scores = []
    for value, index in zip(values, indices):
        scores.append(
            (classes[index], round(100 * value.item(), 2)))
    return scores

In [None]:
print(run_clip_classifier(woman_img_emb, final_classes))

[('woman', 95.75), ('boy', 4.27)]
