In [2]:
import torch
from tqdm import tqdm
import pandas as pd
import open_clip
from bias_explorer.utils import dataloader

In [3]:
model_name = "ViT-B-16"
data_source = "openai"
preds_path = f"clip_aggregation_preds_{model_name}_{data_source}.csv"
report_path = f"clip_aggregation_report_{model_name}_{data_source}.csv"
img_path = f"./data/fairface/embeddings/openCLIP/{model_name}/{data_source}/generated_img_embs.pkl"
val_path = "./data/fairface/fface_val.csv"
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
img_embs = pd.read_pickle(img_path)
fface_df = pd.read_csv(val_path)
model, _, preprocessing = open_clip.create_model_and_transforms(
    model_name, pretrained=data_source, device=device)

In [4]:
clip_templates = [
    'a photo of a {}',
    'a photo of a white {}',
    'a photo of a young white {}',
    'a photo of an old white {}',
    'a photo of a black {}',
    'a photo of a young black {}',
    'a photo of an old black {}',
    'a photo of a latino {}',
    'a photo of a young latino {}',
    'a photo of an old latino {}',
    'a photo of a east asian {}',
    'a photo of a young east asian {}',
    'a photo of an old east asian {}',
    'a photo of a southeast asian {}',
    'a photo of a young southeast asian {}',
    'a photo of an old southeast asian {}',
    'a photo of an indian {}',
    'a photo of a young indian {}',
    'a photo of an old indian {}',
    'a photo of a middle eastern {}',
    'a photo of a young middle eastern {}',
    'a photo of an old middle eastern {}'
]

gender_classes = ['man', 'woman']

preds_classes = ['Male', 'Female']

In [5]:
def zeroshot_classifier(classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        prompts = []
        for classname in tqdm(classnames):
            texts = [template.format(classname)
                     for template in templates]  # format with class
            prompts.append(texts)
            texts = open_clip.tokenize(texts).cuda()  # tokenize
            class_embeddings = model.encode_text(texts)  # embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights, prompts


zeroshot_weights, prompts = zeroshot_classifier(gender_classes, clip_templates)

100%|██████████| 2/2 [00:00<00:00,  5.52it/s]


In [6]:
plist = []
for prompt in prompts:
    for p in prompt:
        plist.append(p)

In [8]:
with torch.no_grad():
    preds_dict = {}
    fnames = []
    preds = []
    for _, emb in img_embs.iterrows():
        name = emb['file']
        img_features = emb['embeddings']
        image_features = torch.from_numpy(img_features).to(device)
        logits = 100. * image_features @ zeroshot_weights
        text_probs = logits.softmax(dim=-1)
        top_probs, top_labels = text_probs.cpu().topk(1, dim=-1)
        pindex = top_labels.cpu().numpy().item()
        fnames.append(name)
        preds.append(preds_classes[pindex])
    preds_dict['file'] = fnames
    preds_dict['gender_preds'] = preds

In [9]:
preds_df = pd.DataFrame(data=preds_dict)
new_df = fface_df.set_index('file').join(preds_df.set_index('file'))
new_df.drop(columns=['service_test'], inplace=True)

In [25]:
template = "a photo of a{age}{race}{gender}"
age_labels = [" young", " middle-aged", " old", ""]
race_labels = [" black", " indian", " latino hispanic",
               " middle eastern", " southeast asian", " east asian", " white", ""]
gender_labels = [" woman", " man"]

In [26]:
age_prompts = [template.format(age=label, race="{race}", gender="{gender}") for label in age_labels]

In [27]:
race_prompts = []
for ap in age_prompts:
    for rl in race_labels:
        race_prompts.append(ap.format(race=rl, gender="{gender}"))

In [28]:
final_prompts = []
for rp in race_prompts:
    for gl in gender_labels:
        final_prompts.append(rp.format(gender=gl))

In [30]:
import json

In [31]:
with open ("testlabels.json", "w", encoding="utf-8") as final:
    json.dump(final_prompts, final)