In [2]:
from io import BytesIO
from PIL import Image
import datasets
import joblib
from transformers import CLIPProcessor, CLIPModel, CLIPVisionModel
from sklearn.linear_model import LogisticRegression
import torch
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import precision_recall_fscore_support
import pandas as pd
from sklearn.metrics import classification_report

In [3]:
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Load model and pre-processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.to(device)
vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
vision_model.to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [5]:
def get_embedding_and_zs(sample):
    # import pdb; pdb.set_trace()

    # Age prediction
    age_texts = [f"A person in the {c} age group" for c in ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "more than 70"]]
    inputs = processor(text=age_texts, images=sample["image"], return_tensors="pt", padding=True).to(device)
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image # this is the image-text similarity score
    age_pred = logits_per_image.argmax(dim=1) # we can take the argmax
    sample["zs_age_clip"] = [int(gp) for gp in age_pred]

    return sample

In [6]:
# Load training data
train_ds = datasets.load_dataset('HuggingFaceM4/FairFace', '1.25', split='train', verification_mode="no_checks")
# train_ds = train_ds.shuffle(seed=42).select([i for i in range(1_000)]) # Take only first 20_000 images
train_ds = train_ds.shuffle(seed=42)
train_ds = train_ds.map(get_embedding_and_zs, batched = True, batch_size=16)

In [7]:
# Load validation data and test on this
valid_ds = datasets.load_dataset('HuggingFaceM4/FairFace', '1.25', split="validation", verification_mode="no_checks")
valid_ds = valid_ds.shuffle(seed=42)
# valid_ds = valid_ds.shuffle(seed=42) 
valid_ds = valid_ds.map(get_embedding_and_zs, batched = True, batch_size=16)

Map:   0%|          | 0/10954 [00:00<?, ? examples/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [8]:
y_train_age = np.array(train_ds["age"])
y_val_age = np.array(valid_ds["age"])

In [9]:
np.bincount(y_train_age)

array([ 1792, 10408,  9103, 25598, 19250, 10744,  6228,  2779,   842],
      dtype=int64)

In [10]:
np.bincount(y_val_age)

array([ 199, 1356, 1181, 3300, 2330, 1353,  796,  321,  118], dtype=int64)

In [11]:
print(classification_report( y_train_age, np.array(train_ds["zs_age_clip"])))

              precision    recall  f1-score   support

           0       0.41      0.77      0.54      1792
           1       0.80      0.57      0.66     10408
           2       0.39      0.55      0.46      9103
           3       0.57      0.29      0.39     25598
           4       0.41      0.51      0.45     19250
           5       0.30      0.23      0.26     10744
           6       0.30      0.24      0.27      6228
           7       0.26      0.15      0.19      2779
           8       0.05      0.75      0.10       842

    accuracy                           0.40     86744
   macro avg       0.39      0.45      0.37     86744
weighted avg       0.47      0.40      0.41     86744



In [12]:
y_preds = np.array(train_ds["zs_age_clip"])
precision, recall, f_score_weighted, _ = precision_recall_fscore_support(y_train_age, y_preds, average='weighted')
_, _, f_score_macro, _ = precision_recall_fscore_support(y_train_age, y_preds, average='macro')
_, _, f_score_micro, _ = precision_recall_fscore_support(y_train_age, y_preds, average='micro')
class_rep = classification_report( y_train_age, y_preds, output_dict=True)
print(f"Training set metrics - Age (CLIP ZS) \n" + "="*40)
print(f"Accuracy: {class_rep['accuracy']:.4f} Precision: {precision:.4f}, Recall: {recall:.4f}, F-Score(Weighted): {f_score_weighted:.4f}, F-Score(Micro): {f_score_micro:.4f}, F-Score(Macro): {f_score_macro:.4f}")

Training set metrics - Age (CLIP ZS) 
Accuracy: 0.3980 Precision: 0.4741, Recall: 0.3980, F-Score(Weighted): 0.4121, F-Score(Micro): 0.3980, F-Score(Macro): 0.3688


In [15]:
len(np.array(train_ds["zs_age_clip"]))

86744

In [None]:
np.save("clip_zs_age_preds_train_42.npy", np.array(train_ds["zs_age_clip"]))

: 

In [13]:
print(classification_report(y_val_age, np.array(valid_ds["zs_age_clip"])))

              precision    recall  f1-score   support

           0       0.38      0.76      0.50       199
           1       0.83      0.57      0.68      1356
           2       0.40      0.58      0.48      1181
           3       0.57      0.28      0.37      3300
           4       0.40      0.51      0.44      2330
           5       0.31      0.24      0.27      1353
           6       0.31      0.23      0.27       796
           7       0.28      0.17      0.22       321
           8       0.06      0.81      0.12       118

    accuracy                           0.40     10954
   macro avg       0.39      0.46      0.37     10954
weighted avg       0.48      0.40      0.41     10954



In [14]:
y_preds = np.array(valid_ds["zs_age_clip"])
precision, recall, f_score_weighted, _ = precision_recall_fscore_support(y_val_age, y_preds, average='weighted')
_, _, f_score_macro, _ = precision_recall_fscore_support(y_val_age, y_preds, average='macro')
_, _, f_score_micro, _ = precision_recall_fscore_support(y_val_age, y_preds, average='micro')
class_rep = classification_report( y_val_age, y_preds, output_dict=True)
print(f"Validation set metrics - Age (CLIP ZS) \n" + "="*40)
print(f"Accuracy: {class_rep['accuracy']:.4f} Precision: {precision:.4f}, Recall: {recall:.4f}, F-Score(Weighted): {f_score_weighted:.4f}, F-Score(Micro): {f_score_micro:.4f}, F-Score(Macro): {f_score_macro:.4f}")

Validation set metrics - Age (CLIP ZS) 
Accuracy: 0.3984 Precision: 0.4782, Recall: 0.3984, F-Score(Weighted): 0.4110, F-Score(Micro): 0.3984, F-Score(Macro): 0.3714


In [16]:
len(np.array(valid_ds["zs_age_clip"]))

10954

In [17]:
np.save("clip_zs_age_preds_val_42.npy", np.array(valid_ds["zs_age_clip"]))