In [26]:
import datasets
from transformers import CLIPProcessor, CLIPModel, CLIPVisionModel
import torch
import numpy as np
import pandas as pd
import utils
import os
import json
from tqdm.auto import tqdm
import joblib

In [2]:
validation_ds = datasets.load_dataset('HuggingFaceM4/FairFace', '1.25', split="validation", verification_mode="no_checks")

In [3]:
reported_stats = {}

In [4]:
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# Load model and pre-processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.to(device)
vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
vision_model.to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [32]:
def get_embedding_and_zs(sample):

    # Age prediction
    inputs = processor(text=["text"]*len(sample), images=sample["image"], return_tensors="pt", padding=True).to(device)
    # outputs = model.get_image_features(**inputs)
    # Store embeddings - dim 512
    # sample["proj_embeddings"] = model.get_image_features(**inputs)
    outputs = model(**inputs)
    # Store embeddings - dim 512
    sample["proj_embeddings"] = outputs.image_embeds
    # # Reduce the age by 2
    # sample["age"] = [age - 2 for age in sample["age"]] # Since classes 0 and 1 have been deleted
    return sample

In [33]:
validation_ds = validation_ds.map(get_embedding_and_zs, batched = True, batch_size=32)

Map:   0%|          | 0/10954 [00:00<?, ? examples/s]

In [34]:
def adjust_ages(sample):
    sample["age"] = [age - 2 for age in sample["age"]] # Since classes 0 and 1 have been deleted
    return sample

In [35]:
# Get gender stats first
# 0 - Male, 1 - Female
gender_labels = np.array(validation_ds["gender"])
reported_stats["perc_fem_val"] = round(np.sum(gender_labels == 1) / len(gender_labels) * 100, ndigits=2) 
reported_stats["perc_mal_val"] = round(np.sum(gender_labels == 0) / len(gender_labels) * 100, ndigits=2) 

In [37]:
# Get gender predictions
gender_model, _, _ = utils._load_lr_classifiers()
gender_preds = utils._predict_gender(gender_model=gender_model, images=np.array(validation_ds["proj_embeddings"]))

KeyError: 10

In [None]:
gender_labels_pred = np.argmax(gender_preds, axis=1)
reported_stats["perc_fem_val_pred"] = round(np.sum(gender_labels_pred == 1) / len(gender_labels_pred) * 100, ndigits=2) 
reported_stats["perc_mal_val_pred"] = round(np.sum(gender_labels_pred == 0) / len(gender_labels_pred) * 100, ndigits=2) 

In [12]:
reported_stats

{'perc_fem_val': 47.12,
 'perc_mal_val': 52.88,
 'perc_fem_val_pred': 50.14,
 'perc_mal_val_pred': 49.86}

In [13]:
round(np.sum(gender_labels_pred == gender_labels) / len(gender_labels), 2)

0.96

In [14]:
# Filter and adjust age data
# age_validation_ds = validation_ds.filter(lambda sample: sample["age"] not in {0, 1}).map(adjust_ages, batched = True, batch_size=32) # Filter out the first two classes

In [15]:
# 0,1,2,3,4,5 - Up to 50 & 6,7,8 - Over 50
age_labels = np.array(validation_ds["age"])
reported_stats["perc_ut50_val"] = round(np.sum(age_labels <= 5) / len(age_labels) * 100, ndigits=2) 
reported_stats["perc_o50_val"] = round(np.sum(age_labels >= 6) / len(age_labels) * 100, ndigits=2)

In [29]:
# Get age predictions
age_scaler, age_model = utils._load_age_model()
age_preds = utils._predict_age(age_scaler=age_scaler, age_model=age_model, images=np.array(validation_ds["proj_embeddings"]))

In [31]:
from sklearn.metrics import classification_report


age_scaler = joblib.load("models/projected_scaler.joblib")
age_clf = joblib.load("models/lr_clf_proj_age.joblib")
X_val_scaled = age_scaler.transform(np.array(validation_ds["proj_embeddings"]))
y_val_preds = age_clf.predict(X_val_scaled)
print(classification_report(age_labels, y_val_preds))

              precision    recall  f1-score   support

           0       0.02      0.99      0.05       199
           1       0.47      0.02      0.03      1356
           2       0.39      0.14      0.20      1181
           3       0.72      0.04      0.08      3300
           4       0.58      0.04      0.08      2330
           5       0.44      0.20      0.28      1353
           6       0.41      0.02      0.04       796
           7       0.15      0.16      0.15       321
           8       0.14      0.81      0.24       118

    accuracy                           0.10     10954
   macro avg       0.37      0.27      0.13     10954
weighted avg       0.53      0.10      0.11     10954



In [25]:
np.unique(np.argmax(age_preds, axis=1), return_counts=True)

(array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=int64),
 array([8424,   51,  408,  201,  177,  623,   39,  352,  679], dtype=int64))

In [None]:
round(np.sum(np.argmax(age_preds, axis=2) == age_labels) / len(age_labels), 2)

0.1

In [None]:
age_labels_pred = []
for pred in tqdm(age_preds):
    # print(pred)
    max_sum_age = {
        "Up to 50": sum([p for p in pred[:-3]]),
        "Over 50": sum([p for p in pred[-3:]])
    }
    # print(max_sum_age)
    age_labels_pred.append(int(max_sum_age["Up to 50"] < max_sum_age["Over 50"]))

age_labels_pred = np.array(age_labels_pred)

In [128]:
# age_labels_pred = np.argmax(age_preds, axis=1)
reported_stats["perc_ut50_val_pred"] = round(np.sum(age_labels_pred == 0) / len(age_labels_pred) * 100, ndigits=2) 
reported_stats["perc_o50_val_pred"] = round(np.sum(age_labels_pred == 1) / len(age_labels_pred) * 100, ndigits=2)

In [129]:
reported_stats

{'perc_fem_val': 47.12,
 'perc_mal_val': 52.88,
 'perc_fem_val_pred': 50.14,
 'perc_mal_val_pred': 49.86,
 'perc_ut50_val': 88.73,
 'perc_o50_val': 11.27,
 'perc_ut50_val_pred': 90.22,
 'perc_o50_val_pred': 9.78,
 'avg_conf_gender_blackpanther': 0.9994,
 'avg_conf_gender_mamamia': 0.9995,
 'avg_conf_gender_marigold': 0.9993}

In [None]:
if os.path.exists("black_panther_demography.csv"):
    df = pd.read_csv("black_panther_demography.csv")
    reported_stats["avg_conf_gender_blackpanther"] = round(df.gender.apply(lambda gender: max(json.loads(gender.replace("\'", "\"")).values())).mean(), ndigits=4)
    # reported_stats["avg_conf_age_blackpanther"] = round(df.age.apply(lambda age: max(json.loads(age.replace("\'", "\"")).values())).mean(), ndigits=4)

In [None]:
if os.path.exists("mama_mia_demography.csv"):
    df = pd.read_csv("mama_mia_demography.csv")
    reported_stats["avg_conf_gender_mamamia"] = round(df.gender.apply(lambda gender: max(json.loads(gender.replace("\'", "\"")).values())).mean(), ndigits=4)
    # reported_stats["avg_conf_age_mamamia"] = round(df.age.apply(lambda age: max(json.loads(age.replace("\'", "\"")).values())).mean(), ndigits=4)

In [None]:
if os.path.exists("marigold_demography.csv"):
    df = pd.read_csv("marigold_demography.csv")
    reported_stats["avg_conf_gender_marigold"] = round(df.gender.apply(lambda gender: max(json.loads(gender.replace("\'", "\"")).values())).mean(), ndigits=4)
    # reported_stats["avg_conf_age_marigold"] = round(df.age.apply(lambda age: max(json.loads(age.replace("\'", "\"")).values())).mean(), ndigits=4)

In [106]:
reported_stats

{'perc_fem_val': 47.12,
 'perc_mal_val': 52.88,
 'perc_fem_val_pred': 50.14,
 'perc_mal_val_pred': 49.86,
 'perc_ut50_val': 88.73,
 'perc_o50_val': 11.27,
 'perc_ut50_val_pred': 90.23,
 'perc_o50_val_pred': 9.77,
 'avg_conf_gender_blackpanther': 0.9994,
 'avg_conf_gender_mamamia': 0.9995,
 'avg_conf_gender_marigold': 0.9993}

In [110]:
pd.DataFrame([reported_stats])

Unnamed: 0,perc_fem_val,perc_mal_val,perc_fem_val_pred,perc_mal_val_pred,perc_ut50_val,perc_o50_val,perc_ut50_val_pred,perc_o50_val_pred,avg_conf_gender_blackpanther,avg_conf_gender_mamamia,avg_conf_gender_marigold
0,47.12,52.88,50.14,49.86,88.73,11.27,90.23,9.77,0.9994,0.9995,0.9993
