<a href="https://colab.research.google.com/github/dimna21/ML_Assignment4/blob/main/FER2013.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wandb



In [2]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdimna21[0m ([33mdimna21-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Load the CSV
csv_path = "/content/drive/MyDrive/FER_data/fer2013/fer2013.csv"
df = pd.read_csv(csv_path)

# Split by Usage
df_train = df[df['Usage']=="Training"].copy()
df_val   = df[df['Usage']=="PublicTest"].copy()
df_test  = df[df['Usage']=="PrivateTest"].copy()

In [5]:
# Balance function: upsampling & random ±10 intensity shifts
def balance_dataset(df, target_count, img_shape=(48,48)):
    def augment(pix_str):
        arr = np.fromstring(pix_str, sep=' ', dtype=int).reshape(img_shape)
        shift = np.random.randint(-10, 11)
        arr = np.clip(arr + shift, 0, 255).astype(int)
        return ' '.join(map(str, arr.ravel()))
    parts = [df]
    for emo, grp in df.groupby('emotion'):
        n = len(grp)
        if n < target_count:
            extra = grp.sample(n=target_count-n, replace=True).copy()
            extra['pixels'] = extra['pixels'].map(augment)
            parts.append(extra)
    return pd.concat(parts, ignore_index=True)


In [6]:
# Dataset class
class FERDataset(Dataset):
    def __init__(self, dataframe):
        self.pixels = dataframe['pixels'].values
        self.labels = dataframe['emotion'].values.astype(int)
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        arr = np.fromstring(self.pixels[idx], sep=' ', dtype=np.uint8).reshape(48,48)
        arr = arr.astype(np.float32) / 255.0
        tensor = torch.from_numpy(arr).unsqueeze(0)  # shape [1,48,48]
        return tensor, self.labels[idx]


In [7]:
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [8]:
# DataLoaders
max_count = df_train['emotion'].value_counts().max()
balanced_train = balance_dataset(df_train, target_count=max_count)

#train_ds = FERDataset(df_train)
train_ds = FERDataset(balanced_train)
val_ds = FERDataset(df_val)
test_ds = FERDataset(df_test)

batch_size = 128
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=2)
val_dl = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2)
test_dl = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=2)

# Class names
class_names = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]

In [9]:
import torch
from tqdm import tqdm
import wandb
from sklearn.metrics import confusion_matrix, f1_score
import matplotlib.pyplot as plt
import numpy as np

def train_model(model, train_loader, val_loader,
                criterion, optimizer, device,
                epochs=5, class_names=None):

    wandb.init(
        project="ML_Assignment4",
        config={
            "epochs": epochs,
            "batch_size": train_loader.batch_size,
            "optimizer": optimizer.__class__.__name__,
            "lr": optimizer.param_groups[0]["lr"],
            "criterion": criterion.__class__.__name__,
        },
    )
    cfg = wandb.config
    wandb.watch(model, log="all", log_freq=100)
    model.to(device)

    train_loss_plot, val_loss_plot = [], []
    train_acc_plot,  val_acc_plot  = [], []

    for epoch in range(1, cfg.epochs + 1):
        # — TRAIN —
        model.train()
        running_loss = running_correct = running_total = 0
        for X, y in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(X)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss    += loss.item() * X.size(0)
            preds            = out.argmax(dim=1)
            running_correct += preds.eq(y).sum().item()
            running_total   += y.size(0)

        train_loss = running_loss / running_total
        train_acc  = running_correct / running_total

        # — VALIDATE —
        model.eval()
        val_running_loss = val_running_correct = val_running_total = 0
        all_preds, all_targets = [], []
        with torch.no_grad():
            for X, y in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
                X, y = X.to(device), y.to(device)
                out = model(X)
                loss = criterion(out, y)

                val_running_loss    += loss.item() * X.size(0)
                preds                = out.argmax(dim=1)
                val_running_correct += preds.eq(y).sum().item()
                val_running_total   += y.size(0)

                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(y.cpu().numpy())

        val_loss = val_running_loss / val_running_total
        val_acc  = val_running_correct / val_running_total

        # — SCALARS & F1 —
        cm = confusion_matrix(all_targets, all_preds)
        f1_per_class = f1_score(all_targets, all_preds, average=None)

        log_data = {
            "train_loss": train_loss,
            "train_acc":  train_acc,
            "val_loss":   val_loss,
            "val_acc":    val_acc,
        }
        for i, name in enumerate(class_names):
            log_data[f"f1_{name}"] = f1_per_class[i]

        wandb.log(log_data, step=epoch)

        # — 7×7 Confusion Matrix Plot & Log —
        fig_cm, ax = plt.subplots(figsize=(6,6))
        im = ax.imshow(cm, interpolation='nearest', cmap='Blues')
        fig_cm.colorbar(im, ax=ax)

        ax.set_xticks(np.arange(len(class_names)))
        ax.set_yticks(np.arange(len(class_names)))
        ax.set_xticklabels(class_names, rotation=45, ha='right')
        ax.set_yticklabels(class_names)

        for i in range(len(class_names)):
            for j in range(len(class_names)):
                ax.text(j, i, cm[i, j],
                        ha='center', va='center')

        ax.set_xlabel('Predicted')
        ax.set_ylabel('Actual')
        ax.set_title(f'Epoch {epoch} Confusion Matrix')

        wandb.log({"confusion_matrix": wandb.Image(fig_cm)}, step=epoch)
        plt.close(fig_cm)

        # — PRINT & STORE FOR CURVES —
        print(
            f"Epoch {epoch}/{cfg.epochs} — "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}  |  "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
        )
        print("-" * 60)

        train_loss_plot.append(train_loss)
        val_loss_plot.append(val_loss)
        train_acc_plot.append(train_acc)
        val_acc_plot.append(val_acc)

    # — PLOT & LOG LOSS/ACC CURVES TO W&B —
    epochs_range = list(range(1, cfg.epochs + 1))

    # Loss curve
    fig1, ax1 = plt.subplots()
    ax1.plot(epochs_range, train_loss_plot, label="Train Loss")
    ax1.plot(epochs_range, val_loss_plot,   label="Val Loss")
    ax1.set(title="Loss vs Epoch", xlabel="Epoch", ylabel="Loss")
    ax1.legend()
    wandb.log({"loss_curve": wandb.Image(fig1)})
    plt.close(fig1)

    # Accuracy curve
    fig2, ax2 = plt.subplots()
    ax2.plot(epochs_range, train_acc_plot, label="Train Acc")
    ax2.plot(epochs_range, val_acc_plot,   label="Val Acc")
    ax2.set(title="Accuracy vs Epoch", xlabel="Epoch", ylabel="Accuracy")
    ax2.legend()
    wandb.log({"acc_curve": wandb.Image(fig2)})
    plt.close(fig2)

    return model


In [None]:
class BaselineModel(nn.Module):
    def __init__(self):
      super().__init__()

      self.conv1 = nn.Conv2d(1, 32, kernel_size = 3, padding = 1)
      self.conv2 = nn.Conv2d(32, 64, kernel_size = 3, padding = 1)
      self.conv3 = nn.Conv2d(64, 128, kernel_size = 3, padding = 1)
      self.pooling = nn.MaxPool2d(2,2)
      self.relu = nn.ReLU()

      self.flatten = nn.Flatten()
      self.linear = nn.Linear((128 * 6 * 6), 128)
      self.output = nn.Linear(128, 7)


    def forward(self, x):
      x = self.conv1(x)
      x = self.pooling(x)
      x = self.relu(x)
      x = self.conv2(x)
      x = self.pooling(x)
      x = self.relu(x)
      x = self.conv3(x)
      x = self.pooling(x)
      x = self.relu(x)
      x = self.flatten(x)
      x = self.linear(x)
      x = self.output(x)

      return x

#Run with dataset balancing:
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/7amufnzs?nw=nwuserdimna21

#Run without dataset balancing:
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/dnyp6qje
#20 epochs instead of 10:
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/0axoxyj9?nw=nwuserdimna21

In [None]:
baseline_model = BaselineModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(baseline_model.parameters(), lr=1e-3)
trained = train_model(
    baseline_model,
    train_dl,
    val_dl,
    criterion,
    optimizer,
    device,
    epochs=20,
    class_names=class_names
)

0,1
f1_Angry,▁▂▅▆▆▆▇▇█▇
f1_Disgust,▁▁▂▄▅▇▅▇██
f1_Fear,▁▃▄▅▅▆▇█▆█
f1_Happy,▁▃▄▇▆▇█▇██
f1_Neutral,▁▂▄▆▇██▇▇█
f1_Sad,▁▅▃▇█▅█▆▇▆
f1_Surprise,▁▃▅▇▇▇▇█▇█
train_acc,▁▃▄▅▆▆▇▇██
train_loss,█▆▅▄▃▃▂▂▁▁
val_acc,▁▄▄▇▇▇████

0,1
f1_Angry,0.4639
f1_Disgust,0.50633
f1_Fear,0.37651
f1_Happy,0.75697
f1_Neutral,0.49241
f1_Sad,0.37192
f1_Surprise,0.71262
train_acc,0.67275
train_loss,0.89268
val_acc,0.54416


Epoch 1 [Train]: 100%|██████████| 225/225 [00:06<00:00, 34.76it/s]
Epoch 1 [Val]: 100%|██████████| 29/29 [00:01<00:00, 25.32it/s]


Epoch 1/20 — Train Loss: 1.6812, Train Acc: 0.3272  |  Val Loss: 1.5147, Val Acc: 0.4235
------------------------------------------------------------


Epoch 2 [Train]: 100%|██████████| 225/225 [00:05<00:00, 42.10it/s]
Epoch 2 [Val]: 100%|██████████| 29/29 [00:00<00:00, 46.90it/s]


Epoch 2/20 — Train Loss: 1.4493, Train Acc: 0.4449  |  Val Loss: 1.3864, Val Acc: 0.4762
------------------------------------------------------------


Epoch 3 [Train]: 100%|██████████| 225/225 [00:07<00:00, 28.47it/s]
Epoch 3 [Val]: 100%|██████████| 29/29 [00:00<00:00, 33.85it/s]


Epoch 3/20 — Train Loss: 1.3132, Train Acc: 0.5038  |  Val Loss: 1.3009, Val Acc: 0.5032
------------------------------------------------------------


Epoch 4 [Train]: 100%|██████████| 225/225 [00:04<00:00, 54.99it/s]
Epoch 4 [Val]: 100%|██████████| 29/29 [00:00<00:00, 50.87it/s]


Epoch 4/20 — Train Loss: 1.2133, Train Acc: 0.5430  |  Val Loss: 1.2520, Val Acc: 0.5319
------------------------------------------------------------


Epoch 5 [Train]: 100%|██████████| 225/225 [00:04<00:00, 54.86it/s]
Epoch 5 [Val]: 100%|██████████| 29/29 [00:00<00:00, 48.91it/s]


Epoch 5/20 — Train Loss: 1.1404, Train Acc: 0.5729  |  Val Loss: 1.2164, Val Acc: 0.5467
------------------------------------------------------------


Epoch 6 [Train]: 100%|██████████| 225/225 [00:06<00:00, 36.63it/s]
Epoch 6 [Val]: 100%|██████████| 29/29 [00:00<00:00, 50.79it/s]


Epoch 6/20 — Train Loss: 1.0759, Train Acc: 0.5992  |  Val Loss: 1.2318, Val Acc: 0.5389
------------------------------------------------------------


Epoch 7 [Train]: 100%|██████████| 225/225 [00:04<00:00, 55.49it/s]
Epoch 7 [Val]: 100%|██████████| 29/29 [00:00<00:00, 52.67it/s]


Epoch 7/20 — Train Loss: 1.0205, Train Acc: 0.6215  |  Val Loss: 1.2220, Val Acc: 0.5500
------------------------------------------------------------


Epoch 8 [Train]: 100%|██████████| 225/225 [00:04<00:00, 48.70it/s]
Epoch 8 [Val]: 100%|██████████| 29/29 [00:01<00:00, 27.93it/s]


Epoch 8/20 — Train Loss: 0.9569, Train Acc: 0.6484  |  Val Loss: 1.2517, Val Acc: 0.5464
------------------------------------------------------------


Epoch 9 [Train]: 100%|██████████| 225/225 [00:05<00:00, 43.54it/s]
Epoch 9 [Val]: 100%|██████████| 29/29 [00:00<00:00, 48.15it/s]


Epoch 9/20 — Train Loss: 0.9012, Train Acc: 0.6701  |  Val Loss: 1.2355, Val Acc: 0.5469
------------------------------------------------------------


Epoch 10 [Train]: 100%|██████████| 225/225 [00:04<00:00, 55.54it/s]
Epoch 10 [Val]: 100%|██████████| 29/29 [00:00<00:00, 53.94it/s]


Epoch 10/20 — Train Loss: 0.8462, Train Acc: 0.6903  |  Val Loss: 1.2926, Val Acc: 0.5517
------------------------------------------------------------


Epoch 11 [Train]: 100%|██████████| 225/225 [00:06<00:00, 33.39it/s]
Epoch 11 [Val]: 100%|██████████| 29/29 [00:00<00:00, 38.83it/s]


Epoch 11/20 — Train Loss: 0.7881, Train Acc: 0.7095  |  Val Loss: 1.3603, Val Acc: 0.5414
------------------------------------------------------------


Epoch 12 [Train]: 100%|██████████| 225/225 [00:04<00:00, 54.46it/s]
Epoch 12 [Val]: 100%|██████████| 29/29 [00:00<00:00, 53.20it/s]


Epoch 12/20 — Train Loss: 0.7348, Train Acc: 0.7298  |  Val Loss: 1.3331, Val Acc: 0.5578
------------------------------------------------------------


Epoch 13 [Train]: 100%|██████████| 225/225 [00:04<00:00, 55.74it/s]
Epoch 13 [Val]: 100%|██████████| 29/29 [00:00<00:00, 37.73it/s]


Epoch 13/20 — Train Loss: 0.6840, Train Acc: 0.7514  |  Val Loss: 1.4282, Val Acc: 0.5553
------------------------------------------------------------


Epoch 14 [Train]: 100%|██████████| 225/225 [00:05<00:00, 38.07it/s]
Epoch 14 [Val]: 100%|██████████| 29/29 [00:00<00:00, 49.62it/s]


Epoch 14/20 — Train Loss: 0.6396, Train Acc: 0.7645  |  Val Loss: 1.5129, Val Acc: 0.5472
------------------------------------------------------------


Epoch 15 [Train]: 100%|██████████| 225/225 [00:04<00:00, 55.25it/s]
Epoch 15 [Val]: 100%|██████████| 29/29 [00:00<00:00, 50.40it/s]


Epoch 15/20 — Train Loss: 0.5856, Train Acc: 0.7889  |  Val Loss: 1.5878, Val Acc: 0.5344
------------------------------------------------------------


Epoch 16 [Train]: 100%|██████████| 225/225 [00:05<00:00, 43.86it/s]
Epoch 16 [Val]: 100%|██████████| 29/29 [00:01<00:00, 28.67it/s]


Epoch 16/20 — Train Loss: 0.5397, Train Acc: 0.8020  |  Val Loss: 1.6771, Val Acc: 0.5419
------------------------------------------------------------


Epoch 17 [Train]: 100%|██████████| 225/225 [00:04<00:00, 47.47it/s]
Epoch 17 [Val]: 100%|██████████| 29/29 [00:00<00:00, 47.57it/s]


Epoch 17/20 — Train Loss: 0.4986, Train Acc: 0.8182  |  Val Loss: 1.7001, Val Acc: 0.5472
------------------------------------------------------------


Epoch 18 [Train]: 100%|██████████| 225/225 [00:04<00:00, 47.82it/s]
Epoch 18 [Val]: 100%|██████████| 29/29 [00:00<00:00, 47.82it/s]


Epoch 18/20 — Train Loss: 0.4636, Train Acc: 0.8318  |  Val Loss: 1.8411, Val Acc: 0.5414
------------------------------------------------------------


Epoch 19 [Train]: 100%|██████████| 225/225 [00:06<00:00, 35.93it/s]
Epoch 19 [Val]: 100%|██████████| 29/29 [00:00<00:00, 50.49it/s]


Epoch 19/20 — Train Loss: 0.4270, Train Acc: 0.8452  |  Val Loss: 1.8972, Val Acc: 0.5511
------------------------------------------------------------


Epoch 20 [Train]: 100%|██████████| 225/225 [00:04<00:00, 54.63it/s]
Epoch 20 [Val]: 100%|██████████| 29/29 [00:00<00:00, 48.18it/s]


Epoch 20/20 — Train Loss: 0.3815, Train Acc: 0.8583  |  Val Loss: 2.1096, Val Acc: 0.5411
------------------------------------------------------------


In [12]:
class ImprovedModel(nn.Module):
    def __init__(self, dropout_p=0.5):
        super().__init__()
        # — conv backbone with 4 pooling stages —
        self.conv1    = nn.Conv2d(1,   32, 3, padding=1, bias=False)
        self.bn1      = nn.BatchNorm2d(32)
        self.conv2    = nn.Conv2d(32,  64, 3, padding=1, bias=False)
        self.bn2      = nn.BatchNorm2d(64)
        self.conv3    = nn.Conv2d(64, 128, 3, padding=1, bias=False)
        self.bn3      = nn.BatchNorm2d(128)
        self.conv4    = nn.Conv2d(128,256, 3, padding=1, bias=False)
        self.bn4      = nn.BatchNorm2d(256)

        self.pool     = nn.MaxPool2d(2,2)
        self.relu     = nn.ReLU(inplace=True)
        self.drop_conv= nn.Dropout2d(dropout_p)

        # — two extra conv layers (no further pooling) —
        self.conv5    = nn.Conv2d(256, 512, 3, padding=1, bias=False)
        self.bn5      = nn.BatchNorm2d(512)
        self.conv6    = nn.Conv2d(512, 512, 3, padding=1, bias=False)
        self.bn6      = nn.BatchNorm2d(512)

        # — richer head with 4 FCs —
        self.flatten  = nn.Flatten()
        self.fc1      = nn.Linear(512 * 3 * 3, 512, bias=False)
        self.bn_fc1   = nn.BatchNorm1d(512)
        self.drop1    = nn.Dropout(dropout_p)
        self.fc2      = nn.Linear(512, 256, bias=False)
        self.bn_fc2   = nn.BatchNorm1d(256)
        self.drop2    = nn.Dropout(dropout_p)
        self.fc3      = nn.Linear(256, 128, bias=False)
        self.bn_fc3   = nn.BatchNorm1d(128)
        self.drop3    = nn.Dropout(dropout_p)
        self.fc4      = nn.Linear(128, 64, bias=False)
        self.bn_fc4   = nn.BatchNorm1d(64)
        self.drop4    = nn.Dropout(dropout_p)
        self.output   = nn.Linear(64, 7)

    def forward(self, x):
        # conv blocks 1–4 with pooling
        for conv, bn in [(self.conv1,self.bn1),
                         (self.conv2,self.bn2),
                         (self.conv3,self.bn3),
                         (self.conv4,self.bn4)]:
            x = conv(x)
            x = bn(x)
            x = self.relu(x)
            x = self.pool(x)
            x = self.drop_conv(x)

        # extra conv blocks 5–6 (no more pooling)
        x = self.conv5(x)
        x = self.bn5(x)
        x = self.relu(x)
        x = self.drop_conv(x)

        x = self.conv6(x)
        x = self.bn6(x)
        x = self.relu(x)
        x = self.drop_conv(x)

        # FC head
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.bn_fc1(x)
        x = self.relu(x)
        x = self.drop1(x)

        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = self.relu(x)
        x = self.drop2(x)

        x = self.fc3(x)
        x = self.bn_fc3(x)
        x = self.relu(x)
        x = self.drop3(x)

        x = self.fc4(x)
        x = self.bn_fc4(x)
        x = self.relu(x)
        x = self.drop4(x)

        x = self.output(x)
        return x


#Run with added batchnorms
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/cuxqkkdm?nw=nwuserdimna21
#topped out at 70%/58%

#Run with added batchnorms and dropouts
#10 epochs:  https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/iactnmo6?nw=nwuserdimna21
#20 epochs:  https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/ydlxqyzi?nw=nwuserdimna21
#started overfitting after reaching 62%/60% accuracies on train/validation

#Run with added 4th convolutional block:
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/lqfkix8t
#no improvements, got 62%/60%

#Run with a more complex head of 512->256->128
#https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/65qgarhf
#slight improvement, got 64%/61%

#Run with added residual blocks
#20 epochs: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/op1fbc6g
#40 epochs: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/cyn9apy4
#no improvement

#Run without residual blocks, with augmented data and increased complexity: 2 extra conv layers and 4 FCs in head
#40 epochs, 0.3 dropout: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/ofsajdfh?nw=nwuserdimna21
#40 epochs, 0.5 dropout: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/itxksjhl

In [None]:
import torch.nn as nn

class ResidualBlock(nn.Module):
    """
    A basic residual block with two 3x3 conv layers and an optional downsample.
    """
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1, dropout_p=0.0):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout2d(dropout_p) if dropout_p > 0 else nn.Identity()

        # downsample path for matching dimensions
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        out = self.dropout(out)
        return out

class GigaModel(nn.Module):
    def __init__(self, dropout_p=0.3):
        super().__init__()
        # define ReLU for use in head
        self.relu = nn.ReLU(inplace=True)
        # initial conv to expand grayscale to 32 channels
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        # residual layers: doubling channels and downsampling spatially at each stage
        self.layer1 = ResidualBlock(32,  32, stride=1, dropout_p=dropout_p)
        self.layer2 = ResidualBlock(32,  64, stride=2, dropout_p=dropout_p)
        self.layer3 = ResidualBlock(64, 128, stride=2, dropout_p=dropout_p)
        self.layer4 = ResidualBlock(128,256, stride=2, dropout_p=dropout_p)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))

        self.flatten = nn.Flatten()
        self.fc1     = nn.Linear(256 * ResidualBlock.expansion, 512, bias=False)
        self.bn_fc1  = nn.BatchNorm1d(512)
        self.drop1   = nn.Dropout(dropout_p)
        self.fc2     = nn.Linear(512, 256, bias=False)
        self.bn_fc2  = nn.BatchNorm1d(256)
        self.drop2   = nn.Dropout(dropout_p)
        self.fc3     = nn.Linear(256, 128, bias=False)
        self.bn_fc3  = nn.BatchNorm1d(128)
        self.drop3   = nn.Dropout(dropout_p)
        self.output  = nn.Linear(128, 7)

    def forward(self, x):
        # initial conv
        x = self.conv1(x)
        # residual conv backbone with downsamples
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # head
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.bn_fc1(x)
        x = self.relu(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = self.relu(x)
        x = self.drop2(x)
        x = self.fc3(x)
        x = self.bn_fc3(x)
        x = self.relu(x)
        x = self.drop3(x)
        x = self.output(x)
        return x

#training pt1: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/8a4wuiau?nw=nwuserdimna21
#training pt2: https://wandb.ai/dimna21-free-university-of-tbilisi-/ML_Assignment4/runs/127aogq7

In [13]:
improved_model = ImprovedModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(improved_model.parameters(), lr=1e-3)
trained = train_model(
    improved_model,
    train_dl,
    val_dl,
    criterion,
    optimizer,
    device,
    epochs=40,
    class_names=class_names
)

0,1
f1_Angry,▁▂▄▅▅▆▆▇▇▇▆▇▇▇▇▇▇█████▇▇█▇███▇█▇████████
f1_Disgust,▁▃▃▄▄▅▅▅▇▇▇▇▇▇▆▇▇▇▇▇▇█▇▇▇▇█▇▇▇▇▇▇▇█▇▇▇▇▇
f1_Fear,▁▂▃▂▃▃▃▅▄▅▆▄▆▆▇▆▇▇▇▇▇▇▇█▇█▇█████████████
f1_Happy,▁▆▇▇▇███████████████████████████████████
f1_Neutral,▁▆▇▇▇▇▇▇█▇█▇████████████████████████████
f1_Sad,▃▂▂▃▁▂▄▄▄▅▅▅▅▅▅▅▆▅▇▆▇▆▆▇▇▆▇▆▇▇▇▇█▆▇▇▇█▇█
f1_Surprise,▁▃▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███▇▇▇█████████████
train_acc,▁▂▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████
train_loss,█▇▆▅▅▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
val_acc,▁▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇██████████████████████

0,1
f1_Angry,0.51735
f1_Disgust,0.56489
f1_Fear,0.41069
f1_Happy,0.79188
f1_Neutral,0.5423
f1_Sad,0.49021
f1_Surprise,0.77396
train_acc,0.82422
train_loss,0.51021
val_acc,0.60156


Epoch 1 [Train]: 100%|██████████| 395/395 [00:11<00:00, 34.29it/s]
Epoch 1 [Val]: 100%|██████████| 29/29 [00:00<00:00, 37.57it/s]


Epoch 1/40 — Train Loss: 1.9626, Train Acc: 0.1572  |  Val Loss: 1.9187, Val Acc: 0.1956
------------------------------------------------------------


Epoch 2 [Train]: 100%|██████████| 395/395 [00:13<00:00, 29.15it/s]
Epoch 2 [Val]: 100%|██████████| 29/29 [00:01<00:00, 28.85it/s]


Epoch 2/40 — Train Loss: 1.9336, Train Acc: 0.1770  |  Val Loss: 1.9011, Val Acc: 0.2251
------------------------------------------------------------


Epoch 3 [Train]: 100%|██████████| 395/395 [00:14<00:00, 27.71it/s]
Epoch 3 [Val]: 100%|██████████| 29/29 [00:00<00:00, 50.30it/s]


Epoch 3/40 — Train Loss: 1.9281, Train Acc: 0.1852  |  Val Loss: 1.9011, Val Acc: 0.2090
------------------------------------------------------------


Epoch 4 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.44it/s]
Epoch 4 [Val]: 100%|██████████| 29/29 [00:00<00:00, 51.71it/s]


Epoch 4/40 — Train Loss: 1.9233, Train Acc: 0.1870  |  Val Loss: 1.9031, Val Acc: 0.1672
------------------------------------------------------------


Epoch 5 [Train]: 100%|██████████| 395/395 [00:12<00:00, 31.26it/s]
Epoch 5 [Val]: 100%|██████████| 29/29 [00:01<00:00, 20.69it/s]


Epoch 5/40 — Train Loss: 1.9177, Train Acc: 0.1941  |  Val Loss: 1.8829, Val Acc: 0.2555
------------------------------------------------------------


Epoch 6 [Train]: 100%|██████████| 395/395 [00:10<00:00, 38.31it/s]
Epoch 6 [Val]: 100%|██████████| 29/29 [00:00<00:00, 34.20it/s]


Epoch 6/40 — Train Loss: 1.8734, Train Acc: 0.2154  |  Val Loss: 1.8087, Val Acc: 0.1254
------------------------------------------------------------


Epoch 7 [Train]: 100%|██████████| 395/395 [00:11<00:00, 34.32it/s]
Epoch 7 [Val]: 100%|██████████| 29/29 [00:00<00:00, 51.45it/s]


Epoch 7/40 — Train Loss: 1.7830, Train Acc: 0.2721  |  Val Loss: 1.7457, Val Acc: 0.2388
------------------------------------------------------------


Epoch 8 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.43it/s]
Epoch 8 [Val]: 100%|██████████| 29/29 [00:00<00:00, 52.90it/s]


Epoch 8/40 — Train Loss: 1.7037, Train Acc: 0.3184  |  Val Loss: 1.7152, Val Acc: 0.2875
------------------------------------------------------------


Epoch 9 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.20it/s]
Epoch 9 [Val]: 100%|██████████| 29/29 [00:00<00:00, 38.86it/s]


Epoch 9/40 — Train Loss: 1.6460, Train Acc: 0.3525  |  Val Loss: 1.6577, Val Acc: 0.3555
------------------------------------------------------------


Epoch 10 [Train]: 100%|██████████| 395/395 [00:12<00:00, 32.10it/s]
Epoch 10 [Val]: 100%|██████████| 29/29 [00:00<00:00, 51.28it/s]


Epoch 10/40 — Train Loss: 1.5904, Train Acc: 0.3823  |  Val Loss: 1.5283, Val Acc: 0.4154
------------------------------------------------------------


Epoch 11 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.66it/s]
Epoch 11 [Val]: 100%|██████████| 29/29 [00:00<00:00, 50.67it/s]


Epoch 11/40 — Train Loss: 1.5324, Train Acc: 0.4092  |  Val Loss: 1.4888, Val Acc: 0.4288
------------------------------------------------------------


Epoch 12 [Train]: 100%|██████████| 395/395 [00:10<00:00, 36.95it/s]
Epoch 12 [Val]: 100%|██████████| 29/29 [00:00<00:00, 30.60it/s]


Epoch 12/40 — Train Loss: 1.4810, Train Acc: 0.4279  |  Val Loss: 1.4728, Val Acc: 0.4478
------------------------------------------------------------


Epoch 13 [Train]: 100%|██████████| 395/395 [00:10<00:00, 36.32it/s]
Epoch 13 [Val]: 100%|██████████| 29/29 [00:00<00:00, 41.24it/s]


Epoch 13/40 — Train Loss: 1.4390, Train Acc: 0.4436  |  Val Loss: 1.4145, Val Acc: 0.4556
------------------------------------------------------------


Epoch 14 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.62it/s]
Epoch 14 [Val]: 100%|██████████| 29/29 [00:00<00:00, 53.87it/s]


Epoch 14/40 — Train Loss: 1.4043, Train Acc: 0.4600  |  Val Loss: 1.4139, Val Acc: 0.4639
------------------------------------------------------------


Epoch 15 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.31it/s]
Epoch 15 [Val]: 100%|██████████| 29/29 [00:00<00:00, 54.39it/s]


Epoch 15/40 — Train Loss: 1.3762, Train Acc: 0.4665  |  Val Loss: 1.3976, Val Acc: 0.4689
------------------------------------------------------------


Epoch 16 [Train]: 100%|██████████| 395/395 [00:12<00:00, 32.03it/s]
Epoch 16 [Val]: 100%|██████████| 29/29 [00:01<00:00, 21.16it/s]


Epoch 16/40 — Train Loss: 1.3470, Train Acc: 0.4805  |  Val Loss: 1.3598, Val Acc: 0.4684
------------------------------------------------------------


Epoch 17 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.03it/s]
Epoch 17 [Val]: 100%|██████████| 29/29 [00:00<00:00, 41.03it/s]


Epoch 17/40 — Train Loss: 1.3273, Train Acc: 0.4861  |  Val Loss: 1.3508, Val Acc: 0.4804
------------------------------------------------------------


Epoch 18 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.56it/s]
Epoch 18 [Val]: 100%|██████████| 29/29 [00:00<00:00, 52.44it/s]


Epoch 18/40 — Train Loss: 1.3067, Train Acc: 0.4961  |  Val Loss: 1.3328, Val Acc: 0.4826
------------------------------------------------------------


Epoch 19 [Train]: 100%|██████████| 395/395 [00:11<00:00, 34.50it/s]
Epoch 19 [Val]: 100%|██████████| 29/29 [00:00<00:00, 31.03it/s]


Epoch 19/40 — Train Loss: 1.2950, Train Acc: 0.4969  |  Val Loss: 1.3287, Val Acc: 0.4923
------------------------------------------------------------


Epoch 20 [Train]: 100%|██████████| 395/395 [00:10<00:00, 37.35it/s]
Epoch 20 [Val]: 100%|██████████| 29/29 [00:00<00:00, 51.52it/s]


Epoch 20/40 — Train Loss: 1.2769, Train Acc: 0.5064  |  Val Loss: 1.3314, Val Acc: 0.4765
------------------------------------------------------------


Epoch 21 [Train]: 100%|██████████| 395/395 [00:11<00:00, 34.16it/s]
Epoch 21 [Val]: 100%|██████████| 29/29 [00:00<00:00, 41.72it/s]


Epoch 21/40 — Train Loss: 1.2614, Train Acc: 0.5124  |  Val Loss: 1.3024, Val Acc: 0.4926
------------------------------------------------------------


Epoch 22 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.73it/s]
Epoch 22 [Val]: 100%|██████████| 29/29 [00:00<00:00, 40.14it/s]


Epoch 22/40 — Train Loss: 1.2530, Train Acc: 0.5116  |  Val Loss: 1.2937, Val Acc: 0.4996
------------------------------------------------------------


Epoch 23 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.50it/s]
Epoch 23 [Val]: 100%|██████████| 29/29 [00:00<00:00, 51.71it/s]


Epoch 23/40 — Train Loss: 1.2396, Train Acc: 0.5192  |  Val Loss: 1.2946, Val Acc: 0.4935
------------------------------------------------------------


Epoch 24 [Train]: 100%|██████████| 395/395 [00:14<00:00, 27.89it/s]
Epoch 24 [Val]: 100%|██████████| 29/29 [00:00<00:00, 40.60it/s]


Epoch 24/40 — Train Loss: 1.2392, Train Acc: 0.5186  |  Val Loss: 1.2938, Val Acc: 0.4993
------------------------------------------------------------


Epoch 25 [Train]: 100%|██████████| 395/395 [00:11<00:00, 35.48it/s]
Epoch 25 [Val]: 100%|██████████| 29/29 [00:01<00:00, 24.56it/s]


Epoch 25/40 — Train Loss: 1.2206, Train Acc: 0.5272  |  Val Loss: 1.2907, Val Acc: 0.4976
------------------------------------------------------------


Epoch 26 [Train]: 100%|██████████| 395/395 [00:10<00:00, 38.51it/s]
Epoch 26 [Val]: 100%|██████████| 29/29 [00:01<00:00, 28.26it/s]


Epoch 26/40 — Train Loss: 1.2117, Train Acc: 0.5290  |  Val Loss: 1.2746, Val Acc: 0.5088
------------------------------------------------------------


Epoch 27 [Train]: 100%|██████████| 395/395 [00:12<00:00, 32.66it/s]
Epoch 27 [Val]: 100%|██████████| 29/29 [00:00<00:00, 53.65it/s]


Epoch 27/40 — Train Loss: 1.2082, Train Acc: 0.5295  |  Val Loss: 1.2668, Val Acc: 0.5052
------------------------------------------------------------


Epoch 28 [Train]: 100%|██████████| 395/395 [00:12<00:00, 31.47it/s]
Epoch 28 [Val]: 100%|██████████| 29/29 [00:00<00:00, 51.77it/s]


Epoch 28/40 — Train Loss: 1.1938, Train Acc: 0.5371  |  Val Loss: 1.2757, Val Acc: 0.5104
------------------------------------------------------------


Epoch 29 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.21it/s]
Epoch 29 [Val]: 100%|██████████| 29/29 [00:00<00:00, 51.52it/s]


Epoch 29/40 — Train Loss: 1.1888, Train Acc: 0.5391  |  Val Loss: 1.2657, Val Acc: 0.5160
------------------------------------------------------------


Epoch 30 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.42it/s]
Epoch 30 [Val]: 100%|██████████| 29/29 [00:00<00:00, 40.96it/s]


Epoch 30/40 — Train Loss: 1.1820, Train Acc: 0.5399  |  Val Loss: 1.2652, Val Acc: 0.5096
------------------------------------------------------------


Epoch 31 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.17it/s]
Epoch 31 [Val]: 100%|██████████| 29/29 [00:00<00:00, 52.07it/s]


Epoch 31/40 — Train Loss: 1.1714, Train Acc: 0.5446  |  Val Loss: 1.2629, Val Acc: 0.5171
------------------------------------------------------------


Epoch 32 [Train]: 100%|██████████| 395/395 [00:10<00:00, 36.18it/s]
Epoch 32 [Val]: 100%|██████████| 29/29 [00:00<00:00, 29.86it/s]


Epoch 32/40 — Train Loss: 1.1691, Train Acc: 0.5456  |  Val Loss: 1.2500, Val Acc: 0.5188
------------------------------------------------------------


Epoch 33 [Train]: 100%|██████████| 395/395 [00:10<00:00, 36.86it/s]
Epoch 33 [Val]: 100%|██████████| 29/29 [00:00<00:00, 53.74it/s]


Epoch 33/40 — Train Loss: 1.1711, Train Acc: 0.5428  |  Val Loss: 1.2620, Val Acc: 0.5116
------------------------------------------------------------


Epoch 34 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.13it/s]
Epoch 34 [Val]: 100%|██████████| 29/29 [00:00<00:00, 39.99it/s]


Epoch 34/40 — Train Loss: 1.1587, Train Acc: 0.5490  |  Val Loss: 1.2587, Val Acc: 0.5163
------------------------------------------------------------


Epoch 35 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.27it/s]
Epoch 35 [Val]: 100%|██████████| 29/29 [00:00<00:00, 51.20it/s]


Epoch 35/40 — Train Loss: 1.1471, Train Acc: 0.5535  |  Val Loss: 1.2463, Val Acc: 0.5180
------------------------------------------------------------


Epoch 36 [Train]: 100%|██████████| 395/395 [00:11<00:00, 32.93it/s]
Epoch 36 [Val]: 100%|██████████| 29/29 [00:00<00:00, 51.72it/s]


Epoch 36/40 — Train Loss: 1.1472, Train Acc: 0.5539  |  Val Loss: 1.2495, Val Acc: 0.5177
------------------------------------------------------------


Epoch 37 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.08it/s]
Epoch 37 [Val]: 100%|██████████| 29/29 [00:00<00:00, 52.34it/s]


Epoch 37/40 — Train Loss: 1.1445, Train Acc: 0.5560  |  Val Loss: 1.2519, Val Acc: 0.5213
------------------------------------------------------------


Epoch 38 [Train]: 100%|██████████| 395/395 [00:11<00:00, 33.50it/s]
Epoch 38 [Val]: 100%|██████████| 29/29 [00:00<00:00, 39.96it/s]


Epoch 38/40 — Train Loss: 1.1388, Train Acc: 0.5580  |  Val Loss: 1.2461, Val Acc: 0.5247
------------------------------------------------------------


Epoch 39 [Train]: 100%|██████████| 395/395 [00:11<00:00, 35.74it/s]
Epoch 39 [Val]: 100%|██████████| 29/29 [00:00<00:00, 29.68it/s]


Epoch 39/40 — Train Loss: 1.1293, Train Acc: 0.5600  |  Val Loss: 1.2440, Val Acc: 0.5283
------------------------------------------------------------


Epoch 40 [Train]: 100%|██████████| 395/395 [00:10<00:00, 36.80it/s]
Epoch 40 [Val]: 100%|██████████| 29/29 [00:00<00:00, 48.62it/s]


Epoch 40/40 — Train Loss: 1.1246, Train Acc: 0.5632  |  Val Loss: 1.2518, Val Acc: 0.5255
------------------------------------------------------------


In [None]:
torch.save(improved_model.state_dict(), "checkpoint40.pth")

In [None]:
improved_model = GigaModel().to(device)
improved_model.load_state_dict(torch.load("checkpoint40.pth"))


<All keys matched successfully>

In [None]:
trained = train_model(
    improved_model, train_dl, val_dl,
    criterion, optimizer, device,
    epochs=40,
    class_names=class_names
)

0,1
f1_Angry,▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▅▂▃▃▃▆▄▃▆▅▆▆▅▆▆▆▇▇▇█▇████
f1_Disgust,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1_Fear,▁▂▂▂▂▂█▄▃█▂▃▆▃▄▄▂▅▄▅▅▅▃▄▄▃▃▅▄▆▆▆▆▆▆▆▇▆▆▆
f1_Happy,▂▂▂▂▃▃▁▄▄▅▆▆▇▆▇▇▆▆▇▇▇▆▇▇▇█▇▇▇▇██████████
f1_Neutral,▁▁▁▂▃▃▃▅▅▅▆▃▆▆▇▇▃▆▇▇▆▆▅▇▇▇▆▇▇▇██████████
f1_Sad,▂▃▁▂▃▂▄▃▅▅▆▇▇▆▆▆▇▆▇▅▇▇▇▇▇▇▇███▇▇████████
f1_Surprise,▃▁▁▄▆▇▄▇▅▇▇▇▇█▇█▇███▇█▇█████▇███████████
train_acc,▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇███████
train_loss,█▇▇▇▇▆▆▆▅▅▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
val_acc,▂▂▂▂▃▃▁▄▄▅▅▅▆▆▆▆▅▆▆▆▇▇▆▇▇▇▇▇▇███████████

0,1
f1_Angry,0.3198
f1_Disgust,0.0
f1_Fear,0.11076
f1_Happy,0.77668
f1_Neutral,0.50311
f1_Sad,0.47066
f1_Surprise,0.69486
train_acc,0.52283
train_loss,1.22974
val_acc,0.5294


Epoch 1 [Train]: 100%|██████████| 225/225 [01:26<00:00,  2.59it/s]
Epoch 1 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.39it/s]


Epoch 1/40 — Train Loss: 1.2203, Train Acc: 0.5256  |  Val Loss: 1.2049, Val Acc: 0.5313
------------------------------------------------------------


Epoch 2 [Train]: 100%|██████████| 225/225 [01:28<00:00,  2.55it/s]
Epoch 2 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.87it/s]


Epoch 2/40 — Train Loss: 1.2183, Train Acc: 0.5259  |  Val Loss: 1.2044, Val Acc: 0.5297
------------------------------------------------------------


Epoch 3 [Train]: 100%|██████████| 225/225 [01:29<00:00,  2.51it/s]
Epoch 3 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.44it/s]


Epoch 3/40 — Train Loss: 1.2160, Train Acc: 0.5263  |  Val Loss: 1.2035, Val Acc: 0.5300
------------------------------------------------------------


Epoch 4 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.50it/s]
Epoch 4 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.79it/s]


Epoch 4/40 — Train Loss: 1.2168, Train Acc: 0.5275  |  Val Loss: 1.2046, Val Acc: 0.5308
------------------------------------------------------------


Epoch 5 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 5 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.68it/s]


Epoch 5/40 — Train Loss: 1.2122, Train Acc: 0.5264  |  Val Loss: 1.2030, Val Acc: 0.5333
------------------------------------------------------------


Epoch 6 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 6 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.84it/s]


Epoch 6/40 — Train Loss: 1.2172, Train Acc: 0.5233  |  Val Loss: 1.2049, Val Acc: 0.5316
------------------------------------------------------------


Epoch 7 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.50it/s]
Epoch 7 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.73it/s]


Epoch 7/40 — Train Loss: 1.2129, Train Acc: 0.5299  |  Val Loss: 1.2034, Val Acc: 0.5302
------------------------------------------------------------


Epoch 8 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.50it/s]
Epoch 8 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.95it/s]


Epoch 8/40 — Train Loss: 1.2109, Train Acc: 0.5294  |  Val Loss: 1.2037, Val Acc: 0.5283
------------------------------------------------------------


Epoch 9 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.50it/s]
Epoch 9 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.69it/s]


Epoch 9/40 — Train Loss: 1.2134, Train Acc: 0.5294  |  Val Loss: 1.2068, Val Acc: 0.5272
------------------------------------------------------------


Epoch 10 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 10 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.66it/s]


Epoch 10/40 — Train Loss: 1.2156, Train Acc: 0.5277  |  Val Loss: 1.2055, Val Acc: 0.5325
------------------------------------------------------------


Epoch 11 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 11 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.67it/s]


Epoch 11/40 — Train Loss: 1.2132, Train Acc: 0.5276  |  Val Loss: 1.2051, Val Acc: 0.5305
------------------------------------------------------------


Epoch 12 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 12 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.92it/s]


Epoch 12/40 — Train Loss: 1.2083, Train Acc: 0.5318  |  Val Loss: 1.2028, Val Acc: 0.5322
------------------------------------------------------------


Epoch 13 [Train]: 100%|██████████| 225/225 [01:30<00:00,  2.49it/s]
Epoch 13 [Val]: 100%|██████████| 29/29 [00:03<00:00,  7.66it/s]


Epoch 13/40 — Train Loss: 1.2159, Train Acc: 0.5283  |  Val Loss: 1.2077, Val Acc: 0.5277
------------------------------------------------------------


Epoch 14 [Train]:  77%|███████▋  | 174/225 [01:10<00:20,  2.47it/s]


KeyboardInterrupt: 