# **Task 2 - Custom Loss**
**Vision Transformer (ViT) Baseline**

**Authors:** Iris Vukovic & Olmo Gordon Rodriguez

---

## Downloading the APPA-REAL dataset
In this case, we are using the full dataset (not its reduced version)

In [None]:
from zipfile import ZipFile

# downloading the data
!wget http://data.chalearnlap.cvc.uab.cat/Colab_MFPDS/2025/appa-real-dataset_v2.zip

with ZipFile('appa-real-dataset_v2.zip','r') as zip_ref:
   zip_ref.extractall()
   print('Data decompressed successfully')

# removing the .zip file after extraction to clean space
!rm appa-real-dataset_v2.zip

## Check pytorch version
 - This notebook was successfully tested on version = 2.5.1+cu124

In [None]:
import torch
print(torch.__version__)

In [None]:
# installing torchinfo to print model summary
!pip install torchinfo

## Import required libraries

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import csv
from PIL import Image
from timm import create_model
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import copy
from tqdm import tqdm

# Defining the data loader class
In this example, metadata information is loaded but not used. Future implementations can take benefit of it.
Note that age labels are divided by 100 (assuming 100 is the max age found in the dataset) so that the age values can be normalized to be in the range of 0 and 1. This way, we can add a sigmoid activation in the last layer of our model.

In [None]:
class AgeEstimationDataset(Dataset):
    def __init__(self, image_dir, csv_file, transform=None):
        self.image_dir = image_dir
        self.data_info = pd.read_csv(csv_file)
        self.transform = transform
        self.age_normalization_factor = 100; # used to normalize age labels

    def __len__(self):
        return len(self.data_info)

    def __normalization_factor__(self):
        return self.age_normalization_factor

    def __getitem__(self, idx):
        image_id = f"{self.data_info.iloc[idx, 0]:06d}.jpg"  # Format image ID
        image_path = os.path.join(self.image_dir, image_id)
        image = Image.open(image_path).convert("RGB")  # Load image as RGB

        # normalizing age labes (by 100) to be between 0 and 1 (assuming 100 is the max age)
        age = float(self.data_info.iloc[idx, 1])/self.age_normalization_factor  # Load age label
        metadata = self.data_info.iloc[idx, 2:].tolist()  # Extract metadata as list
        ethnicity = self.data_info.iloc[idx, 2]

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(age, dtype=torch.float32), metadata

In [None]:
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
# Create dataset and dataloader (train set):
dataset_train = AgeEstimationDataset("train_data", "labels_metadata_train.csv", transform=data_transforms)
dataloader_train = DataLoader(dataset_train, batch_size=32, shuffle=True, num_workers=2)
print(f"Total number of train samples: {len(dataloader_train.dataset)}")

# Create dataset and dataloader (validation set):
dataset_valid = AgeEstimationDataset("valid_data", "labels_metadata_valid.csv", transform=data_transforms)
dataloader_valid = DataLoader(dataset_valid, batch_size=32, shuffle=True, num_workers=2)
print(f"Total number of valid samples: {len(dataloader_valid.dataset)}")

In [None]:
import random

# Function to display image samples
def display_samples(dataset, num_samples=3):
    fig, axes = plt.subplots(1, num_samples, figsize=(11, 5))
    for i in range(num_samples):
        random_index = random.randint(0, len(dataset) - 1)
        image, age, metadata = dataset[random_index]

        image = image * np.array([0.5]) + np.array([0.5])  # Denormalize
        image = image.permute(1, 2, 0).numpy() if isinstance(image, torch.Tensor) else np.array(image)

        axes[i].imshow(image)
        axes[i].axis("off")
        # denormalizing the age values to plot the original labels
        axes[i].set_title(f"Age: {age*dataset.__normalization_factor__():.2f}\n{', '.join(metadata)}")
    plt.show()

display_samples(dataset_train, num_samples=3)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load the updated CSV file
df = pd.read_csv("labels_metadata_train.csv")

# Plot the age distribution using a histogram
plt.figure(figsize=(10, 6))
plt.hist(df['age'], bins=30, color='blue', alpha=0.7)
plt.title('Age Distribution')
plt.xlabel('Age')
plt.ylabel('Frequency')
plt.show()

In [None]:
# Plot the distribution of Ethnicity (as an example of metadata)
plt.figure(figsize=(10, 6))
ethnicity_valid_values = ['asian', 'caucasian', 'afroamerican']
df = df[df['ethnicity'].isin(ethnicity_valid_values)]
ethnicity_counts = df['ethnicity'].value_counts()
plt.bar(ethnicity_counts.index, ethnicity_counts.values, color='orange', alpha=0.7)
plt.title('Ethnicity Distribution')
plt.xlabel('Ethnicity')
plt.ylabel('Frequency')
plt.xticks(rotation=45)
plt.show()

In [None]:
# Plot the distribution of Gender (another example of metadata)
df = pd.read_csv("labels_metadata_train.csv")
plt.figure(figsize=(10, 6))
gender_valid_values = ['male', 'female']
df = df[df['gender'].isin(gender_valid_values)]
gender_counts = df['gender'].value_counts()
plt.bar(gender_counts.index, gender_counts.values, color='green', alpha=0.7)
plt.title('Gender Distribution')
plt.xlabel('Gender')
plt.ylabel('Frequency')
plt.xticks(rotation=45)
plt.show()

In [None]:
# Plot the distribution of Expression (another example of metadata)
df = pd.read_csv("labels_metadata_train.csv")
plt.figure(figsize=(10, 6))
expression_valid_values = ['slightlyhappy', 'happy', 'neutral' , 'other']
df = df[df['emotion'].isin(expression_valid_values)]
expression_counts = df['emotion'].value_counts()
plt.bar(expression_counts.index, expression_counts.values, color='red', alpha=0.7)
plt.title('Expression Distribution')
plt.xlabel('Expression')
plt.ylabel('Frequency')
plt.xticks(rotation=45)
plt.show()

#Augmenting

In [None]:
import pandas as pd
import os
from PIL import Image
from torchvision import transforms
from tqdm import tqdm  # for progress bar (optional)
import torch

class AgeEstimationDataset(Dataset):
    def __init__(self, image_dir, csv_file, transform=None, augmentation_transform=None):
        self.image_dir = image_dir
        self.data_info = pd.read_csv(csv_file)
        self.transform = transform
        self.augmentation_transform = augmentation_transform  # Optional augmentation transform
        self.age_normalization_factor = 100  # Used to normalize age labels

    def __len__(self):
        return len(self.data_info)

    def __normalization_factor__(self):
        return self.age_normalization_factor

    def __getitem__(self, idx):
        image_id = f"{self.data_info.iloc[idx, 0]:06d}.jpg"  # Format image ID
        image_path = os.path.join(self.image_dir, image_id)
        image = Image.open(image_path).convert("RGB")  # Load image as RGB

        # Normalize age labels (by 100) to be between 0 and 1 (assuming 100 is the max age)
        age = float(self.data_info.iloc[idx, 1]) / self.age_normalization_factor  # Load age label
        metadata = self.data_info.iloc[idx, 2:].tolist()  # Extract metadata as list
        ethnicity = metadata[1]
        age_value = int(self.data_info.iloc[idx, 1])

        # Apply augmentation based on conditions (ethnicity or age > 60)
        if (ethnicity in ['afroamerican', 'asian']) or (age_value > 60):
            if self.augmentation_transform:
                image = self.augmentation_transform(image)

        # Apply any additional transforms (like normalization)
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(age, dtype=torch.float32), metadata


def save_augmented_images_and_update_csv(dataset, augmentation_transforms, output_dir, csv_file, num_augmented_images=3):
    # Create a new directory for augmented images
    os.makedirs(output_dir, exist_ok=True)

    # Load the original CSV
    df = pd.read_csv(csv_file)


    # Augment and save images
    augmented_data = []
    for idx in tqdm(range(len(dataset)), desc="Processing images"):
        image, age, metadata = dataset[idx]  # Get the original image and other data
        original_image_id = dataset.data_info.iloc[idx, 0]
        metadata = dataset.data_info.iloc[idx, 2:].tolist()
        gender = metadata[0]
        ethnicity = metadata[1]
        expression = metadata[2]
        age_value = int(dataset.data_info.iloc[idx, 1])

        # Apply augmentation only if the condition is met (ethnicity or age > 60)
        if (ethnicity in ['afroamerican', 'asian']) or (age_value > 60):
            # Apply augmentation multiple times (num_augmented_images times)
            for i in range(num_augmented_images):
                augmented_image = augmentation_transforms(image)  # Apply the augmentation

                augmented_image_pil = transforms.ToPILImage()(augmented_image)

                # Save the augmented image
                augmented_image_name = f"augmented_{original_image_id}_{i}.jpg"
                new_image_path = os.path.join(output_dir, augmented_image_name)
                augmented_image_pil.save(new_image_path)

                # Add new entry to augmented data
                #augmented_data.append([original_image_id, age.item()] + [metadata] + [new_image_path])
                augmented_data.append([original_image_id, age.item()*100, gender, ethnicity, expression])


    # Append the augmented data to the original dataframe
    augmented_df = pd.DataFrame(augmented_data, columns=df.columns.tolist())

    # Append to the original CSV file without overwriting it
    if not os.path.exists(csv_file):
        df = pd.concat([df, augmented_df], ignore_index=True)
        df.to_csv(csv_file, index=False)
    else:
        # If file exists, append to a new file to avoid overwriting
        new_csv_file = f"augmented_{os.path.basename(csv_file)}"
        df = pd.concat([df, augmented_df], ignore_index=True)
        df.to_csv(new_csv_file, index=False)
        print(f"CSV file already exists. Augmented data saved to new file: {new_csv_file}")

    print(f"CSV file updated with augmented images: {csv_file}")


# Define the augmentation transforms
augmentation_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
])

# Define the standard transforms for all data
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
}

# Assume the original dataset and transform
dataset_train = AgeEstimationDataset(
    image_dir="train_data",  # Directory where original images are stored
    csv_file="labels_metadata_train.csv",  # Original CSV file
    transform=data_transforms['train'],  # You can apply the transforms to training data here
    augmentation_transform=None  # Augmentation transform will be applied conditionally in the `__getitem__` method
)

# Directory to save augmented images
output_dir = "augmented_train_data"

# Number of augmented images to generate per original image
num_augmented_images = 3  # For example, generate 3 augmented images for each original image

# Update CSV with augmented images
save_augmented_images_and_update_csv(dataset_train, augmentation_transforms, output_dir, "labels_metadata_train.csv", num_augmented_images)


In [None]:
# Create dataset and dataloader (train set):
aug_dataset_train = AgeEstimationDataset("train_data", "augmented_labels_metadata_train.csv", transform=data_transforms['train'])
aug_dataloader_train = DataLoader(aug_dataset_train, batch_size=32, shuffle=True, num_workers=2)
print(f"Total number of train samples: {len(aug_dataloader_train.dataset)}")

# Create dataset and dataloader (validation set):
dataset_valid = AgeEstimationDataset("valid_data", "labels_metadata_valid.csv", transform=data_transforms['val'])
dataloader_valid = DataLoader(dataset_valid, batch_size=32, shuffle=True, num_workers=2)
print(f"Total number of valid samples: {len(dataloader_valid.dataset)}")

In [None]:
import random

# Function to denormalize images for visualization
def denormalize(tensor):
    mean = [0.5]
    std = [0.5]
    tensor = tensor * std[0] + mean[0]
    tensor = tensor.clip(0, 1)
    return tensor

# Select a random sample from the dataset
random_idx = random.randint(0, len(dataset_train) - 1)
image, _, _ = dataset_train[random_idx]  # Get the image at the random index

# Apply the augmentation transforms to the image
augmented_images = []
augmentation_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
])

# Generate augmented versions of the random image
for _ in range(5):  # Generate 5 augmented versions
    augmented_image = augmentation_transforms(image)  # Apply the augmentations
    augmented_images.append(augmented_image)

# Convert image tensors to numpy arrays for plotting
image = denormalize(image).numpy().transpose(1, 2, 0)  # Convert from (C, H, W) to (H, W, C)
augmented_images = [denormalize(img).numpy().transpose(1, 2, 0) for img in augmented_images]

# Plot the original image and its augmented versions
plt.figure(figsize=(15, 5))

# Plot the original image
plt.subplot(1, 6, 1)
plt.imshow(image)
plt.axis('off')
plt.title('Original Image')

# Plot the augmented images
for i in range(5):
    plt.subplot(1, 6, i + 2)
    plt.imshow(augmented_images[i])
    plt.axis('off')
    plt.title(f'Augmented {i+1}')

plt.show()


In [None]:
def display_samples(dataset, num_samples=4):
    fig, axes = plt.subplots(1, num_samples, figsize=(15, 5))

    for i in range(num_samples):
        image, age, metadata = dataset[i]  # Get the image and metadata
        image = image.numpy().transpose((1, 2, 0))  # Convert tensor to numpy array (H, W, C)

        # Denormalize the image (assuming the original image was normalized between -1 and 1)
        image = image * np.array([0.5]) + np.array([0.5])  # Denormalize

        axes[i].imshow(image)
        axes[i].axis("off")

        # Ensure metadata is properly converted to strings
        metadata_str = [str(m) for m in metadata]  # Convert all metadata to strings

        # Display age and metadata
        axes[i].set_title(f"Age: {age * dataset.__normalization_factor__():.2f}\n{', '.join(metadata_str)}")

    plt.show()


In [None]:
display_samples(aug_dataset_train, num_samples=4)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load the updated CSV file
df = pd.read_csv("augmented_labels_metadata_train.csv")

# Plot the age distribution using a histogram
plt.figure(figsize=(10, 6))
plt.hist(df['age'], bins=30, color='blue', alpha=0.7)
plt.title('Age Distribution after augmenting')
plt.xlabel('Age')
plt.ylabel('Frequency')
plt.show()

In [None]:
# Plot the distribution of Ethnicity (as an example of metadata)
plt.figure(figsize=(10, 6))
ethnicity_valid_values = ['asian', 'caucasian', 'afroamerican']
df = df[df['ethnicity'].isin(ethnicity_valid_values)]
ethnicity_counts = df['ethnicity'].value_counts()
plt.bar(ethnicity_counts.index, ethnicity_counts.values, color='orange', alpha=0.7)
plt.title('Ethnicity Distribution after augmenting')
plt.xlabel('Ethnicity')
plt.ylabel('Frequency')
plt.xticks(rotation=45)
plt.show()

In [None]:
# Plot the distribution of gender (as an example of metadata)
df = pd.read_csv("augmented_labels_metadata_train.csv")

plt.figure(figsize=(10, 6))
gender_valid_values = ['female', 'male']
df = df[df['gender'].isin(gender_valid_values)]
gender_counts = df['gender'].value_counts()
plt.bar(gender_counts.index, gender_counts.values, color='green', alpha=0.7)
plt.title('Gender Distribution after augmenting')
plt.xlabel('Gender')
plt.ylabel('Frequency')
plt.xticks(rotation=45)
plt.show()

In [None]:
# Plot the distribution of emotion (as an example of metadata)
plt.figure(figsize=(10, 6))
emotion_valid_values = ['slightlyhappy', 'happy', 'neutral', 'other']
df = df[df['emotion'].isin(emotion_valid_values)]
emotion_counts = df['emotion'].value_counts()
plt.bar(emotion_counts.index, emotion_counts.values, color='red', alpha=0.7)
plt.title('Emotion Distribution after augmenting')
plt.xlabel('Emotion')
plt.ylabel('Frequency')
plt.xticks(rotation=45)
plt.show()

In [None]:
# Vision Transformer Model for Age Prediction (pretrained on ImageNet)
# https://pytorch.org/vision/main/models/vision_transformer.html
# https://huggingface.co/docs/transformers/main/en//model_doc/vit
class AgeEstimationViT(nn.Module):
    def __init__(self):
        super(AgeEstimationViT, self).__init__()
        self.vit = create_model("vit_base_patch16_224", pretrained=True, num_classes=0) #remove classifier head

        in_features = self.vit.num_features

        #regression head
        self.regressor = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.vit.forward_features(x) #image passed through model to extract features
        x = x[:,0] #first token (represents whole image) is selected
        x = self.regressor(x) #this token is passed through regression head which processes the features and returns age prediction
        return x

In [None]:
# creating the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AgeEstimationViT().to(device)

In [None]:
# print model summary
from torchinfo import summary
summary(model, input_size=(1, 3, 224, 224))  # Adjust based on your model

In [None]:
# Function to plot training curves
def plot_training_curves(train_losses, val_losses):
    plt.figure(figsize=(8, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training & Validation Loss')
    plt.legend()
    plt.grid()
    plt.show()

### Defining the training function

In [None]:
# Training function with early stopping and model saving
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs, patience, model_path):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float("inf")
    early_stopping_counter = 0
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0

            with tqdm(dataloaders[phase], desc=f"{phase.capitalize()} Epoch {epoch+1}") as t:
                for inputs, labels, _ in t:
                    inputs, labels = inputs.to(device), labels.to(device)

                    optimizer.zero_grad()

                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        loss = criterion(outputs.view_as(labels), labels)

                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    running_loss += loss.item() * inputs.size(0)
                    t.set_postfix(loss=loss.item())

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            print(f'{phase} Loss: {epoch_loss:.6f}')

            if phase == 'train':
                train_losses.append(epoch_loss)
            else:
                val_losses.append(epoch_loss)
                scheduler.step(epoch_loss)

                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    early_stopping_counter = 0
                    # Save best model during training
                    print("saving best model...")
                    torch.save(best_model_wts, model_path)
                else:
                    early_stopping_counter += 1
                    if early_stopping_counter >= patience:
                        print("Early stopping triggered.")
                        model.load_state_dict(best_model_wts)
                        plot_training_curves(train_losses, val_losses)
                        return model

    model.load_state_dict(best_model_wts)
    plot_training_curves(train_losses, val_losses)
    return model

# Train the model

In [None]:
#==================
MODEL_TRAIN = True
#==================

# model hyperparameters
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=2)
num_epochs = 10
patience = 10

# data loaders
dataloaders = {"train": aug_dataloader_train, "val": dataloader_valid}  # Assuming split dataset

if (MODEL_TRAIN):
  #if(MOUNT_GOOGLE_DRIVE):
    #best_model = train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs, patience,"/content/best_age_estimation_model.pth")
  best_model = train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs, patience,"best_age_estimation_model.pth")
else:
  # download the pretrained model
  !wget http://data.chalearnlap.cvc.uab.cat/Colab_MFPDS/2025/best_age_estimation_model.zip

  # decompressing the data
  from zipfile import ZipFile

  with ZipFile('best_age_estimation_model.zip','r') as zip:
    zip.extractall()
    print('Model decompressed successfully')

  # removing the .zip file after extraction to clean space
  !rm best_age_estimation_model.zip


# Making predictions on Validation set


In [None]:
# Function to make predictions on test set and compute MSE
def predict_and_evaluate(model_path, test_dataset, batch_size=32, output_csv="predictions.csv", output_zip="predictionsViTbaseline.zip"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AgeEstimationViT().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
    model.eval()

    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    predictions = []
    actual_ages = []
    metadatas = []

    with torch.no_grad():
        for images, labels, metadata in tqdm(test_loader, desc="Predicting"):
            images = images.to(device)
            outputs = model(images).squeeze().cpu().numpy()
            labels = labels.cpu().numpy()

            predictions.extend(outputs*test_dataset.__normalization_factor__())
            actual_ages.extend(labels*test_dataset.__normalization_factor__())

        for idx in tqdm(range(len(test_dataset))):
            image, age, metadata = test_dataset[idx]
            metadata = test_dataset.data_info.iloc[idx, 2:].tolist()
            metadatas.append(metadata)

    mae = np.mean(np.abs(np.array(predictions) - np.array(actual_ages)))
    print(f"\n=======\nMean Absolute Error on Test Set: {mae:.4f}")

    # Save only predictions to CSV without headers
    with open(output_csv, mode='w', newline='') as file:
        writer = csv.writer(file)
        for pred in predictions:
            writer.writerow([pred])

    # Zip the CSV file using ZipFile
    with ZipFile(output_zip, 'w') as zipf:
        zipf.write(output_csv, os.path.basename(output_csv))

    print(f"Predictions saved to {output_csv} and compressed as {output_zip}")

    return predictions, actual_ages, mae, metadatas

In [None]:
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
!wget http://data.chalearnlap.cvc.uab.cat/Colab_MFPDS/bias_functions.py

In [None]:
from bias_functions import age_bias, gender_bias, ethnicity_bias, face_expression_bias

def calculate_bias(model, dataset, predictions, actual_ages, metadata):
    # Calculate bias scores
    age_bias(predictions, actual_ages)
    gender_bias(predictions, actual_ages, metadata)
    ethnicity_bias(predictions, actual_ages, metadata)
    face_expression_bias(predictions, actual_ages, metadata)

In [None]:
dataset_valid = AgeEstimationDataset("valid_data", "labels_metadata_valid.csv", transform=data_transforms)

In [None]:
predictions_val, actual_ages_val, mae_val, metadata_val = predict_and_evaluate("/content/best_age_estimation_model.pth", dataset_valid)


In [None]:
calculate_bias("/content/best_age_estimation_model.pth", dataset_valid, predictions_val, actual_ages_val, metadata_val)

# Making predictions on Test set
The following cells are generating predictions (and evaluating them) on the **Test set** so that we can create our submission file to be uploaded to our age estimation challenge. **IMPORTANT:** <font color='red'>**Do not evaluate your model on the Test set when defining your training strategy, model or hyperparameters.</font> For this, use the Validation set.**

In [None]:
# Create dataset and dataloader (test set):
dataset_test = AgeEstimationDataset("test_data", "labels_metadata_test.csv", transform=data_transforms)


In [None]:
predictions, actual_ages, mae, metadata = predict_and_evaluate("/content/best_age_estimation_model.pth", dataset_test)


In [None]:
calculate_bias("/content/best_age_estimation_model.pth", dataset_test, predictions, actual_ages, metadata)

#Custom Loss (non-augmented data)

In [None]:
class AgeEstimationDataset(Dataset):
    def __init__(self, image_dir, csv_file, transform=None):
        self.image_dir = image_dir
        self.data_info = pd.read_csv(csv_file)
        self.transform = transform
        self.age_normalization_factor = 100; # used to normalize age labels

    def __len__(self):
        return len(self.data_info)

    def __normalization_factor__(self):
        return self.age_normalization_factor

    def __getitem__(self, idx):
        image_id = f"{self.data_info.iloc[idx, 0]:06d}.jpg"  # Format image ID
        image_path = os.path.join(self.image_dir, image_id)
        image = Image.open(image_path).convert("RGB")  # Load image as RGB

        # normalizing age labes (by 100) to be between 0 and 1 (assuming 100 is the max age)
        age = float(self.data_info.iloc[idx, 1])/self.age_normalization_factor  # Load age label
        metadata = self.data_info.iloc[idx, 2:].tolist()  # Extract metadata as list
        ethnicity = self.data_info.iloc[idx, 2]

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(age, dtype=torch.float32), metadata

In [None]:
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
# Create dataset and dataloader (train set):
dataset_train = AgeEstimationDataset("train_data", "labels_metadata_train.csv", transform=data_transforms)
dataloader_train = DataLoader(dataset_train, batch_size=32, shuffle=True, num_workers=2)
print(f"Total number of train samples: {len(dataloader_train.dataset)}")

# Create dataset and dataloader (validation set):
dataset_valid = AgeEstimationDataset("valid_data", "labels_metadata_valid.csv", transform=data_transforms)
dataloader_valid = DataLoader(dataset_valid, batch_size=32, shuffle=True, num_workers=2)
print(f"Total number of valid samples: {len(dataloader_valid.dataset)}")

In [None]:
df = pd.read_csv("labels_metadata_train.csv")
print(df['ethnicity'].value_counts())  # Check the distribution of ethnicity in the original dataset
print(df['gender'].value_counts())
print(df['emotion'].value_counts())
over_60_count = df[df['age'] > 60].shape[0]
print(f"Number of people over 60: {over_60_count}")

In [None]:
group_counts = {'caucasian' : 3522, 'asian' : 424, 'afroamerican' : 119, 'male' : 2068, 'female' : 1997, 'slightlyhappy' : 1784, 'neutral' : 1404, 'happy' : 712, 'other' : 165}
#group_labels = ['caucasian','asian', 'afroamerican', 'male', 'female', 'slightlyhappy', 'neutral', 'happy', 'other']

def customized_mse_loss_batch(y_true, y_pred, group_counts, metadata):
    """
    Customized MSE loss that assigns higher weight to underrepresented groups.

    Parameters:
    - y_true: Tensor of true values (e.g., true ages) of shape (batch_size,)
    - y_pred: Tensor of predicted values (e.g., predicted ages) of shape (batch_size,)
    - group_labels: Tensor or list of group labels corresponding to each sample in the batch
    - group_counts: Dictionary with counts of each group in the entire dataset

    Returns:
    - Weighted MSE loss for the batch
    """
    # Convert to float if necessary
    y_true = y_true.float()
    y_pred = y_pred.float()

    # Compute the Mean Squared Error (MSE) loss for each sample
    mse_loss = (y_pred - y_true) ** 2

    # Create a weight vector based on the group frequencies (inverse frequency weighting)
    # give higher weights to underrepresented groups
    group_weights = []
    #for label in group_labels:
    for i in range(len(metadata[0])):
        label = metadata[1][i]
        group_size = group_counts.get(label, 1)  # Use 1 to avoid division by zero
        group_weight = 1 / group_size  # Inverse of the frequency (lower frequency -> higher weight)
        group_weights.append(group_weight)

    # Convert group_weights to a tensor
    weight_vec = torch.tensor(group_weights, dtype=torch.float32, device = y_true.device)

    # Apply the weight to the MSE loss
    weighted_mse_loss = mse_loss * weight_vec

    # Return the mean of the weighted loss
    return weighted_mse_loss.mean()

In [None]:
#train model with custom loss that accounts for both age and ethnicity weights
def train_model_with_custom_loss(model, dataloaders, optimizer, scheduler, num_epochs, patience, model_path):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float("inf")
    early_stopping_counter = 0
    train_losses = []
    val_losses = []
    epoch_threshold = 3
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0

            with tqdm(dataloaders[phase], desc=f"{phase.capitalize()} Epoch {epoch+1}") as t:

                for inputs, labels, metadata in t:
                    inputs, labels = inputs.to(device), labels.to(device)
                    genders = metadata[0]
                    ethnicities = metadata[1]
                    expressions = metadata[2]

                    optimizer.zero_grad()

                    if isinstance(ethnicities, tuple):
                        ethnicities = list(ethnicities)

                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)

                        # Calculate sample weights based on age and ethnicity
                        #sample_weights = calculate_sample_weights(labels.cpu().detach().numpy(), ethnicities, expressions, genders, 4)

                        #print(f'Sample Weights: {sample_weights}')

                        #sample_weights = torch.tensor(sample_weights, dtype=torch.float32, device=inputs.device)

                        # Calculate weighted MSE loss
                        loss = customized_mse_loss_batch(labels, outputs.view_as(labels), group_counts, metadata)
                        #loss = weighted_mse_loss(outputs.view_as(labels), labels, sample_weights)
                        # Apply custom loss starting from a certain epoch

                        '''if epoch > epoch_threshold:
                            # Custom loss applied after epoch_threshold
                            loss = customized_mse_loss_batch(labels, outputs.view_as(labels), group_labels, group_counts)
                        else:
                            # Default loss (e.g., MSE) before epoch_threshold
                            loss = criterion(outputs.view_as(labels), labels)'''

                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    running_loss += loss.item() * inputs.size(0)
                    t.set_postfix(loss=loss.item())

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            print(f'{phase} Loss: {epoch_loss:.6f}')


            if phase == 'train':
                train_losses.append(epoch_loss)
            else:
                val_losses.append(epoch_loss)
                scheduler.step(epoch_loss)

                # Save the best model during training
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    early_stopping_counter = 0
                    print("Saving best model...")
                    torch.save(best_model_wts, model_path)
                else:
                    early_stopping_counter += 1
                    if early_stopping_counter >= patience:
                        print("Early stopping triggered.")
                        model.load_state_dict(best_model_wts)
                        plot_training_curves(train_losses, val_losses)
                        return model


    print(f'Predictions: {outputs}, Labels: {labels}')
    # Load best model weights and plot training curves
    model.load_state_dict(best_model_wts)
    plot_training_curves(train_losses, val_losses)
    return model


In [None]:
#==================
MODEL_TRAIN = True
#==================

# model hyperparameters
#criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=2)
num_epochs =10
patience = 10

# data loaders
dataloaders = {"train": dataloader_train, "val": dataloader_valid}  # Assuming split dataset

if (MODEL_TRAIN):
  #if(MOUNT_GOOGLE_DRIVE):
    #best_model = train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs, patience,"/content/best_age_estimation_model.pth")
  best_model = train_model_with_custom_loss(model, dataloaders, optimizer, scheduler, num_epochs, patience,"best_age_estimation_model_customloss.pth")
else:
  # download the pretrained model
  !wget http://data.chalearnlap.cvc.uab.cat/Colab_MFPDS/2025/best_age_estimation_model.zip

  # decompressing the data
  from zipfile import ZipFile

  with ZipFile('best_age_estimation_model_customloss.zip','r') as zip:
    zip.extractall()
    print('Model decompressed successfully')

  # removing the .zip file after extraction to clean space
  !rm best_age_estimation_model.zip


# Making predictions on Validation set (w/ custom loss)

In [None]:
dataset_valid = AgeEstimationDataset("valid_data", "labels_metadata_valid.csv", transform=data_transforms)


In [None]:
predictions_val, actual_ages_val, mae_val, metadata_val = predict_and_evaluate("/content/best_age_estimation_model_customloss.pth", dataset_valid)
print(predictions_val)
print(actual_ages_val)

In [None]:
calculate_bias("/content/best_age_estimation_model_customloss.pth", dataset_valid, predictions_val, actual_ages_val, metadata_val)

# Making predictions on Test set (custom loss)

In [None]:
dataset_test = AgeEstimationDataset("test_data", "labels_metadata_test.csv", transform=data_transforms)


In [None]:
predictions, actual_ages, mae, metadata = predict_and_evaluate("/content/best_age_estimation_model_customloss.pth", dataset_test)


In [None]:
calculate_bias("/content/best_age_estimation_model_customloss.pth", dataset_test, predictions, actual_ages, metadata)

#Custom Loss with augmented data

#Augmentating

In [None]:
import pandas as pd
import os
from PIL import Image
from torchvision import transforms
from tqdm import tqdm  # for progress bar (optional)
import torch

class AgeEstimationDataset(Dataset):
    def __init__(self, image_dir, csv_file, transform=None, augmentation_transform=None):
        self.image_dir = image_dir
        self.data_info = pd.read_csv(csv_file)
        self.transform = transform
        self.augmentation_transform = augmentation_transform  # Optional augmentation transform
        self.age_normalization_factor = 100  # Used to normalize age labels

    def __len__(self):
        return len(self.data_info)

    def __normalization_factor__(self):
        return self.age_normalization_factor

    def __getitem__(self, idx):
        image_id = f"{self.data_info.iloc[idx, 0]:06d}.jpg"  # Format image ID
        image_path = os.path.join(self.image_dir, image_id)
        image = Image.open(image_path).convert("RGB")  # Load image as RGB

        # Normalize age labels (by 100) to be between 0 and 1 (assuming 100 is the max age)
        age = float(self.data_info.iloc[idx, 1]) / self.age_normalization_factor  # Load age label
        metadata = self.data_info.iloc[idx, 2:].tolist()  # Extract metadata as list
        ethnicity = metadata[1]
        age_value = int(self.data_info.iloc[idx, 1])

        # Apply augmentation based on conditions (ethnicity or age > 60)
        if (ethnicity in ['afroamerican', 'asian']) or (age_value > 60):
            if self.augmentation_transform:
                image = self.augmentation_transform(image)

        # Apply any additional transforms (like normalization)
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(age, dtype=torch.float32), metadata


def save_augmented_images_and_update_csv(dataset, augmentation_transforms, output_dir, csv_file, num_augmented_images=3):
    # Create a new directory for augmented images
    os.makedirs(output_dir, exist_ok=True)

    # Load the original CSV
    df = pd.read_csv(csv_file)


    # Augment and save images
    augmented_data = []
    for idx in tqdm(range(len(dataset)), desc="Processing images"):
        image, age, metadata = dataset[idx]  # Get the original image and other data
        original_image_id = dataset.data_info.iloc[idx, 0]
        metadata = dataset.data_info.iloc[idx, 2:].tolist()
        gender = metadata[0]
        ethnicity = metadata[1]
        expression = metadata[2]
        age_value = int(dataset.data_info.iloc[idx, 1])

        # Apply augmentation only if the condition is met (ethnicity or age > 60)
        if (ethnicity in ['afroamerican', 'asian']) or (age_value > 60):
            # Apply augmentation multiple times (num_augmented_images times)
            for i in range(num_augmented_images):
                augmented_image = augmentation_transforms(image)  # Apply the augmentation

                augmented_image_pil = transforms.ToPILImage()(augmented_image)

                # Save the augmented image
                augmented_image_name = f"augmented_{original_image_id}_{i}.jpg"
                new_image_path = os.path.join(output_dir, augmented_image_name)
                augmented_image_pil.save(new_image_path)

                # Add new entry to augmented data
                #augmented_data.append([original_image_id, age.item()] + [metadata] + [new_image_path])
                augmented_data.append([original_image_id, age.item()*100, gender, ethnicity, expression])


    # Append the augmented data to the original dataframe
    augmented_df = pd.DataFrame(augmented_data, columns=df.columns.tolist())

    # Append to the original CSV file without overwriting it
    if not os.path.exists(csv_file):
        df = pd.concat([df, augmented_df], ignore_index=True)
        df.to_csv(csv_file, index=False)
    else:
        # If file exists, append to a new file to avoid overwriting
        new_csv_file = f"augmented_{os.path.basename(csv_file)}"
        df = pd.concat([df, augmented_df], ignore_index=True)
        df.to_csv(new_csv_file, index=False)
        print(f"CSV file already exists. Augmented data saved to new file: {new_csv_file}")

    print(f"CSV file updated with augmented images: {csv_file}")


# Define the augmentation transforms
augmentation_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
])

# Define the standard transforms for all data
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
}

# Assume the original dataset and transform
dataset_train = AgeEstimationDataset(
    image_dir="train_data",  # Directory where original images are stored
    csv_file="labels_metadata_train.csv",  # Original CSV file
    transform=data_transforms['train'],  # You can apply the transforms to training data here
    augmentation_transform=None  # Augmentation transform will be applied conditionally in the `__getitem__` method
)

# Directory to save augmented images
output_dir = "augmented_train_data"

# Number of augmented images to generate per original image
num_augmented_images = 3  # For example, generate 3 augmented images for each original image

# Update CSV with augmented images
save_augmented_images_and_update_csv(dataset_train, augmentation_transforms, output_dir, "labels_metadata_train.csv", num_augmented_images)


In [None]:
# Create dataset and dataloader (train set):
aug_dataset_train = AgeEstimationDataset("train_data", "augmented_labels_metadata_train.csv", transform=data_transforms['train'])
aug_dataloader_train = DataLoader(aug_dataset_train, batch_size=32, shuffle=True, num_workers=2)
print(f"Total number of train samples: {len(aug_dataloader_train.dataset)}")

# Create dataset and dataloader (validation set):
dataset_valid = AgeEstimationDataset("valid_data", "labels_metadata_valid.csv", transform=data_transforms['val'])
dataloader_valid = DataLoader(dataset_valid, batch_size=32, shuffle=True, num_workers=2)
print(f"Total number of valid samples: {len(dataloader_valid.dataset)}")

In [None]:
#==================
MODEL_TRAIN = True
#==================

# model hyperparameters
#criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=2)
num_epochs =10
patience = 10

# data loaders
dataloaders = {"train": aug_dataloader_train, "val": dataloader_valid}  # Assuming split dataset

if (MODEL_TRAIN):
  #if(MOUNT_GOOGLE_DRIVE):
    #best_model = train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs, patience,"/content/best_age_estimation_model.pth")
  best_model = train_model_with_custom_loss(model, dataloaders, optimizer, scheduler, num_epochs, patience,"best_age_estimation_model_customloss_augmentations.pth")
else:
  # download the pretrained model
  !wget http://data.chalearnlap.cvc.uab.cat/Colab_MFPDS/2025/best_age_estimation_model.zip

  # decompressing the data
  from zipfile import ZipFile

  with ZipFile('best_age_estimation_model_customloss_augmenations.zip','r') as zip:
    zip.extractall()
    print('Model decompressed successfully')

  # removing the .zip file after extraction to clean space
  !rm best_age_estimation_model.zip


# Making predictions on Validation set (w/ custom loss and augmentations)

In [None]:
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
dataset_valid = AgeEstimationDataset("valid_data", "labels_metadata_valid.csv", transform=data_transforms)


In [None]:
predictions_val, actual_ages_val, mae_val, metadata_val = predict_and_evaluate("/content/best_age_estimation_model_customloss_augmentations.pth", dataset_valid)
print(predictions_val)
print(actual_ages_val)

In [None]:
calculate_bias("/content/best_age_estimation_model_customloss_augmentations.pth", dataset_valid, predictions_val, actual_ages_val, metadata_val)

# Making predictions on Test set (custom loss and augmentations)

In [None]:
dataset_test = AgeEstimationDataset("test_data", "labels_metadata_test.csv", transform=data_transforms)


In [None]:
predictions, actual_ages, mae, metadata = predict_and_evaluate("/content/best_age_estimation_model_customloss_augmentations.pth", dataset_test)


In [None]:
calculate_bias("/content/best_age_estimation_model_customloss_augmentations.pth", dataset_test, predictions, actual_ages, metadata)