In [2]:
import clip
import torch
import pandas as pd
import numpy as np
from PIL import Image

In [3]:
def filter_df(df, race=None, gender=None):
    new_df = df
    if gender:
        new_df = new_df[new_df['gender'] == gender]
    if race:
        new_df = new_df[new_df['race'] == race]
    return new_df

In [4]:
%matplotlib inline

print('\nLoading model...')
available_models = ['RN50', 'RN101', 'RN50x4', 'RN50x16']
layers = ['layer4', 'layer3', 'layer2', 'layer1']

clip_model = available_models[0]
saliency_layer = layers[0]

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(clip_model, device=device, jit=False)
print(f"Done! Model loaded to {device} device")


Loading model...
Done! Model loaded to cuda device


In [5]:
path = "/home/lazye/Documents/ufrgs/mcs/datasets/FairFace/"
fface_df = pd.read_csv(f"{path}/train/fairface_label_train.csv")

In [6]:
man_df = filter_df(fface_df, 'White', 'Male')
woman_df = filter_df(fface_df, 'White', 'Female')

In [7]:
man_df['gender'].value_counts()

gender
Male    8701
Name: count, dtype: int64

In [12]:
def generate_embeddings_dataframe(df):
    files = list()
    embs = list()

    for file in df:
        img_path = path + file
        img = Image.open(img_path)
        img_input = preprocess(img).unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = model.encode_image(img_input)

        image_features /= image_features.norm(dim=-1, keepdim=True)
        files.append(file)
        embs.append(image_features.cpu().numpy())

    d = {'file': files, 'embeddings': embs}

    df_out = pd.DataFrame(data=d)
    return df_out

In [14]:
woman_embs_df = generate_embeddings_dataframe(woman_df['file'])

In [22]:
man_embds_df = generate_embeddings_dataframe(man_df['file'])

In [23]:
man_embds_df.to_pickle('man_embeddings.csv')

In [17]:
woman_embs_df.to_pickle('woman_embeddings.csv')

In [18]:
loaded_embds = pd.read_pickle('woman_embeddings.csv')

In [21]:
loaded_embds.iloc[0]['embeddings']

array([[-0.003   ,  0.02199 , -0.01199 , ...,  0.004536,  0.0315  ,
        -0.02223 ]], dtype=float16)