# Use Case 8: Gender Classification in Face Images - Embdding Extraction

In [14]:
from io import BytesIO
from PIL import Image

import datasets

import os

import torch
import numpy as np
import h5py
from tqdm import tqdm
from transformers import ViTImageProcessor, ViTForImageClassification

## Load the Saved Dataset

In [2]:
AGE_CLASSES = ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "more than 70"]
GENDER_CLASSES = ["Male", "Female"]
RACE_CLASSES = ["East Asian", "Indian", "Black", "White", "Middle Eastern", "Latino_Hispanic", "Southeast Asian"]

# Create dictionaries mapping class names to their index
age2id = {age: idx for idx, age in enumerate(AGE_CLASSES)}
gender2id = {gender: idx for idx, gender in enumerate(GENDER_CLASSES)}
race2id = {race: idx for idx, race in enumerate(RACE_CLASSES)}

# Create reverse mappings from index to class name
id2age = {idx: age for age, idx in age2id.items()}
id2gender = {idx: gender for gender, idx in gender2id.items()}
id2race = {idx: race for race, idx in race2id.items()}

In [11]:
ds = datasets.load_from_disk("fairface/data")

In [12]:
ds

DatasetDict({
    train: Dataset({
        features: ['img_bytes', 'age', 'gender', 'race', 'id'],
        num_rows: 44425
    })
    test: Dataset({
        features: ['img_bytes', 'age', 'gender', 'race', 'id'],
        num_rows: 9438
    })
    drifted: Dataset({
        features: ['img_bytes', 'age', 'gender', 'race', 'id'],
        num_rows: 13835
    })
    new_unseen: Dataset({
        features: ['img_bytes', 'age', 'gender', 'race', 'id'],
        num_rows: 30000
    })
})

In [13]:
def bytes_to_pil(example_batch):
    example_batch['img'] = [
        Image.open(BytesIO(b)) for b in example_batch.pop('img_bytes')
    ]
    return example_batch

ds = ds.with_transform(bytes_to_pil)

## Load the Fine-Tuned Model

In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [16]:
model_path = os.path.join("fairface", "saved_model", "best_model")

In [17]:
# Load ViT model and image processor

processor = ViTImageProcessor.from_pretrained(model_path)

model = ViTForImageClassification.from_pretrained(
    model_path,
    num_labels=2,  # Gender classification (Male/Female)
    id2label={0: "Male", 1: "Female"},
    label2id={"Male": 0, "Female": 1},
    ignore_mismatched_sizes=True,
    output_hidden_states=True  # Ensure hidden states are returned
).to(device)


In [21]:
def collate_fn(batch):
    """
    Custom collate function to preprocess images before batching.
    Converts images to pixel values using the ViT processor.
    """
    images = [item["img"] for item in batch]
    ids = [item["id"] for item in batch]
    genders = [item["gender"] for item in batch]
    
    # Convert images to tensors using the processor
    pixel_values = processor(images=images, return_tensors="pt")["pixel_values"]

    return {
        "pixel_values": pixel_values,
        "id": ids,
        "gender": torch.tensor(genders, dtype=torch.long)
    }


In [22]:
def extract_embedding_and_predict(model, processor, dataset, layer_id):
    X = []
    E = np.empty((0, 768))
    Y_original = []
    Y_original_names = []
    Y_predicted = []
    Y_predicted_names = []

    BATCH_SIZE = 32  # Adjust as needed
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for batch in tqdm(dataloader):
        pixel_values = batch["pixel_values"].to(device)

        with torch.no_grad():
            outputs = model(pixel_values, output_hidden_states=True)

        batch_probabilities = torch.nn.functional.softmax(outputs["logits"], dim=-1)
        batch_labels = torch.argmax(batch_probabilities, dim=1).tolist()
        batch_labels_name = [model.config.id2label[l] for l in batch_labels]

        last_layer_hidden_states = outputs["hidden_states"][layer_id]
        embedding_CLS = last_layer_hidden_states[:, 0, :].detach().cpu().numpy()

        X.extend(batch["id"])
        E = np.vstack([E, embedding_CLS])
        Y_original.extend(batch["gender"].tolist())
        Y_original_names.extend([model.config.id2label[l] for l in batch["gender"].tolist()])
        Y_predicted.extend(batch_labels)
        Y_predicted_names.extend(batch_labels_name)

    return X, E, Y_original, Y_original_names, Y_predicted, Y_predicted_names


In [23]:
def save_embedding(output_path, X, E, Y_original, Y_original_names, Y_predicted, Y_predicted_names):
    with h5py.File(output_path, "w") as fp:
        fp.create_dataset("X", data=np.array(X, dtype="S"), compression="gzip")  # Convert to bytes for HDF5
        fp.create_dataset("E", data=E, compression="gzip")
        fp.create_dataset("Y_original", data=np.array(Y_original, dtype=int), compression="gzip")
        fp.create_dataset("Y_original_names", data=np.array(Y_original_names, dtype="S"), compression="gzip")
        fp.create_dataset("Y_predicted", data=np.array(Y_predicted, dtype=int), compression="gzip")
        fp.create_dataset("Y_predicted_names", data=np.array(Y_predicted_names, dtype="S"), compression="gzip")
    print(f"Saved embeddings to {output_path}")


In [24]:
layer_id = -1  # Last layer

for split in ["train_embedding", "test_embedding", "new_unseen_embedding", "drift_embedding"]:
    print(f"Processing {split} split...")
    X, E, Y_original, Y_original_names, Y_predicted, Y_predicted_names = extract_embedding_and_predict(
        model, processor, ds[split], layer_id
    )
    
    output_path = f"fairface/saved_embedding/{split}.h5"
    save_embedding(output_path, X, E, Y_original, Y_original_names, Y_predicted, Y_predicted_names)


Processing train split...


100%|██████████| 1389/1389 [05:31<00:00,  4.19it/s]


Saved embeddings to fairface/saved_embedding/train.h5
Processing test split...


100%|██████████| 295/295 [00:53<00:00,  5.47it/s]


Saved embeddings to fairface/saved_embedding/test.h5
Processing new_unseen split...


100%|██████████| 938/938 [03:18<00:00,  4.72it/s]


Saved embeddings to fairface/saved_embedding/new_unseen.h5
Processing drifted split...


100%|██████████| 433/433 [01:22<00:00,  5.25it/s]


Saved embeddings to fairface/saved_embedding/drifted.h5
