<a href="https://colab.research.google.com/github/dimna21/ML_Assignment4/blob/main/FER2013.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wandb



In [None]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdimna21[0m ([33mdimna21-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Load the CSV
csv_path = "/content/drive/MyDrive/FER_data/fer2013/fer2013.csv"
df = pd.read_csv(csv_path)

# Split by Usage
df_train = df[df['Usage']=="Training"].copy()
df_val   = df[df['Usage']=="PublicTest"].copy()
df_test  = df[df['Usage']=="PrivateTest"].copy()

In [None]:
# Balance function: upsampling & random ±10 intensity shifts
def balance_dataset(df, target_count, img_shape=(48,48)):
    def augment(pix_str):
        arr = np.fromstring(pix_str, sep=' ', dtype=int).reshape(img_shape)
        shift = np.random.randint(-10, 11)
        arr = np.clip(arr + shift, 0, 255).astype(int)
        return ' '.join(map(str, arr.ravel()))
    parts = [df]
    for emo, grp in df.groupby('emotion'):
        n = len(grp)
        if n < target_count:
            extra = grp.sample(n=target_count-n, replace=True).copy()
            extra['pixels'] = extra['pixels'].map(augment)
            parts.append(extra)
    return pd.concat(parts, ignore_index=True)


In [None]:
# Dataset class
class FERDataset(Dataset):
    def __init__(self, dataframe):
        self.pixels = dataframe['pixels'].values
        self.labels = dataframe['emotion'].values.astype(int)
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        arr = np.fromstring(self.pixels[idx], sep=' ', dtype=np.uint8).reshape(48,48)
        arr = arr.astype(np.float32) / 255.0
        tensor = torch.from_numpy(arr).unsqueeze(0)  # shape [1,48,48]
        return tensor, self.labels[idx]


In [None]:
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [None]:
# DataLoaders
max_count = df_train['emotion'].value_counts().max()
balanced_train = balance_dataset(df_train, target_count=max_count)

#train_ds = FERDataset(df_train)
train_ds = FERDataset(balanced_train)
val_ds = FERDataset(df_val)
test_ds = FERDataset(df_test)

batch_size = 256
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=2)
val_dl = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2)
test_dl = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=2)

# Class names
class_names = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]

In [None]:
balanced_train['emotion'].value_counts()

Unnamed: 0_level_0,count
emotion,Unnamed: 1_level_1
0,7215
2,7215
4,7215
6,7215
3,7215
5,7215
1,7215


In [None]:
import torch
from tqdm import tqdm
import wandb
from sklearn.metrics import confusion_matrix, f1_score
import matplotlib.pyplot as plt
import numpy as np

def train_model(model, train_loader, val_loader,
                criterion, optimizer, device,
                epochs=5, class_names=None):

    wandb.init(
        project="ML_Assignment4",
        config={
            "epochs": epochs,
            "batch_size": train_loader.batch_size,
            "optimizer": optimizer.__class__.__name__,
            "lr": optimizer.param_groups[0]["lr"],
            "criterion": criterion.__class__.__name__,
        },
    )
    cfg = wandb.config
    wandb.watch(model, log="all", log_freq=100)
    model.to(device)

    train_loss_plot, val_loss_plot = [], []
    train_acc_plot,  val_acc_plot  = [], []

    for epoch in range(1, cfg.epochs + 1):
        # — TRAIN —
        model.train()
        running_loss = running_correct = running_total = 0
        for X, y in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(X)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss    += loss.item() * X.size(0)
            preds            = out.argmax(dim=1)
            running_correct += preds.eq(y).sum().item()
            running_total   += y.size(0)

        train_loss = running_loss / running_total
        train_acc  = running_correct / running_total

        # — VALIDATE —
        model.eval()
        val_running_loss = val_running_correct = val_running_total = 0
        all_preds, all_targets = [], []
        with torch.no_grad():
            for X, y in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
                X, y = X.to(device), y.to(device)
                out = model(X)
                loss = criterion(out, y)

                val_running_loss    += loss.item() * X.size(0)
                preds                = out.argmax(dim=1)
                val_running_correct += preds.eq(y).sum().item()
                val_running_total   += y.size(0)

                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(y.cpu().numpy())

        val_loss = val_running_loss / val_running_total
        val_acc  = val_running_correct / val_running_total

        # — SCALARS & F1 —
        cm = confusion_matrix(all_targets, all_preds)
        f1_per_class = f1_score(all_targets, all_preds, average=None)

        log_data = {
            "train_loss": train_loss,
            "train_acc":  train_acc,
            "val_loss":   val_loss,
            "val_acc":    val_acc,
        }
        for i, name in enumerate(class_names):
            log_data[f"f1_{name}"] = f1_per_class[i]

        wandb.log(log_data, step=epoch)

        # — 7×7 Confusion Matrix Plot & Log —
        fig_cm, ax = plt.subplots(figsize=(6,6))
        im = ax.imshow(cm, interpolation='nearest', cmap='Blues')
        fig_cm.colorbar(im, ax=ax)

        ax.set_xticks(np.arange(len(class_names)))
        ax.set_yticks(np.arange(len(class_names)))
        ax.set_xticklabels(class_names, rotation=45, ha='right')
        ax.set_yticklabels(class_names)

        for i in range(len(class_names)):
            for j in range(len(class_names)):
                ax.text(j, i, cm[i, j],
                        ha='center', va='center')

        ax.set_xlabel('Predicted')
        ax.set_ylabel('Actual')
        ax.set_title(f'Epoch {epoch} Confusion Matrix')

        wandb.log({"confusion_matrix": wandb.Image(fig_cm)}, step=epoch)
        plt.close(fig_cm)

        # — PRINT & STORE FOR CURVES —
        print(
            f"Epoch {epoch}/{cfg.epochs} — "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}  |  "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
        )
        print("-" * 60)

        train_loss_plot.append(train_loss)
        val_loss_plot.append(val_loss)
        train_acc_plot.append(train_acc)
        val_acc_plot.append(val_acc)

    # — PLOT & LOG LOSS/ACC CURVES TO W&B —
    epochs_range = list(range(1, cfg.epochs + 1))

    # Loss curve
    fig1, ax1 = plt.subplots()
    ax1.plot(epochs_range, train_loss_plot, label="Train Loss")
    ax1.plot(epochs_range, val_loss_plot,   label="Val Loss")
    ax1.set(title="Loss vs Epoch", xlabel="Epoch", ylabel="Loss")
    ax1.legend()
    wandb.log({"loss_curve": wandb.Image(fig1)})
    plt.close(fig1)

    # Accuracy curve
    fig2, ax2 = plt.subplots()
    ax2.plot(epochs_range, train_acc_plot, label="Train Acc")
    ax2.plot(epochs_range, val_acc_plot,   label="Val Acc")
    ax2.set(title="Accuracy vs Epoch", xlabel="Epoch", ylabel="Accuracy")
    ax2.legend()
    wandb.log({"acc_curve": wandb.Image(fig2)})
    plt.close(fig2)

    return model


In [None]:
class BaselineModel(nn.Module):
    def __init__(self):
      super().__init__()

      self.conv1 = nn.Conv2d(1, 32, kernel_size = 3, padding = 1)
      self.conv2 = nn.Conv2d(32, 64, kernel_size = 3, padding = 1)
      self.conv3 = nn.Conv2d(64, 128, kernel_size = 3, padding = 1)
      self.pooling = nn.MaxPool2d(2,2)
      self.relu = nn.ReLU()

      self.flatten = nn.Flatten()
      self.linear = nn.Linear((128 * 6 * 6), 128)
      self.output = nn.Linear(128, 7)


    def forward(self, x):
      x = self.conv1(x)
      x = self.pooling(x)
      x = self.relu(x)
      x = self.conv2(x)
      x = self.pooling(x)
      x = self.relu(x)
      x = self.conv3(x)
      x = self.pooling(x)
      x = self.relu(x)
      x = self.flatten(x)
      x = self.linear(x)
      x = self.output(x)

      return x

#Run with dataset balancing:
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/7amufnzs?nw=nwuserdimna21

#Run without dataset balancing:
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/dnyp6qje
#20 epochs instead of 10:
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/0axoxyj9?nw=nwuserdimna21

In [None]:
baseline_model = BaselineModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(baseline_model.parameters(), lr=1e-3)
trained = train_model(
    baseline_model,
    train_dl,
    val_dl,
    criterion,
    optimizer,
    device,
    epochs=20,
    class_names=class_names
)

0,1
f1_Angry,▁▂▅▆▆▆▇▇█▇
f1_Disgust,▁▁▂▄▅▇▅▇██
f1_Fear,▁▃▄▅▅▆▇█▆█
f1_Happy,▁▃▄▇▆▇█▇██
f1_Neutral,▁▂▄▆▇██▇▇█
f1_Sad,▁▅▃▇█▅█▆▇▆
f1_Surprise,▁▃▅▇▇▇▇█▇█
train_acc,▁▃▄▅▆▆▇▇██
train_loss,█▆▅▄▃▃▂▂▁▁
val_acc,▁▄▄▇▇▇████

0,1
f1_Angry,0.4639
f1_Disgust,0.50633
f1_Fear,0.37651
f1_Happy,0.75697
f1_Neutral,0.49241
f1_Sad,0.37192
f1_Surprise,0.71262
train_acc,0.67275
train_loss,0.89268
val_acc,0.54416


Epoch 1 [Train]: 100%|██████████| 225/225 [00:06<00:00, 34.76it/s]
Epoch 1 [Val]: 100%|██████████| 29/29 [00:01<00:00, 25.32it/s]


Epoch 1/20 — Train Loss: 1.6812, Train Acc: 0.3272  |  Val Loss: 1.5147, Val Acc: 0.4235
------------------------------------------------------------


Epoch 2 [Train]: 100%|██████████| 225/225 [00:05<00:00, 42.10it/s]
Epoch 2 [Val]: 100%|██████████| 29/29 [00:00<00:00, 46.90it/s]


Epoch 2/20 — Train Loss: 1.4493, Train Acc: 0.4449  |  Val Loss: 1.3864, Val Acc: 0.4762
------------------------------------------------------------


Epoch 3 [Train]: 100%|██████████| 225/225 [00:07<00:00, 28.47it/s]
Epoch 3 [Val]: 100%|██████████| 29/29 [00:00<00:00, 33.85it/s]


Epoch 3/20 — Train Loss: 1.3132, Train Acc: 0.5038  |  Val Loss: 1.3009, Val Acc: 0.5032
------------------------------------------------------------


Epoch 4 [Train]: 100%|██████████| 225/225 [00:04<00:00, 54.99it/s]
Epoch 4 [Val]: 100%|██████████| 29/29 [00:00<00:00, 50.87it/s]


Epoch 4/20 — Train Loss: 1.2133, Train Acc: 0.5430  |  Val Loss: 1.2520, Val Acc: 0.5319
------------------------------------------------------------


Epoch 5 [Train]: 100%|██████████| 225/225 [00:04<00:00, 54.86it/s]
Epoch 5 [Val]: 100%|██████████| 29/29 [00:00<00:00, 48.91it/s]


Epoch 5/20 — Train Loss: 1.1404, Train Acc: 0.5729  |  Val Loss: 1.2164, Val Acc: 0.5467
------------------------------------------------------------


Epoch 6 [Train]: 100%|██████████| 225/225 [00:06<00:00, 36.63it/s]
Epoch 6 [Val]: 100%|██████████| 29/29 [00:00<00:00, 50.79it/s]


Epoch 6/20 — Train Loss: 1.0759, Train Acc: 0.5992  |  Val Loss: 1.2318, Val Acc: 0.5389
------------------------------------------------------------


Epoch 7 [Train]: 100%|██████████| 225/225 [00:04<00:00, 55.49it/s]
Epoch 7 [Val]: 100%|██████████| 29/29 [00:00<00:00, 52.67it/s]


Epoch 7/20 — Train Loss: 1.0205, Train Acc: 0.6215  |  Val Loss: 1.2220, Val Acc: 0.5500
------------------------------------------------------------


Epoch 8 [Train]: 100%|██████████| 225/225 [00:04<00:00, 48.70it/s]
Epoch 8 [Val]: 100%|██████████| 29/29 [00:01<00:00, 27.93it/s]


Epoch 8/20 — Train Loss: 0.9569, Train Acc: 0.6484  |  Val Loss: 1.2517, Val Acc: 0.5464
------------------------------------------------------------


Epoch 9 [Train]: 100%|██████████| 225/225 [00:05<00:00, 43.54it/s]
Epoch 9 [Val]: 100%|██████████| 29/29 [00:00<00:00, 48.15it/s]


Epoch 9/20 — Train Loss: 0.9012, Train Acc: 0.6701  |  Val Loss: 1.2355, Val Acc: 0.5469
------------------------------------------------------------


Epoch 10 [Train]: 100%|██████████| 225/225 [00:04<00:00, 55.54it/s]
Epoch 10 [Val]: 100%|██████████| 29/29 [00:00<00:00, 53.94it/s]


Epoch 10/20 — Train Loss: 0.8462, Train Acc: 0.6903  |  Val Loss: 1.2926, Val Acc: 0.5517
------------------------------------------------------------


Epoch 11 [Train]: 100%|██████████| 225/225 [00:06<00:00, 33.39it/s]
Epoch 11 [Val]: 100%|██████████| 29/29 [00:00<00:00, 38.83it/s]


Epoch 11/20 — Train Loss: 0.7881, Train Acc: 0.7095  |  Val Loss: 1.3603, Val Acc: 0.5414
------------------------------------------------------------


Epoch 12 [Train]: 100%|██████████| 225/225 [00:04<00:00, 54.46it/s]
Epoch 12 [Val]: 100%|██████████| 29/29 [00:00<00:00, 53.20it/s]


Epoch 12/20 — Train Loss: 0.7348, Train Acc: 0.7298  |  Val Loss: 1.3331, Val Acc: 0.5578
------------------------------------------------------------


Epoch 13 [Train]: 100%|██████████| 225/225 [00:04<00:00, 55.74it/s]
Epoch 13 [Val]: 100%|██████████| 29/29 [00:00<00:00, 37.73it/s]


Epoch 13/20 — Train Loss: 0.6840, Train Acc: 0.7514  |  Val Loss: 1.4282, Val Acc: 0.5553
------------------------------------------------------------


Epoch 14 [Train]: 100%|██████████| 225/225 [00:05<00:00, 38.07it/s]
Epoch 14 [Val]: 100%|██████████| 29/29 [00:00<00:00, 49.62it/s]


Epoch 14/20 — Train Loss: 0.6396, Train Acc: 0.7645  |  Val Loss: 1.5129, Val Acc: 0.5472
------------------------------------------------------------


Epoch 15 [Train]: 100%|██████████| 225/225 [00:04<00:00, 55.25it/s]
Epoch 15 [Val]: 100%|██████████| 29/29 [00:00<00:00, 50.40it/s]


Epoch 15/20 — Train Loss: 0.5856, Train Acc: 0.7889  |  Val Loss: 1.5878, Val Acc: 0.5344
------------------------------------------------------------


Epoch 16 [Train]: 100%|██████████| 225/225 [00:05<00:00, 43.86it/s]
Epoch 16 [Val]: 100%|██████████| 29/29 [00:01<00:00, 28.67it/s]


Epoch 16/20 — Train Loss: 0.5397, Train Acc: 0.8020  |  Val Loss: 1.6771, Val Acc: 0.5419
------------------------------------------------------------


Epoch 17 [Train]: 100%|██████████| 225/225 [00:04<00:00, 47.47it/s]
Epoch 17 [Val]: 100%|██████████| 29/29 [00:00<00:00, 47.57it/s]


Epoch 17/20 — Train Loss: 0.4986, Train Acc: 0.8182  |  Val Loss: 1.7001, Val Acc: 0.5472
------------------------------------------------------------


Epoch 18 [Train]: 100%|██████████| 225/225 [00:04<00:00, 47.82it/s]
Epoch 18 [Val]: 100%|██████████| 29/29 [00:00<00:00, 47.82it/s]


Epoch 18/20 — Train Loss: 0.4636, Train Acc: 0.8318  |  Val Loss: 1.8411, Val Acc: 0.5414
------------------------------------------------------------


Epoch 19 [Train]: 100%|██████████| 225/225 [00:06<00:00, 35.93it/s]
Epoch 19 [Val]: 100%|██████████| 29/29 [00:00<00:00, 50.49it/s]


Epoch 19/20 — Train Loss: 0.4270, Train Acc: 0.8452  |  Val Loss: 1.8972, Val Acc: 0.5511
------------------------------------------------------------


Epoch 20 [Train]: 100%|██████████| 225/225 [00:04<00:00, 54.63it/s]
Epoch 20 [Val]: 100%|██████████| 29/29 [00:00<00:00, 48.18it/s]


Epoch 20/20 — Train Loss: 0.3815, Train Acc: 0.8583  |  Val Loss: 2.1096, Val Acc: 0.5411
------------------------------------------------------------


In [None]:
class ImprovedModel(nn.Module):
    def __init__(self, dropout_p=0.5):
        super().__init__()
        # — conv backbone with 4 pooling stages —
        self.conv1    = nn.Conv2d(1,   32, 3, padding=1, bias=False)
        self.bn1      = nn.BatchNorm2d(32)
        self.conv2    = nn.Conv2d(32,  64, 3, padding=1, bias=False)
        self.bn2      = nn.BatchNorm2d(64)
        self.conv3    = nn.Conv2d(64, 128, 3, padding=1, bias=False)
        self.bn3      = nn.BatchNorm2d(128)
        self.conv4    = nn.Conv2d(128,256, 3, padding=1, bias=False)
        self.bn4      = nn.BatchNorm2d(256)

        self.pool     = nn.MaxPool2d(2,2)
        self.relu     = nn.ReLU(inplace=True)
        self.drop_conv= nn.Dropout2d(dropout_p)

        # — two extra conv layers (no further pooling) —
        self.conv5    = nn.Conv2d(256, 512, 3, padding=1, bias=False)
        self.bn5      = nn.BatchNorm2d(512)
        self.conv6    = nn.Conv2d(512, 512, 3, padding=1, bias=False)
        self.bn6      = nn.BatchNorm2d(512)

        # — richer head with 4 FCs —
        self.flatten  = nn.Flatten()
        self.fc1      = nn.Linear(512 * 3 * 3, 512, bias=False)
        self.bn_fc1   = nn.BatchNorm1d(512)
        self.drop1    = nn.Dropout(dropout_p)
        self.fc2      = nn.Linear(512, 256, bias=False)
        self.bn_fc2   = nn.BatchNorm1d(256)
        self.drop2    = nn.Dropout(dropout_p)
        self.fc3      = nn.Linear(256, 128, bias=False)
        self.bn_fc3   = nn.BatchNorm1d(128)
        self.drop3    = nn.Dropout(dropout_p)
        self.fc4      = nn.Linear(128, 64, bias=False)
        self.bn_fc4   = nn.BatchNorm1d(64)
        self.drop4    = nn.Dropout(dropout_p)
        self.output   = nn.Linear(64, 7)

    def forward(self, x):
        # conv blocks 1–4 with pooling
        for conv, bn in [(self.conv1,self.bn1),
                         (self.conv2,self.bn2),
                         (self.conv3,self.bn3),
                         (self.conv4,self.bn4)]:
            x = conv(x)
            x = bn(x)
            x = self.relu(x)
            x = self.pool(x)
            x = self.drop_conv(x)

        # extra conv blocks 5–6 (no more pooling)
        x = self.conv5(x)
        x = self.bn5(x)
        x = self.relu(x)
        x = self.drop_conv(x)

        x = self.conv6(x)
        x = self.bn6(x)
        x = self.relu(x)
        x = self.drop_conv(x)

        # FC head
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.bn_fc1(x)
        x = self.relu(x)
        x = self.drop1(x)

        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = self.relu(x)
        x = self.drop2(x)

        x = self.fc3(x)
        x = self.bn_fc3(x)
        x = self.relu(x)
        x = self.drop3(x)

        x = self.fc4(x)
        x = self.bn_fc4(x)
        x = self.relu(x)
        x = self.drop4(x)

        x = self.output(x)
        return x


#Run with added batchnorms
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/cuxqkkdm?nw=nwuserdimna21
#topped out at 70%/58%

#Run with added batchnorms and dropouts
#10 epochs:  https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/iactnmo6?nw=nwuserdimna21
#20 epochs:  https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/ydlxqyzi?nw=nwuserdimna21
#started overfitting after reaching 62%/60% accuracies on train/validation

#Run with added 4th convolutional block:
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/lqfkix8t
#no improvements, got 62%/60%

#Run with a more complex head of 512->256->128
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/65qgarhf
#slight improvement, got 64%/61%

#Run with added residual blocks
#20 epochs: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/op1fbc6g
#40 epochs: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/cyn9apy4
#no improvement

#Run without residual blocks, with augmented data and increased complexity: 2 extra conv layers and 4 FCs in head
#40 epochs, 0.3 dropout: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/ofsajdfh?nw=nwuserdimna21
#40 epochs, 0.5 dropout: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/itxksjhl
#40 epochs, 0.5 dropout, double batch size (256): https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/x225ygh0

In [None]:
import torch.nn as nn

class ResidualBlock(nn.Module):
    """
    A basic residual block with two 3x3 conv layers and an optional downsample.
    """
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1, dropout_p=0.0):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout2d(dropout_p) if dropout_p > 0 else nn.Identity()

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        out = self.dropout(out)
        return out

class GigaModel(nn.Module):
    def __init__(self, dropout_p=0.3):
        super().__init__()
        # define ReLU for use in head
        self.relu = nn.ReLU(inplace=True)
        # initial conv to expand grayscale to 32 channels
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        # residual layers: doubling channels and downsampling spatially at each stage
        self.layer1 = ResidualBlock(32,  32, stride=1, dropout_p=dropout_p)
        self.layer2 = ResidualBlock(32,  64, stride=2, dropout_p=dropout_p)
        self.layer3 = ResidualBlock(64, 128, stride=2, dropout_p=dropout_p)
        self.layer4 = ResidualBlock(128,256, stride=2, dropout_p=dropout_p)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))

        self.flatten = nn.Flatten()
        self.fc1     = nn.Linear(256 * ResidualBlock.expansion, 512, bias=False)
        self.bn_fc1  = nn.BatchNorm1d(512)
        self.drop1   = nn.Dropout(dropout_p)
        self.fc2     = nn.Linear(512, 256, bias=False)
        self.bn_fc2  = nn.BatchNorm1d(256)
        self.drop2   = nn.Dropout(dropout_p)
        self.fc3     = nn.Linear(256, 128, bias=False)
        self.bn_fc3  = nn.BatchNorm1d(128)
        self.drop3   = nn.Dropout(dropout_p)
        self.output  = nn.Linear(128, 7)

    def forward(self, x):
        # initial conv
        x = self.conv1(x)
        # residual conv backbone with downsamples
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # head
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.bn_fc1(x)
        x = self.relu(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = self.relu(x)
        x = self.drop2(x)
        x = self.fc3(x)
        x = self.bn_fc3(x)
        x = self.relu(x)
        x = self.drop3(x)
        x = self.output(x)
        return x

#training pt1: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/8a4wuiau?nw=nwuserdimna21
#training pt2: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/127aogq7

In [None]:
improved_model = ImprovedModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(improved_model.parameters(), lr=1e-3)
trained = train_model(
    improved_model,
    train_dl,
    val_dl,
    criterion,
    optimizer,
    device,
    epochs=40,
    class_names=class_names
)

Epoch 1 [Train]: 100%|██████████| 198/198 [00:10<00:00, 18.09it/s]
Epoch 1 [Val]: 100%|██████████| 15/15 [00:01<00:00, 14.57it/s]


Epoch 1/40 — Train Loss: 1.9672, Train Acc: 0.1561  |  Val Loss: 1.9162, Val Acc: 0.1914
------------------------------------------------------------


Epoch 2 [Train]: 100%|██████████| 198/198 [00:08<00:00, 22.85it/s]
Epoch 2 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.18it/s]


Epoch 2/40 — Train Loss: 1.9370, Train Acc: 0.1686  |  Val Loss: 1.9108, Val Acc: 0.1962
------------------------------------------------------------


Epoch 3 [Train]: 100%|██████████| 198/198 [00:11<00:00, 16.61it/s]
Epoch 3 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.22it/s]


Epoch 3/40 — Train Loss: 1.9295, Train Acc: 0.1796  |  Val Loss: 1.8984, Val Acc: 0.2126
------------------------------------------------------------


Epoch 4 [Train]: 100%|██████████| 198/198 [00:12<00:00, 16.11it/s]
Epoch 4 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.72it/s]


Epoch 4/40 — Train Loss: 1.9232, Train Acc: 0.1833  |  Val Loss: 1.8858, Val Acc: 0.2187
------------------------------------------------------------


Epoch 5 [Train]: 100%|██████████| 198/198 [00:11<00:00, 17.87it/s]
Epoch 5 [Val]: 100%|██████████| 15/15 [00:01<00:00, 13.64it/s]


Epoch 5/40 — Train Loss: 1.8966, Train Acc: 0.2054  |  Val Loss: 1.8180, Val Acc: 0.2391
------------------------------------------------------------


Epoch 6 [Train]: 100%|██████████| 198/198 [00:09<00:00, 20.40it/s]
Epoch 6 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.68it/s]


Epoch 6/40 — Train Loss: 1.8230, Train Acc: 0.2338  |  Val Loss: 1.7634, Val Acc: 0.3324
------------------------------------------------------------


Epoch 7 [Train]: 100%|██████████| 198/198 [00:08<00:00, 22.87it/s]
Epoch 7 [Val]: 100%|██████████| 15/15 [00:00<00:00, 31.66it/s]


Epoch 7/40 — Train Loss: 1.7721, Train Acc: 0.2648  |  Val Loss: 1.7897, Val Acc: 0.2059
------------------------------------------------------------


Epoch 8 [Train]: 100%|██████████| 198/198 [00:09<00:00, 20.11it/s]
Epoch 8 [Val]: 100%|██████████| 15/15 [00:00<00:00, 23.64it/s]


Epoch 8/40 — Train Loss: 1.7208, Train Acc: 0.3044  |  Val Loss: 1.7883, Val Acc: 0.2235
------------------------------------------------------------


Epoch 9 [Train]: 100%|██████████| 198/198 [00:10<00:00, 18.95it/s]
Epoch 9 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.65it/s]


Epoch 9/40 — Train Loss: 1.6670, Train Acc: 0.3409  |  Val Loss: 1.6652, Val Acc: 0.3463
------------------------------------------------------------


Epoch 10 [Train]: 100%|██████████| 198/198 [00:10<00:00, 18.77it/s]
Epoch 10 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.71it/s]


Epoch 10/40 — Train Loss: 1.6050, Train Acc: 0.3752  |  Val Loss: 1.5551, Val Acc: 0.3834
------------------------------------------------------------


Epoch 11 [Train]: 100%|██████████| 198/198 [00:09<00:00, 19.84it/s]
Epoch 11 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.74it/s]


Epoch 11/40 — Train Loss: 1.5495, Train Acc: 0.4017  |  Val Loss: 1.5007, Val Acc: 0.4246
------------------------------------------------------------


Epoch 12 [Train]: 100%|██████████| 198/198 [00:09<00:00, 20.16it/s]
Epoch 12 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.66it/s]


Epoch 12/40 — Train Loss: 1.5015, Train Acc: 0.4203  |  Val Loss: 1.5068, Val Acc: 0.4310
------------------------------------------------------------


Epoch 13 [Train]: 100%|██████████| 198/198 [00:09<00:00, 19.99it/s]
Epoch 13 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.06it/s]


Epoch 13/40 — Train Loss: 1.4538, Train Acc: 0.4362  |  Val Loss: 1.4427, Val Acc: 0.4575
------------------------------------------------------------


Epoch 14 [Train]: 100%|██████████| 198/198 [00:09<00:00, 21.17it/s]
Epoch 14 [Val]: 100%|██████████| 15/15 [00:01<00:00, 11.99it/s]


Epoch 14/40 — Train Loss: 1.4186, Train Acc: 0.4519  |  Val Loss: 1.4275, Val Acc: 0.4625
------------------------------------------------------------


Epoch 15 [Train]: 100%|██████████| 198/198 [00:09<00:00, 21.75it/s]
Epoch 15 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.76it/s]


Epoch 15/40 — Train Loss: 1.3927, Train Acc: 0.4627  |  Val Loss: 1.4248, Val Acc: 0.4636
------------------------------------------------------------


Epoch 16 [Train]: 100%|██████████| 198/198 [00:09<00:00, 20.07it/s]
Epoch 16 [Val]: 100%|██████████| 15/15 [00:00<00:00, 24.51it/s]


Epoch 16/40 — Train Loss: 1.3641, Train Acc: 0.4727  |  Val Loss: 1.3744, Val Acc: 0.4703
------------------------------------------------------------


Epoch 17 [Train]: 100%|██████████| 198/198 [00:11<00:00, 17.71it/s]
Epoch 17 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.83it/s]


Epoch 17/40 — Train Loss: 1.3348, Train Acc: 0.4862  |  Val Loss: 1.3638, Val Acc: 0.4745
------------------------------------------------------------


Epoch 18 [Train]: 100%|██████████| 198/198 [00:08<00:00, 22.59it/s]
Epoch 18 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.91it/s]


Epoch 18/40 — Train Loss: 1.3119, Train Acc: 0.4916  |  Val Loss: 1.3919, Val Acc: 0.4642
------------------------------------------------------------


Epoch 19 [Train]: 100%|██████████| 198/198 [00:09<00:00, 20.59it/s]
Epoch 19 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.50it/s]


Epoch 19/40 — Train Loss: 1.3016, Train Acc: 0.4941  |  Val Loss: 1.3459, Val Acc: 0.4859
------------------------------------------------------------


Epoch 20 [Train]: 100%|██████████| 198/198 [00:09<00:00, 20.26it/s]
Epoch 20 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.39it/s]


Epoch 20/40 — Train Loss: 1.2860, Train Acc: 0.5027  |  Val Loss: 1.3306, Val Acc: 0.4887
------------------------------------------------------------


Epoch 21 [Train]: 100%|██████████| 198/198 [00:09<00:00, 21.65it/s]
Epoch 21 [Val]: 100%|██████████| 15/15 [00:00<00:00, 19.54it/s]


Epoch 21/40 — Train Loss: 1.2675, Train Acc: 0.5101  |  Val Loss: 1.3328, Val Acc: 0.4943
------------------------------------------------------------


Epoch 22 [Train]: 100%|██████████| 198/198 [00:09<00:00, 21.54it/s]
Epoch 22 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.46it/s]


Epoch 22/40 — Train Loss: 1.2512, Train Acc: 0.5149  |  Val Loss: 1.3347, Val Acc: 0.4965
------------------------------------------------------------


Epoch 23 [Train]: 100%|██████████| 198/198 [00:10<00:00, 19.08it/s]
Epoch 23 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.78it/s]


Epoch 23/40 — Train Loss: 1.2447, Train Acc: 0.5164  |  Val Loss: 1.3223, Val Acc: 0.4999
------------------------------------------------------------


Epoch 24 [Train]: 100%|██████████| 198/198 [00:10<00:00, 18.75it/s]
Epoch 24 [Val]: 100%|██████████| 15/15 [00:01<00:00, 10.96it/s]


Epoch 24/40 — Train Loss: 1.2333, Train Acc: 0.5200  |  Val Loss: 1.3193, Val Acc: 0.4940
------------------------------------------------------------


Epoch 25 [Train]: 100%|██████████| 198/198 [00:09<00:00, 21.85it/s]
Epoch 25 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.30it/s]


Epoch 25/40 — Train Loss: 1.2197, Train Acc: 0.5263  |  Val Loss: 1.3039, Val Acc: 0.5038
------------------------------------------------------------


Epoch 26 [Train]: 100%|██████████| 198/198 [00:09<00:00, 21.25it/s]
Epoch 26 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.03it/s]


Epoch 26/40 — Train Loss: 1.2126, Train Acc: 0.5292  |  Val Loss: 1.2893, Val Acc: 0.5065
------------------------------------------------------------


Epoch 27 [Train]: 100%|██████████| 198/198 [00:09<00:00, 19.97it/s]
Epoch 27 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.72it/s]


Epoch 27/40 — Train Loss: 1.2051, Train Acc: 0.5308  |  Val Loss: 1.3056, Val Acc: 0.5046
------------------------------------------------------------


Epoch 28 [Train]: 100%|██████████| 198/198 [00:09<00:00, 20.86it/s]
Epoch 28 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.84it/s]


Epoch 28/40 — Train Loss: 1.1968, Train Acc: 0.5374  |  Val Loss: 1.3058, Val Acc: 0.5116
------------------------------------------------------------


Epoch 29 [Train]: 100%|██████████| 198/198 [00:08<00:00, 22.49it/s]
Epoch 29 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.65it/s]


Epoch 29/40 — Train Loss: 1.1876, Train Acc: 0.5398  |  Val Loss: 1.3043, Val Acc: 0.5082
------------------------------------------------------------


Epoch 30 [Train]: 100%|██████████| 198/198 [00:09<00:00, 19.84it/s]
Epoch 30 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.57it/s]


Epoch 30/40 — Train Loss: 1.1813, Train Acc: 0.5427  |  Val Loss: 1.2970, Val Acc: 0.5099
------------------------------------------------------------


Epoch 31 [Train]: 100%|██████████| 198/198 [00:09<00:00, 20.11it/s]
Epoch 31 [Val]: 100%|██████████| 15/15 [00:00<00:00, 23.46it/s]


Epoch 31/40 — Train Loss: 1.1747, Train Acc: 0.5438  |  Val Loss: 1.3142, Val Acc: 0.5132
------------------------------------------------------------


Epoch 32 [Train]: 100%|██████████| 198/198 [00:08<00:00, 22.13it/s]
Epoch 32 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.25it/s]


Epoch 32/40 — Train Loss: 1.1683, Train Acc: 0.5482  |  Val Loss: 1.2650, Val Acc: 0.5199
------------------------------------------------------------


Epoch 33 [Train]: 100%|██████████| 198/198 [00:10<00:00, 18.45it/s]
Epoch 33 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.91it/s]


Epoch 33/40 — Train Loss: 1.1661, Train Acc: 0.5476  |  Val Loss: 1.2707, Val Acc: 0.5166
------------------------------------------------------------


Epoch 34 [Train]: 100%|██████████| 198/198 [00:10<00:00, 19.15it/s]
Epoch 34 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.86it/s]


Epoch 34/40 — Train Loss: 1.1595, Train Acc: 0.5511  |  Val Loss: 1.2825, Val Acc: 0.5146
------------------------------------------------------------


Epoch 35 [Train]: 100%|██████████| 198/198 [00:09<00:00, 19.92it/s]
Epoch 35 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.18it/s]


Epoch 35/40 — Train Loss: 1.1508, Train Acc: 0.5537  |  Val Loss: 1.2661, Val Acc: 0.5191
------------------------------------------------------------


Epoch 36 [Train]: 100%|██████████| 198/198 [00:08<00:00, 22.42it/s]
Epoch 36 [Val]: 100%|██████████| 15/15 [00:00<00:00, 25.93it/s]


Epoch 36/40 — Train Loss: 1.1439, Train Acc: 0.5572  |  Val Loss: 1.2769, Val Acc: 0.5163
------------------------------------------------------------


Epoch 37 [Train]: 100%|██████████| 198/198 [00:09<00:00, 20.19it/s]
Epoch 37 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.60it/s]


Epoch 37/40 — Train Loss: 1.1429, Train Acc: 0.5562  |  Val Loss: 1.2572, Val Acc: 0.5185
------------------------------------------------------------


Epoch 38 [Train]: 100%|██████████| 198/198 [00:09<00:00, 20.20it/s]
Epoch 38 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.62it/s]


Epoch 38/40 — Train Loss: 1.1365, Train Acc: 0.5599  |  Val Loss: 1.2715, Val Acc: 0.5235
------------------------------------------------------------


Epoch 39 [Train]: 100%|██████████| 198/198 [00:09<00:00, 21.35it/s]
Epoch 39 [Val]: 100%|██████████| 15/15 [00:01<00:00, 14.57it/s]


Epoch 39/40 — Train Loss: 1.1286, Train Acc: 0.5650  |  Val Loss: 1.2506, Val Acc: 0.5249
------------------------------------------------------------


Epoch 40 [Train]: 100%|██████████| 198/198 [00:09<00:00, 21.63it/s]
Epoch 40 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.93it/s]


Epoch 40/40 — Train Loss: 1.1290, Train Acc: 0.5610  |  Val Loss: 1.2449, Val Acc: 0.5286
------------------------------------------------------------


In [None]:
torch.save(improved_model.state_dict(), "checkpoint40.pth")

In [None]:
improved_model = GigaModel().to(device)
improved_model.load_state_dict(torch.load("checkpoint40.pth"))


<All keys matched successfully>

In [None]:
trained = train_model(
    improved_model, train_dl, val_dl,
    criterion, optimizer, device,
    epochs=40,
    class_names=class_names
)

0,1
f1_Angry,▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▅▂▃▃▃▆▄▃▆▅▆▆▅▆▆▆▇▇▇█▇████
f1_Disgust,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1_Fear,▁▂▂▂▂▂█▄▃█▂▃▆▃▄▄▂▅▄▅▅▅▃▄▄▃▃▅▄▆▆▆▆▆▆▆▇▆▆▆
f1_Happy,▂▂▂▂▃▃▁▄▄▅▆▆▇▆▇▇▆▆▇▇▇▆▇▇▇█▇▇▇▇██████████
f1_Neutral,▁▁▁▂▃▃▃▅▅▅▆▃▆▆▇▇▃▆▇▇▆▆▅▇▇▇▆▇▇▇██████████
f1_Sad,▂▃▁▂▃▂▄▃▅▅▆▇▇▆▆▆▇▆▇▅▇▇▇▇▇▇▇███▇▇████████
f1_Surprise,▃▁▁▄▆▇▄▇▅▇▇▇▇█▇█▇███▇█▇█████▇███████████
train_acc,▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇███████
train_loss,█▇▇▇▇▆▆▆▅▅▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
val_acc,▂▂▂▂▃▃▁▄▄▅▅▅▆▆▆▆▅▆▆▆▇▇▆▇▇▇▇▇▇███████████

0,1
f1_Angry,0.3198
f1_Disgust,0.0
f1_Fear,0.11076
f1_Happy,0.77668
f1_Neutral,0.50311
f1_Sad,0.47066
f1_Surprise,0.69486
train_acc,0.52283
train_loss,1.22974
val_acc,0.5294


Epoch 1 [Train]: 100%|██████████| 225/225 [01:26<00:00,  2.59it/s]
Epoch 1 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.39it/s]


Epoch 1/40 — Train Loss: 1.2203, Train Acc: 0.5256  |  Val Loss: 1.2049, Val Acc: 0.5313
------------------------------------------------------------


Epoch 2 [Train]: 100%|██████████| 225/225 [01:28<00:00,  2.55it/s]
Epoch 2 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.87it/s]


Epoch 2/40 — Train Loss: 1.2183, Train Acc: 0.5259  |  Val Loss: 1.2044, Val Acc: 0.5297
------------------------------------------------------------


Epoch 3 [Train]: 100%|██████████| 225/225 [01:29<00:00,  2.51it/s]
Epoch 3 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.44it/s]


Epoch 3/40 — Train Loss: 1.2160, Train Acc: 0.5263  |  Val Loss: 1.2035, Val Acc: 0.5300
------------------------------------------------------------


Epoch 4 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.50it/s]
Epoch 4 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.79it/s]


Epoch 4/40 — Train Loss: 1.2168, Train Acc: 0.5275  |  Val Loss: 1.2046, Val Acc: 0.5308
------------------------------------------------------------


Epoch 5 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 5 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.68it/s]


Epoch 5/40 — Train Loss: 1.2122, Train Acc: 0.5264  |  Val Loss: 1.2030, Val Acc: 0.5333
------------------------------------------------------------


Epoch 6 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 6 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.84it/s]


Epoch 6/40 — Train Loss: 1.2172, Train Acc: 0.5233  |  Val Loss: 1.2049, Val Acc: 0.5316
------------------------------------------------------------


Epoch 7 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.50it/s]
Epoch 7 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.73it/s]


Epoch 7/40 — Train Loss: 1.2129, Train Acc: 0.5299  |  Val Loss: 1.2034, Val Acc: 0.5302
------------------------------------------------------------


Epoch 8 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.50it/s]
Epoch 8 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.95it/s]


Epoch 8/40 — Train Loss: 1.2109, Train Acc: 0.5294  |  Val Loss: 1.2037, Val Acc: 0.5283
------------------------------------------------------------


Epoch 9 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.50it/s]
Epoch 9 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.69it/s]


Epoch 9/40 — Train Loss: 1.2134, Train Acc: 0.5294  |  Val Loss: 1.2068, Val Acc: 0.5272
------------------------------------------------------------


Epoch 10 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 10 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.66it/s]


Epoch 10/40 — Train Loss: 1.2156, Train Acc: 0.5277  |  Val Loss: 1.2055, Val Acc: 0.5325
------------------------------------------------------------


Epoch 11 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 11 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.67it/s]


Epoch 11/40 — Train Loss: 1.2132, Train Acc: 0.5276  |  Val Loss: 1.2051, Val Acc: 0.5305
------------------------------------------------------------


Epoch 12 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 12 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.92it/s]


Epoch 12/40 — Train Loss: 1.2083, Train Acc: 0.5318  |  Val Loss: 1.2028, Val Acc: 0.5322
------------------------------------------------------------


Epoch 13 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 13 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.66it/s]


Epoch 13/40 — Train Loss: 1.2159, Train Acc: 0.5283  |  Val Loss: 1.2077, Val Acc: 0.5277
------------------------------------------------------------


Epoch 14 [Train]:  77%|███████▋  | 174/225 [01:10<00:20,  2.47it/s]


KeyboardInterrupt: 

#Targeted oversampling of underperforming classes

In [None]:
import numpy as np
import pandas as pd

def targeted_oversampling(df, target_count, img_shape=(48,48)):
    """
    Upsample each emotion class:
      – if emotion in [0,1,2,6], new_target = original_count * 2
      – otherwise           , new_target = target_count
    Applies a random ±10 intensity shift to each augmented image.
    """
    def augment(pix_str):
        arr   = np.fromstring(pix_str, sep=' ', dtype=np.uint8).reshape(img_shape)
        shift = np.random.randint(-10, 11)
        arr   = np.clip(arr.astype(int) + shift, 0, 255).astype(np.uint8)
        return ' '.join(map(str, arr.ravel()))

    # define which emotion labels to double
    double_classes = {0, 1, 2, 6}

    parts = [df]
    for emo, grp in df.groupby('emotion'):
        n_orig = len(grp)
        # decide how many we want in the end
        if emo in double_classes:
            new_target = n_orig * 2
        else:
            new_target = target_count

        # only augment if we need more
        if new_target > n_orig:
            n_extra = new_target - n_orig
            extra  = grp.sample(n_extra, replace=True).copy()
            extra['pixels'] = extra['pixels'].map(augment)
            parts.append(extra)

    return pd.concat(parts, ignore_index=True)


In [None]:
# DataLoaders
max_count = df_train['emotion'].value_counts().max()
balanced_train = targeted_oversampling(df_train, target_count=max_count)

#train_ds = FERDataset(df_train)
train_ds = FERDataset(balanced_train)
val_ds = FERDataset(df_val)
test_ds = FERDataset(df_test)

batch_size = 256
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=2)
val_dl = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2)
test_dl = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=2)

# Class names
class_names = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]

In [None]:
balanced_train['emotion'].value_counts()

Unnamed: 0_level_0,count
emotion,Unnamed: 1_level_1
6,9930
2,8194
0,7990
4,7215
3,7215
5,7215
1,872


In [None]:
#improved_model = ImprovedModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(improved_model.parameters(), lr=1e-3)
trained = train_model(
    improved_model,
    train_dl,
    val_dl,
    criterion,
    optimizer,
    device,
    epochs=40,
    class_names=class_names
)

#Targeted oversampling run1: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/fufcp077?nw=nwuserdimna21
#run2 (continued): https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/gltl4ney
# Tops out at 61%/58%

0,1
f1_Angry,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇█▇▇████
f1_Disgust,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1_Fear,▁▁▃▁▁▁▄▅▃▅▄▅▅▆▆▇█▇▆▅▅▄▄▄▄▄▅▅▆▅▅▆▆▆▆▆▇▆█▆
f1_Happy,▁▁▁▁▁▁▁▆▆▇▇▇▇▇▇█████████████████████████
f1_Neutral,▁▁▁▁▁▂▂▁▃▃▄▄▄▄▄▅▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇██
f1_Sad,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▂▆▃▅▆▆▅▅▆▆▆▇▆▆▆▆▇▇▇███
f1_Surprise,▂▃▁▄▂▆▆▇▇▇██████████████████████████████
train_acc,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████
train_loss,█▇▇▇▇▇▆▆▅▅▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
val_acc,▁▁▁▁▁▂▂▄▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇██▇█████████

0,1
f1_Angry,0.45845
f1_Disgust,0.0
f1_Fear,0.15123
f1_Happy,0.79413
f1_Neutral,0.53029
f1_Sad,0.30561
f1_Surprise,0.73812
train_acc,0.52308
train_loss,1.24769
val_acc,0.53998


Epoch 1 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.43it/s]
Epoch 1 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.58it/s]


Epoch 1/40 — Train Loss: 1.2425, Train Acc: 0.5240  |  Val Loss: 1.2003, Val Acc: 0.5366
------------------------------------------------------------


Epoch 2 [Train]: 100%|██████████| 190/190 [00:09<00:00, 20.37it/s]
Epoch 2 [Val]: 100%|██████████| 15/15 [00:01<00:00, 12.74it/s]


Epoch 2/40 — Train Loss: 1.2296, Train Acc: 0.5290  |  Val Loss: 1.2111, Val Acc: 0.5380
------------------------------------------------------------


Epoch 3 [Train]: 100%|██████████| 190/190 [00:10<00:00, 18.43it/s]
Epoch 3 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.01it/s]


Epoch 3/40 — Train Loss: 1.2275, Train Acc: 0.5336  |  Val Loss: 1.1952, Val Acc: 0.5394
------------------------------------------------------------


Epoch 4 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.95it/s]
Epoch 4 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.49it/s]


Epoch 4/40 — Train Loss: 1.2203, Train Acc: 0.5343  |  Val Loss: 1.1931, Val Acc: 0.5425
------------------------------------------------------------


Epoch 5 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.13it/s]
Epoch 5 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.49it/s]


Epoch 5/40 — Train Loss: 1.2163, Train Acc: 0.5375  |  Val Loss: 1.1953, Val Acc: 0.5394
------------------------------------------------------------


Epoch 6 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.94it/s]
Epoch 6 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.43it/s]


Epoch 6/40 — Train Loss: 1.2136, Train Acc: 0.5387  |  Val Loss: 1.1956, Val Acc: 0.5447
------------------------------------------------------------


Epoch 7 [Train]: 100%|██████████| 190/190 [00:10<00:00, 18.41it/s]
Epoch 7 [Val]: 100%|██████████| 15/15 [00:00<00:00, 15.20it/s]


Epoch 7/40 — Train Loss: 1.2025, Train Acc: 0.5435  |  Val Loss: 1.1990, Val Acc: 0.5511
------------------------------------------------------------


Epoch 8 [Train]: 100%|██████████| 190/190 [00:10<00:00, 18.27it/s]
Epoch 8 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.29it/s]


Epoch 8/40 — Train Loss: 1.2006, Train Acc: 0.5457  |  Val Loss: 1.2068, Val Acc: 0.5508
------------------------------------------------------------


Epoch 9 [Train]: 100%|██████████| 190/190 [00:10<00:00, 17.84it/s]
Epoch 9 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.38it/s]


Epoch 9/40 — Train Loss: 1.1962, Train Acc: 0.5486  |  Val Loss: 1.1813, Val Acc: 0.5528
------------------------------------------------------------


Epoch 10 [Train]: 100%|██████████| 190/190 [00:09<00:00, 20.10it/s]
Epoch 10 [Val]: 100%|██████████| 15/15 [00:00<00:00, 17.68it/s]


Epoch 10/40 — Train Loss: 1.1975, Train Acc: 0.5518  |  Val Loss: 1.1922, Val Acc: 0.5573
------------------------------------------------------------


Epoch 11 [Train]: 100%|██████████| 190/190 [00:08<00:00, 21.16it/s]
Epoch 11 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.16it/s]


Epoch 11/40 — Train Loss: 1.1917, Train Acc: 0.5517  |  Val Loss: 1.2061, Val Acc: 0.5534
------------------------------------------------------------


Epoch 12 [Train]: 100%|██████████| 190/190 [00:10<00:00, 18.92it/s]
Epoch 12 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.02it/s]


Epoch 12/40 — Train Loss: 1.1851, Train Acc: 0.5544  |  Val Loss: 1.1744, Val Acc: 0.5534
------------------------------------------------------------


Epoch 13 [Train]: 100%|██████████| 190/190 [00:10<00:00, 18.98it/s]
Epoch 13 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.32it/s]


Epoch 13/40 — Train Loss: 1.1797, Train Acc: 0.5572  |  Val Loss: 1.1763, Val Acc: 0.5584
------------------------------------------------------------


Epoch 14 [Train]: 100%|██████████| 190/190 [00:08<00:00, 21.44it/s]
Epoch 14 [Val]: 100%|██████████| 15/15 [00:00<00:00, 17.69it/s]


Epoch 14/40 — Train Loss: 1.1748, Train Acc: 0.5597  |  Val Loss: 1.1706, Val Acc: 0.5642
------------------------------------------------------------


Epoch 15 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.40it/s]
Epoch 15 [Val]: 100%|██████████| 15/15 [00:00<00:00, 27.66it/s]


Epoch 15/40 — Train Loss: 1.1693, Train Acc: 0.5633  |  Val Loss: 1.1762, Val Acc: 0.5634
------------------------------------------------------------


Epoch 16 [Train]: 100%|██████████| 190/190 [00:10<00:00, 19.00it/s]
Epoch 16 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.34it/s]


Epoch 16/40 — Train Loss: 1.1667, Train Acc: 0.5649  |  Val Loss: 1.1682, Val Acc: 0.5581
------------------------------------------------------------


Epoch 17 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.63it/s]
Epoch 17 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.53it/s]


Epoch 17/40 — Train Loss: 1.1665, Train Acc: 0.5663  |  Val Loss: 1.1721, Val Acc: 0.5617
------------------------------------------------------------


Epoch 18 [Train]: 100%|██████████| 190/190 [00:08<00:00, 21.21it/s]
Epoch 18 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.79it/s]


Epoch 18/40 — Train Loss: 1.1568, Train Acc: 0.5686  |  Val Loss: 1.1858, Val Acc: 0.5634
------------------------------------------------------------


Epoch 19 [Train]: 100%|██████████| 190/190 [00:10<00:00, 18.52it/s]
Epoch 19 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.49it/s]


Epoch 19/40 — Train Loss: 1.1597, Train Acc: 0.5684  |  Val Loss: 1.1622, Val Acc: 0.5648
------------------------------------------------------------


Epoch 20 [Train]: 100%|██████████| 190/190 [00:10<00:00, 18.97it/s]
Epoch 20 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.19it/s]


Epoch 20/40 — Train Loss: 1.1546, Train Acc: 0.5703  |  Val Loss: 1.1743, Val Acc: 0.5595
------------------------------------------------------------


Epoch 21 [Train]: 100%|██████████| 190/190 [00:08<00:00, 21.55it/s]
Epoch 21 [Val]: 100%|██████████| 15/15 [00:01<00:00, 10.76it/s]


Epoch 21/40 — Train Loss: 1.1478, Train Acc: 0.5735  |  Val Loss: 1.1625, Val Acc: 0.5662
------------------------------------------------------------


Epoch 22 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.89it/s]
Epoch 22 [Val]: 100%|██████████| 15/15 [00:00<00:00, 19.22it/s]


Epoch 22/40 — Train Loss: 1.1403, Train Acc: 0.5747  |  Val Loss: 1.1577, Val Acc: 0.5712
------------------------------------------------------------


Epoch 23 [Train]: 100%|██████████| 190/190 [00:10<00:00, 18.99it/s]
Epoch 23 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.29it/s]


Epoch 23/40 — Train Loss: 1.1359, Train Acc: 0.5764  |  Val Loss: 1.1747, Val Acc: 0.5673
------------------------------------------------------------


Epoch 24 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.04it/s]
Epoch 24 [Val]: 100%|██████████| 15/15 [00:00<00:00, 30.28it/s]


Epoch 24/40 — Train Loss: 1.1320, Train Acc: 0.5804  |  Val Loss: 1.1526, Val Acc: 0.5715
------------------------------------------------------------


Epoch 25 [Train]: 100%|██████████| 190/190 [00:08<00:00, 21.44it/s]
Epoch 25 [Val]: 100%|██████████| 15/15 [00:00<00:00, 20.94it/s]


Epoch 25/40 — Train Loss: 1.1340, Train Acc: 0.5792  |  Val Loss: 1.1628, Val Acc: 0.5648
------------------------------------------------------------


Epoch 26 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.45it/s]
Epoch 26 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.37it/s]


Epoch 26/40 — Train Loss: 1.1229, Train Acc: 0.5845  |  Val Loss: 1.1614, Val Acc: 0.5656
------------------------------------------------------------


Epoch 27 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.10it/s]
Epoch 27 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.77it/s]


Epoch 27/40 — Train Loss: 1.1266, Train Acc: 0.5823  |  Val Loss: 1.1625, Val Acc: 0.5628
------------------------------------------------------------


Epoch 28 [Train]: 100%|██████████| 190/190 [00:10<00:00, 17.83it/s]
Epoch 28 [Val]: 100%|██████████| 15/15 [00:00<00:00, 18.12it/s]


Epoch 28/40 — Train Loss: 1.1132, Train Acc: 0.5895  |  Val Loss: 1.1752, Val Acc: 0.5673
------------------------------------------------------------


Epoch 29 [Train]: 100%|██████████| 190/190 [00:11<00:00, 17.12it/s]
Epoch 29 [Val]: 100%|██████████| 15/15 [00:02<00:00,  6.95it/s]


Epoch 29/40 — Train Loss: 1.1139, Train Acc: 0.5857  |  Val Loss: 1.1662, Val Acc: 0.5717
------------------------------------------------------------


Epoch 30 [Train]: 100%|██████████| 190/190 [00:08<00:00, 21.38it/s]
Epoch 30 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.79it/s]


Epoch 30/40 — Train Loss: 1.1078, Train Acc: 0.5922  |  Val Loss: 1.1616, Val Acc: 0.5667
------------------------------------------------------------


Epoch 31 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.18it/s]
Epoch 31 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.87it/s]


Epoch 31/40 — Train Loss: 1.1052, Train Acc: 0.5946  |  Val Loss: 1.1503, Val Acc: 0.5756
------------------------------------------------------------


Epoch 32 [Train]: 100%|██████████| 190/190 [00:10<00:00, 18.95it/s]
Epoch 32 [Val]: 100%|██████████| 15/15 [00:00<00:00, 27.40it/s]


Epoch 32/40 — Train Loss: 1.1011, Train Acc: 0.5949  |  Val Loss: 1.1609, Val Acc: 0.5751
------------------------------------------------------------


Epoch 33 [Train]: 100%|██████████| 190/190 [00:08<00:00, 21.34it/s]
Epoch 33 [Val]: 100%|██████████| 15/15 [00:00<00:00, 17.23it/s]


Epoch 33/40 — Train Loss: 1.1011, Train Acc: 0.5969  |  Val Loss: 1.1679, Val Acc: 0.5754
------------------------------------------------------------


Epoch 34 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.67it/s]
Epoch 34 [Val]: 100%|██████████| 15/15 [00:00<00:00, 29.16it/s]


Epoch 34/40 — Train Loss: 1.0856, Train Acc: 0.6025  |  Val Loss: 1.1641, Val Acc: 0.5787
------------------------------------------------------------


Epoch 35 [Train]: 100%|██████████| 190/190 [00:10<00:00, 18.91it/s]
Epoch 35 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.89it/s]


Epoch 35/40 — Train Loss: 1.0914, Train Acc: 0.5995  |  Val Loss: 1.1571, Val Acc: 0.5782
------------------------------------------------------------


Epoch 36 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.60it/s]
Epoch 36 [Val]: 100%|██████████| 15/15 [00:00<00:00, 19.67it/s]


Epoch 36/40 — Train Loss: 1.0850, Train Acc: 0.6011  |  Val Loss: 1.1559, Val Acc: 0.5765
------------------------------------------------------------


Epoch 37 [Train]: 100%|██████████| 190/190 [00:08<00:00, 21.27it/s]
Epoch 37 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.01it/s]


Epoch 37/40 — Train Loss: 1.0830, Train Acc: 0.6037  |  Val Loss: 1.1743, Val Acc: 0.5731
------------------------------------------------------------


Epoch 38 [Train]: 100%|██████████| 190/190 [00:10<00:00, 18.94it/s]
Epoch 38 [Val]: 100%|██████████| 15/15 [00:00<00:00, 27.78it/s]


Epoch 38/40 — Train Loss: 1.0810, Train Acc: 0.6052  |  Val Loss: 1.1568, Val Acc: 0.5754
------------------------------------------------------------


Epoch 39 [Train]: 100%|██████████| 190/190 [00:09<00:00, 19.00it/s]
Epoch 39 [Val]: 100%|██████████| 15/15 [00:00<00:00, 28.82it/s]


Epoch 39/40 — Train Loss: 1.0805, Train Acc: 0.6047  |  Val Loss: 1.1654, Val Acc: 0.5770
------------------------------------------------------------


Epoch 40 [Train]: 100%|██████████| 190/190 [00:08<00:00, 21.37it/s]
Epoch 40 [Val]: 100%|██████████| 15/15 [00:01<00:00, 10.90it/s]


Epoch 40/40 — Train Loss: 1.0709, Train Acc: 0.6073  |  Val Loss: 1.1593, Val Acc: 0.5745
------------------------------------------------------------


In [None]:
import torch
from sklearn.metrics import classification_report, confusion_matrix

# switch to eval mode
improved_model.eval()

all_preds, all_labels = [], []

with torch.no_grad():
    for X, y in test_dl:
        X, y = X.to(device), y.to(device)
        out = improved_model(X)
        preds = out.argmax(dim=1)
        all_preds .extend(preds.cpu().numpy())
        all_labels.extend(y.cpu().numpy())

# 1) Full report
print("Test Classification Report")
print(classification_report(all_labels, all_preds, target_names=class_names))

# 2) Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
print("Test Confusion Matrix:\n", cm)

# 3) (Optional) save raw preds
import pandas as pd
test_df = df_test.copy().reset_index(drop=True)
test_df['predicted_emotion'] = [class_names[p] for p in all_preds]
test_df.to_csv("fer2013_test_predictions.csv", index=False)
print("Saved predictions to fer2013_test_predictions.csv")


Test Classification Report
              precision    recall  f1-score   support

       Angry       0.48      0.55      0.51       491
     Disgust       0.00      0.00      0.00        55
        Fear       0.47      0.28      0.35       528
       Happy       0.82      0.83      0.82       879
         Sad       0.47      0.35      0.40       594
    Surprise       0.68      0.84      0.75       416
     Neutral       0.50      0.70      0.58       626

    accuracy                           0.60      3589
   macro avg       0.49      0.51      0.49      3589
weighted avg       0.58      0.60      0.58      3589

Test Confusion Matrix:
 [[270   0  31  27  51  18  94]
 [ 42   0   9   2   0   1   1]
 [ 88   0 146  24  86  86  98]
 [ 19   0  14 730  34  30  52]
 [ 81   0  69  46 205  14 179]
 [  9   0  23  16   3 351  14]
 [ 50   0  16  50  57  15 438]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saved predictions to fer2013_test_predictions.csv
