In [2]:
import clip
import torch
import pandas as pd
import numpy as np
from PIL import Image
import json

In [3]:
available_models = clip.available_models()
print(available_models)

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [4]:
print('\nLoading model...')

clip_model = 'ViT-B/16'

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(clip_model, device=device, jit=False)
print(f"Done! Model {clip_model} loaded to {device} device")


Loading model...
Done! Model ViT-B/16 loaded to cuda device


In [5]:
def generate_embeddings_dataframe(df):
    files = list()
    embs = list()

    for file in df:
        img_path = path + file
        img = Image.open(img_path)
        img_input = preprocess(img).unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = model.encode_image(img_input)

        image_features /= image_features.norm(dim=-1, keepdim=True)
        files.append(file)
        embs.append(image_features.cpu().numpy())

    d = {'file': files, 'embeddings': embs}

    df_out = pd.DataFrame(data=d)
    return df_out

In [6]:
def generate_text_embeddings(txts, model, device):
    """Generate text embeddings using CLIP model"""
    text_inputs = torch.cat(
        [clip.tokenize(f"a photo of a {c}") for c in txts]).to(device)

    with torch.no_grad():
        text_features = model.encode_text(text_inputs)

    text_features /= text_features.norm(dim=-1, keepdim=True)

    return text_features

In [7]:
path = "/home/lazye/Documents/ufrgs/mcs/datasets/FairFace/"
fface_df = pd.read_csv(f"../data/fface_val.csv")

In [8]:
with open('../data/synms_gender_labels.json', encoding='utf-8') as json_data:
    labels = json.load(json_data)

In [9]:
class_labels = list(labels.keys())
prompts = list(labels.values())

In [10]:
unified_prompts = prompts[0] + prompts[1]

In [11]:
tokenized_prompts = torch.cat([clip.tokenize(prompt) for prompt in unified_prompts]).to(device)
with torch.no_grad():
    prompt_features = model.encode_text(tokenized_prompts)
    prompt_features /= prompt_features.norm(dim=-1, keepdim=True)

In [12]:
torch.save(prompt_features, "../data/synms-gender-labels.pt")