In [29]:
import torch
from tqdm import tqdm
import pandas as pd
import open_clip
from bias_explorer.operations.report import get_empty_report_dict, gen_dict_report
from bias_explorer.utils import dataloader

In [30]:
model_name = "ViT-B-16"
data_source = "openai"
preds_path = f"clip_aggregation_preds_{model_name}_{data_source}.csv"
report_path = f"clip_aggregation_report_{model_name}_{data_source}.csv"
img_path = f"./data/fairface/embeddings/openCLIP/{model_name}/{data_source}/generated_img_embs.pkl"
val_path = "./data/fairface/fface_val.csv"
device = "cuda" if torch.cuda.is_available() else "cpu"

In [31]:
img_embs = pd.read_pickle(img_path)
fface_df = pd.read_csv(val_path)
model, _, preprocessing = open_clip.create_model_and_transforms(
    model_name, pretrained=data_source, device=device)

In [32]:
clip_templates = [
    'a photo of a {}',
    'a photo of a white {}',
    'a photo of a young white {}',
    'a photo of an old white {}',
    'a photo of a black {}',
    'a photo of a young black {}',
    'a photo of an old black {}',
    'a photo of a latino {}',
    'a photo of a young latino {}',
    'a photo of an old latino {}',
    'a photo of a east asian {}',
    'a photo of a young east asian {}',
    'a photo of an old east asian {}',
    'a photo of a southeast asian {}',
    'a photo of a young southeast asian {}',
    'a photo of an old southeast asian {}',
    'a photo of an indian {}',
    'a photo of a young indian {}',
    'a photo of an old indian {}',
    'a photo of a middle eastern {}',
    'a photo of a young middle eastern {}',
    'a photo of an old middle eastern {}'
]

gender_classes = ['man', 'woman']

preds_classes = ['Male', 'Female']

In [33]:
def zeroshot_classifier(classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        prompts = []
        for classname in tqdm(classnames):
            texts = [template.format(classname)
                     for template in templates]  # format with class
            prompts.append(texts)
            texts = open_clip.tokenize(texts).cuda()  # tokenize
            class_embeddings = model.encode_text(texts)  # embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights, prompts


zeroshot_weights, prompts = zeroshot_classifier(gender_classes, clip_templates)

100%|██████████| 2/2 [00:00<00:00, 32.78it/s]


In [34]:
plist = []
for prompt in prompts:
    for p in prompt:
        plist.append(p)

In [35]:
plist

['a photo of a man',
 'a photo of a white man',
 'a photo of a young white man',
 'a photo of an old white man',
 'a photo of a black man',
 'a photo of a young black man',
 'a photo of an old black man',
 'a photo of a latino man',
 'a photo of a young latino man',
 'a photo of an old latino man',
 'a photo of a east asian man',
 'a photo of a young east asian man',
 'a photo of an old east asian man',
 'a photo of a southeast asian man',
 'a photo of a young southeast asian man',
 'a photo of an old southeast asian man',
 'a photo of an indian man',
 'a photo of a young indian man',
 'a photo of an old indian man',
 'a photo of a middle eastern man',
 'a photo of a young middle eastern man',
 'a photo of an old middle eastern man',
 'a photo of a woman',
 'a photo of a white woman',
 'a photo of a young white woman',
 'a photo of an old white woman',
 'a photo of a black woman',
 'a photo of a young black woman',
 'a photo of an old black woman',
 'a photo of a latino woman',
 'a pho

In [36]:
with torch.no_grad():
    preds_dict = {}
    fnames = []
    preds = []
    for _, emb in img_embs.iterrows():
        name = emb['file']
        img_features = emb['embeddings']
        image_features = torch.from_numpy(img_features).to(device)
        logits = 100. * image_features @ zeroshot_weights
        text_probs = logits.softmax(dim=-1)
        top_probs, top_labels = text_probs.cpu().topk(1, dim=-1)
        pindex = top_labels.cpu().numpy().item()
        fnames.append(name)
        preds.append(preds_classes[pindex])
    preds_dict['file'] = fnames
    preds_dict['gender_preds'] = preds

In [37]:
preds_df = pd.DataFrame(data=preds_dict)
new_df = fface_df.set_index('file').join(preds_df.set_index('file'))
new_df.drop(columns=['service_test'], inplace=True)

In [38]:
new_df.to_csv(preds_path)

In [39]:
new_df = dataloader.load_df(preds_path)

In [40]:
rep_dict = get_empty_report_dict(new_df, "accuracy")

In [41]:
rep_dict = gen_dict_report(new_df, "OpenAI", "accuracy", rep_dict)

In [42]:
rep_df = pd.DataFrame(rep_dict)
rep_df.to_csv(report_path)