# Question 2

In [None]:
import itertools
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset
from IPython.display import display
from sklearn.model_selection import train_test_split
from torchvision import transforms

from models.multi_task_auto_encoder import MultiTaskAutoEncoder
from training_testing.training_mtautoencoder import train_mtautoencoder
from utils.data_loader import CustomMTImageDataset
from utils.transformations_v2 import augment_dataset_with_replacement, resize_dataset

In [None]:
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

## Data Preprocessing

In [None]:
dataset = load_dataset("valhalla/emoji-dataset", cache_dir="../data")

### Creating a Subset

In [None]:
expression_categories = [
    "face",
    "vampire",
    "elf",
    "mage",
    "hero",
    "villain",
    "evil monkey",
    "zombie",
    "haircut",
    "juggling",
]

data_subset = dataset["train"].filter(
    lambda example: any(
        category in example["text"] for category in expression_categories
    )
)

In [None]:
print("Number of images related to expression categories:", len(data_subset))
print("Subset example:", data_subset[25]["text"])
sample_image = data_subset[25]["image"]
display(sample_image)

In [None]:
def categorize_emoji(example):
    human_like = [
        "face",
        "superhero",
        "supervillain",
        "mage",
        "vampire",
        "elf",
        "zombie",
        "man",
        "woman",
    ]
    animals_mythical = [
        "cat",
        "dog",
        "monkey",
        "fox",
        "lion",
        "tiger",
        "horse",
        "unicorn",
        "cow",
        "pig",
        "mouse",
        "rabbit",
        "bear",
        "frog",
        "dragon",
    ]

    description = example["text"]
    if any(word in description for word in human_like):
        return {"class": 0}  # Human and Human-like Characters
    elif any(word in description for word in animals_mythical):
        return {"class": 1}  # Animals and Mythical Creatures
    else:
        return {"class": 2}  # Miscellaneous

In [None]:
data_subset = data_subset.map(categorize_emoji)

In [None]:
data_subset[0]

### Splitting the data

Dividing this subset into training, validation and test sets using a 60/20/20 ratio.

In [None]:
total_size = len(data_subset)
train_size = int(0.6 * total_size)
val_size = int(0.2 * total_size)
test_size = total_size - train_size - val_size

In [None]:
images = [item["image"] for item in data_subset]
text = [item["text"] for item in data_subset]
labels = [item["class"] for item in data_subset]

In [None]:
train_images, test_images, train_labels, test_labels = train_test_split(
    images, labels, test_size=test_size, stratify=labels, random_state=random_seed
)
train_images, val_images, train_labels, val_labels = train_test_split(
    train_images,
    train_labels,
    test_size=val_size,
    stratify=train_labels,
    random_state=random_seed,
)

In [None]:
train_dataset = [
    {"image": img, "class": label} for img, label in zip(train_images, train_labels)
]
val_dataset = [
    {"image": img, "class": label} for img, label in zip(val_images, val_labels)
]
test_dataset = [
    {"image": img, "class": label} for img, label in zip(test_images, test_labels)
]

In [None]:
print("Split Size:\n----------")
print("Train dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))
print("Test dataset size:", len(test_dataset))

### Augmenting to 600/200/200

In [None]:
augmentation_transforms = transforms.Compose(
    [
        transforms.RandomRotation(degrees=15),  # Random rotation up to 10 degrees
        transforms.RandomHorizontalFlip(
            p=0.5
        ),  # Random horizontal flip with a probability of 0.5
        transforms.RandomVerticalFlip(
            p=0.5
        ),  # Random vertical flip with a probability of 0.5
        transforms.ColorJitter(
            brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
        ),  # Randomly Adjust brightness, contrast, saturation, and hue
        transforms.RandomAffine(
            degrees=5, translate=(0.1, 0.1)
        ),  # Random affine transformation
        transforms.RandomApply(
            [transforms.GaussianBlur(kernel_size=3)], p=0.1
        ),  # Random Gaussian blur
    ]
)

In [None]:
train_aug = augment_dataset_with_replacement(
    train_dataset, 600, augmentation_transforms
)
val_aug = augment_dataset_with_replacement(val_dataset, 200, augmentation_transforms)
test_aug = augment_dataset_with_replacement(test_dataset, 200, augmentation_transforms)

In [None]:
print("Augmented train dataset size:", len(train_aug))
print("Augmented validation dataset size:", len(val_aug))
print("Augmented test dataset size:", len(test_aug))

### Resizing to 64x64

In [None]:
resize_transform = transforms.Resize((64, 64))

In [None]:
train_aug_resized = resize_dataset(train_aug, resize_transform)
val_aug_resized = resize_dataset(val_aug, resize_transform)
test_aug_resized = resize_dataset(test_aug, resize_transform)

### Sample data

In [None]:
print("Subset class:", train_aug_resized[500]["class"])
sample_image = train_aug_resized[500]["image"]
display(sample_image)

### Tensor Dataset

In [None]:
train_dataset = CustomMTImageDataset(train_aug_resized)
val_dataset = CustomMTImageDataset(val_aug_resized)
test_dataset = CustomMTImageDataset(test_aug_resized)

## Hyperparameters

Do not run the code below unless you want to perform an extensive grid search.

In [None]:
results = []

latent_sizes = [(32, 2048), (64, 4096), (128, 8192), (256, 16384)]
learning_rates = [0.001, 0.0001]
weight_decays = [1e-5, 1e-4]
encoder_channel_options = [(16, 32), (32, 64)]
kernel_sizes = [3]
strides = [2]
paddings = [1]
lambda_classifications = [0.1, 0.5, 1.0]

num_epochs = 350
batch_size = 16

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False
)

In [None]:
for (
    (latent_channels, flattened_size),
    lr,
    weight_decay,
    encoder_channels,
    kernel_size,
    stride,
    padding,
    lambda_classification,
) in itertools.product(
    latent_sizes,
    learning_rates,
    weight_decays,
    encoder_channel_options,
    kernel_sizes,
    strides,
    paddings,
    lambda_classifications,
):
    model = MultiTaskAutoEncoder(
        latent_size=latent_channels,
        flattened_latent_size=flattened_size,
        encoder_channels=encoder_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        num_classes=3,
    )

    (
        _,
        train_mse_losses,
        train_classification_losses,
        train_classification_accuracies,
        val_mse_losses,
        val_classification_losses,
        val_classification_accuracies,
    ) = train_mtautoencoder(
        model,
        train_loader,
        val_loader,
        num_epochs=num_epochs,
        lr=lr,
        weight_decay=weight_decay,
        lambda_classification=lambda_classification,
    )

    plt.figure(figsize=(15, 5))

    # Plot for MSE loss
    plt.subplot(1, 3, 1)
    plt.plot(train_mse_losses, label="Train MSE Loss")
    plt.plot(val_mse_losses, label="Validation MSE Loss")
    plt.title("MSE Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()

    # Plot for Classification loss
    plt.subplot(1, 3, 2)
    plt.plot(train_classification_losses, label="Train Classification Loss")
    plt.plot(val_classification_losses, label="Validation Classification Loss")
    plt.title("Classification Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()

    # Plot for Classification Accuracy
    plt.subplot(1, 3, 3)
    plt.plot(train_classification_accuracies, label="Train Classification Accuracy")
    plt.plot(val_classification_accuracies, label="Validation Classification Accuracy")
    plt.title("Classification Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy (%)")
    plt.legend()

    plt.suptitle(
        f"LS: {latent_channels}, LR: {lr}, WD: {weight_decay}, EC: {encoder_channels}, KS: {kernel_size}, S: {stride}, P: {padding}, LC: {lambda_classification}"
    )
    plt.savefig(
        "learning_curves/question_2/"
        + f"combined_metrics_ls{latent_channels}_lr{lr}_wd{weight_decay}_ec{encoder_channels}_ks{kernel_size}_s{stride}_p{padding}_lc{lambda_classification}.png"
    )
    plt.close()

    # Store results
    results.append(
        {
            "latent_size": latent_channels,
            "learning_rate": lr,
            "weight_decay": weight_decay,
            "encoder_channels": encoder_channels,
            "kernel_size": kernel_size,
            "stride": stride,
            "padding": padding,
            "lambda_classification": lambda_classification,
            "final_train_mse_loss": train_mse_losses[-1],
            "final_val_mse_loss": val_mse_losses[-1],
            "final_train_classification_loss": train_classification_losses[-1],
            "final_val_classification_loss": val_classification_losses[-1],
            "final_train_classification_accuracy": train_classification_accuracies[-1],
            "final_val_classification_accuracy": val_classification_accuracies[-1],
        }
    )

In [None]:
results_df = pd.DataFrame(results)

In [None]:
results_df.to_csv("results/question_2/hyperparam_results.csv", index=False)

In [None]:
results_df.sort_values(by="final_val_mse_loss", ascending=True).head(10)

## Testing Best Model

In [None]:
batch_size = 16
num_epochs = 150

In [None]:
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False
)

In [None]:
latent_size = 256
learning_rate = 0.001
weight_decay = 0.00001
encoder_channels = (32, 64)
kernel_size = 3
stride = 2
padding = 1
lambda_classification = 0.1
flattened_latent_size = 16384

In [None]:
model = MultiTaskAutoEncoder(
    latent_size=latent_size,
    encoder_channels=encoder_channels,
    kernel_size=kernel_size,
    stride=stride,
    padding=padding,
    flattened_latent_size=flattened_latent_size,
)

In [None]:
(
    _,
    train_mse_losses,
    train_classification_losses,
    train_classification_accuracies,
    val_mse_losses,
    val_classification_losses,
    val_classification_accuracies,
) = train_mtautoencoder(
    model,
    train_loader,
    val_loader,
    num_epochs=num_epochs,
    lr=lr,
    weight_decay=weight_decay,
    lambda_classification=lambda_classification,
)

In [None]:
plt.figure(figsize=(15, 5))

# Plot for MSE loss
plt.subplot(1, 3, 1)
plt.plot(train_mse_losses, label="Train MSE Loss")
plt.plot(val_mse_losses, label="Validation MSE Loss")
plt.title("MSE Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

# Plot for Classification loss
plt.subplot(1, 3, 2)
plt.plot(train_classification_losses, label="Train Classification Loss")
plt.plot(val_classification_losses, label="Validation Classification Loss")
plt.title("Classification Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

# Plot for Classification Accuracy
plt.subplot(1, 3, 3)
plt.plot(train_classification_accuracies, label="Train Classification Accuracy")
plt.plot(val_classification_accuracies, label="Validation Classification Accuracy")
plt.title("Classification Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy (%)")
plt.legend()

plt.suptitle(
    f"LS: {latent_channels}, LR: {lr}, WD: {weight_decay}, EC: {encoder_channels}, KS: {kernel_size}, S: {stride}, P: {padding}, LC: {lambda_classification}"
)
plt.savefig(
    "results/question_2/"
    + f"combined_metrics_ls{latent_channels}_lr{lr}_wd{weight_decay}_ec{encoder_channels}_ks{kernel_size}_s{stride}_p{padding}_lc{lambda_classification}.png"
)
plt.close()

In [None]:
model.eval()
test_loss = 0.0
criterion_mse = torch.nn.MSELoss()
criterion_classification = torch.nn.CrossEntropyLoss()

total_test_mse_loss = 0.0
total_test_classification_loss = 0.0
correct_test = 0
total_test = 0

with torch.no_grad():
    for batch in test_loader:
        images, labels = batch["image"], batch["class"]
        reconstructed, classification_logits = model(images)

        loss_mse = criterion_mse(reconstructed, images)
        loss_classification = criterion_classification(classification_logits, labels)

        total_test_mse_loss += loss_mse.item() * images.size(0)
        total_test_classification_loss += loss_classification.item() * labels.size(0)

        _, predicted = torch.max(classification_logits, 1)
        correct_test += (predicted == labels).sum().item()
        total_test += labels.size(0)

test_mse_loss = total_test_mse_loss / len(test_loader.dataset)
test_classification_loss = total_test_classification_loss / len(test_loader.dataset)
test_accuracy = 100 * correct_test / total_test

print(f"Test MSE Loss: {test_mse_loss}")
print(f"Test Classification Loss: {test_classification_loss}")
print(f"Test Accuracy: {test_accuracy}%")

## Saving Best Model

In [None]:
# architecture
from torchsummary import summary

print(summary(model, (3, 64, 64)))

In [None]:
# saving weights
torch.save(model.state_dict(), "results/question_2/q2_model_weights.pth")

In [None]:
# latent representation
model.eval()
latent_representations = []

with torch.no_grad():
    for batch in test_loader:
        images = batch["image"]
        latent = model.encoder(images)
        latent_representations.append(latent.cpu().numpy())

latent_representations = np.concatenate(latent_representations, axis=0)

latent_representations_path = "results/question_2/latent_representations.npy"
np.save(latent_representations_path, latent_representations)

## Sample Image

In [None]:
model.eval()

# take the first image
sample_batch = next(iter(test_loader))
images, labels = sample_batch["image"], sample_batch["class"]
sample_image, true_label = images[0], labels[0]

with torch.no_grad():
    reconstructed, classification_logits = model(sample_image.unsqueeze(0))
    predicted_label = torch.argmax(classification_logits, dim=1)

In [None]:
# convert image
sample_image_np = sample_image.numpy().transpose(1, 2, 0)
reconstructed_np = reconstructed.squeeze(0).numpy().transpose(1, 2, 0)

In [None]:
plt.figure(figsize=(12, 6))

# Original Image
plt.subplot(1, 2, 1)
plt.imshow(sample_image_np)
plt.title(f"Original Image (True Class: {true_label})")

# Reconstructed Image
plt.subplot(1, 2, 2)
plt.imshow(reconstructed_np)
plt.title(f"Reconstructed Image (Predicted Class: {predicted_label.item()})")

plt.show()