In [5]:
import clip
import torch
import pandas as pd
import numpy as np
from PIL import Image
import json

In [6]:
def filter_df(df, race=None, gender=None):
    new_df = df
    if gender:
        new_df = new_df[new_df['gender'] == gender]
    if race:
        new_df = new_df[new_df['race'] == race]
    return new_df

In [7]:
print('\nLoading model...')
available_models = ['RN50', 'RN101', 'RN50x4', 'RN50x16']
layers = ['layer4', 'layer3', 'layer2', 'layer1']

clip_model = available_models[0]
saliency_layer = layers[0]

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(clip_model, device=device, jit=False)
print(f"Done! Model loaded to {device} device")


Loading model...


Done! Model loaded to cuda device


In [8]:
def generate_embeddings_dataframe(df):
    files = list()
    embs = list()

    for file in df:
        img_path = path + file
        img = Image.open(img_path)
        img_input = preprocess(img).unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = model.encode_image(img_input)

        image_features /= image_features.norm(dim=-1, keepdim=True)
        files.append(file)
        embs.append(image_features.cpu().numpy())

    d = {'file': files, 'embeddings': embs}

    df_out = pd.DataFrame(data=d)
    return df_out

In [9]:
def generate_text_embeddings(txts, model, device):
    """Generate text embeddings using CLIP model"""
    text_inputs = torch.cat(
        [clip.tokenize(f"a photo of a {c}") for c in txts]).to(device)

    with torch.no_grad():
        text_features = model.encode_text(text_inputs)

    text_features /= text_features.norm(dim=-1, keepdim=True)

    return text_features

In [41]:
gender_json = '../data/gender-synms.json'

with open(gender_json, encoding='utf-8') as json_data:
    gender_synms = json.load(json_data)

In [42]:
gender_synms_list = gender_synms['synms']

In [43]:
gender_feats = generate_text_embeddings(gender_synms_list, model, device)
torch.save(gender_feats, '../data/gender-synms-embs.pt')

In [31]:
def get_similarities(self, img_embs, txt_embs):
    """Grab similarity between classes and image embeddings."""
    image_features = torch.from_numpy(img_embs).to(self.device)
    similarity = 100.0 * image_features @ txt_embs.T

    return similarity