In [1]:
import clip
import torch
import pandas as pd
import numpy as np
from PIL import Image
import json

In [2]:
available_models = clip.available_models()
print(available_models)

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [3]:
print('\nLoading model...')

clip_model = available_models[-1]

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(clip_model, device=device, jit=False)
print(f"Done! Model {clip_model} loaded to {device} device")


Loading model...
Done! Model ViT-L/14@336px loaded to cuda device


In [4]:
def generate_embeddings_dataframe(df):
    files = list()
    embs = list()

    for file in df:
        img_path = path + file
        img = Image.open(img_path)
        img_input = preprocess(img).unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = model.encode_image(img_input)

        image_features /= image_features.norm(dim=-1, keepdim=True)
        files.append(file)
        embs.append(image_features.cpu().numpy())

    d = {'file': files, 'embeddings': embs}

    df_out = pd.DataFrame(data=d)
    return df_out

In [5]:
def generate_text_embeddings(txts, model, device):
    """Generate text embeddings using CLIP model"""
    text_inputs = torch.cat(
        [clip.tokenize(f"a photo of a {c}") for c in txts]).to(device)

    with torch.no_grad():
        text_features = model.encode_text(text_inputs)

    text_features /= text_features.norm(dim=-1, keepdim=True)

    return text_features

In [6]:
def get_similarities(self, img_embs, txt_embs):
    """Grab similarity between classes and image embeddings."""
    image_features = torch.from_numpy(img_embs).to(self.device)
    similarity = 100.0 * image_features @ txt_embs.T

    return similarity

In [7]:
path = "/home/lazye/Documents/ufrgs/mcs/datasets/FairFace/"
fface_df = pd.read_csv(f"../data/fface_val.csv")

In [8]:
labels = {
    'White_Male': 'a photo of a white man',
    'White_Female': 'a photo of a white woman',
    'Black_Male': 'a photo of a black man',
    'Black_Female': 'a photo of a black woman',
    'Latino_Hispanic_Male': 'a photo of a latino man',
    'Latino_Hispanic_Female': 'a photo of a latino woman',
    'East Asian_Male': 'a photo of an east asian man',
    'East Asian_Female': 'a photo of an east asian woman',
    'Southeast Asian_Male': 'a photo of a southeast asian man',
    'Southeast Asian_Female': 'a photo of a southeast asian woman',
    'Indian_Male': 'a photo of an indian man',
    'Indian_Female': 'a photo of an indian woman',
    'Middle Eastern_Male': 'a photo of a middle eastern man',
    'Middle Eastern_Female': 'a photo of a middle eastern woman'
}

In [9]:
class_labels = list(labels.keys())
prompts = list(labels.values())

In [10]:
tokenized_prompts = torch.cat([clip.tokenize(prompt) for prompt in prompts]).to(device)
with torch.no_grad():
    prompt_features = model.encode_text(tokenized_prompts)
    prompt_features /= prompt_features.norm(dim=-1, keepdim=True)

In [11]:
with open('../data/labels.json', 'w', encoding='utf-8') as f:
    json.dump(labels, f, ensure_ascii=False, indent=4)