<a href="https://colab.research.google.com/github/dimna21/ML_Assignment4/blob/main/FER2013.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wandb



In [2]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdimna21[0m ([33mdimna21-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Load the CSV
csv_path = "/content/drive/MyDrive/FER_data/fer2013/fer2013.csv"
df = pd.read_csv(csv_path)

# Split by Usage
df_train = df[df['Usage']=="Training"].copy()
df_val   = df[df['Usage']=="PublicTest"].copy()
df_test  = df[df['Usage']=="PrivateTest"].copy()

In [5]:
# Balance function: upsampling & random ±10 intensity shifts
def balance_dataset(df, target_count, img_shape=(48,48)):
    def augment(pix_str):
        arr = np.fromstring(pix_str, sep=' ', dtype=int).reshape(img_shape)
        shift = np.random.randint(-10, 11)
        arr = np.clip(arr + shift, 0, 255).astype(int)
        return ' '.join(map(str, arr.ravel()))
    parts = [df]
    for emo, grp in df.groupby('emotion'):
        n = len(grp)
        if n < target_count:
            extra = grp.sample(n=target_count-n, replace=True).copy()
            extra['pixels'] = extra['pixels'].map(augment)
            parts.append(extra)
    return pd.concat(parts, ignore_index=True)


In [6]:
# Dataset class
class FERDataset(Dataset):
    def __init__(self, dataframe):
        self.pixels = dataframe['pixels'].values
        self.labels = dataframe['emotion'].values.astype(int)
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        arr = np.fromstring(self.pixels[idx], sep=' ', dtype=np.uint8).reshape(48,48)
        arr = arr.astype(np.float32) / 255.0
        tensor = torch.from_numpy(arr).unsqueeze(0)  # shape [1,48,48]
        return tensor, self.labels[idx]


In [7]:
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [9]:
# DataLoaders
max_count = df_train['emotion'].value_counts().max()
balanced_train = balance_dataset(df_train, target_count=max_count)

train_ds = FERDataset(balanced_train)
val_ds = FERDataset(df_val)
test_ds = FERDataset(df_test)

batch_size = 128
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=2)
val_dl = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2)
test_dl = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=2)

# Class names
class_names = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]

In [19]:
import torch
from tqdm import tqdm
import wandb
from sklearn.metrics import confusion_matrix, f1_score
import matplotlib.pyplot as plt
import numpy as np

def train_model(model, train_loader, val_loader,
                criterion, optimizer, device,
                epochs=5, class_names=None):

    wandb.init(
        project="ML_Assignment4",
        config={
            "epochs": epochs,
            "batch_size": train_loader.batch_size,
            "optimizer": optimizer.__class__.__name__,
            "lr": optimizer.param_groups[0]["lr"],
            "criterion": criterion.__class__.__name__,
        },
    )
    cfg = wandb.config
    wandb.watch(model, log="all", log_freq=100)
    model.to(device)

    train_loss_plot, val_loss_plot = [], []
    train_acc_plot,  val_acc_plot  = [], []

    for epoch in range(1, cfg.epochs + 1):
        # — TRAIN —
        model.train()
        running_loss = running_correct = running_total = 0
        for X, y in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(X)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss    += loss.item() * X.size(0)
            preds            = out.argmax(dim=1)
            running_correct += preds.eq(y).sum().item()
            running_total   += y.size(0)

        train_loss = running_loss / running_total
        train_acc  = running_correct / running_total

        # — VALIDATE —
        model.eval()
        val_running_loss = val_running_correct = val_running_total = 0
        all_preds, all_targets = [], []
        with torch.no_grad():
            for X, y in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
                X, y = X.to(device), y.to(device)
                out = model(X)
                loss = criterion(out, y)

                val_running_loss    += loss.item() * X.size(0)
                preds                = out.argmax(dim=1)
                val_running_correct += preds.eq(y).sum().item()
                val_running_total   += y.size(0)

                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(y.cpu().numpy())

        val_loss = val_running_loss / val_running_total
        val_acc  = val_running_correct / val_running_total

        # — SCALARS & F1 —
        cm = confusion_matrix(all_targets, all_preds)
        f1_per_class = f1_score(all_targets, all_preds, average=None)

        log_data = {
            "train_loss": train_loss,
            "train_acc":  train_acc,
            "val_loss":   val_loss,
            "val_acc":    val_acc,
        }
        for i, name in enumerate(class_names):
            log_data[f"f1_{name}"] = f1_per_class[i]

        wandb.log(log_data, step=epoch)

        # — 7×7 Confusion Matrix Plot & Log —
        fig_cm, ax = plt.subplots(figsize=(6,6))
        im = ax.imshow(cm, interpolation='nearest', cmap='Blues')
        fig_cm.colorbar(im, ax=ax)

        ax.set_xticks(np.arange(len(class_names)))
        ax.set_yticks(np.arange(len(class_names)))
        ax.set_xticklabels(class_names, rotation=45, ha='right')
        ax.set_yticklabels(class_names)

        for i in range(len(class_names)):
            for j in range(len(class_names)):
                ax.text(j, i, cm[i, j],
                        ha='center', va='center')

        ax.set_xlabel('Predicted')
        ax.set_ylabel('Actual')
        ax.set_title(f'Epoch {epoch} Confusion Matrix')

        wandb.log({"confusion_matrix": wandb.Image(fig_cm)}, step=epoch)
        plt.close(fig_cm)

        # — PRINT & STORE FOR CURVES —
        print(
            f"Epoch {epoch}/{cfg.epochs} — "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}  |  "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
        )
        print("-" * 60)

        train_loss_plot.append(train_loss)
        val_loss_plot.append(val_loss)
        train_acc_plot.append(train_acc)
        val_acc_plot.append(val_acc)

    # — PLOT & LOG LOSS/ACC CURVES TO W&B —
    epochs_range = list(range(1, cfg.epochs + 1))

    # Loss curve
    fig1, ax1 = plt.subplots()
    ax1.plot(epochs_range, train_loss_plot, label="Train Loss")
    ax1.plot(epochs_range, val_loss_plot,   label="Val Loss")
    ax1.set(title="Loss vs Epoch", xlabel="Epoch", ylabel="Loss")
    ax1.legend()
    wandb.log({"loss_curve": wandb.Image(fig1)})
    plt.close(fig1)

    # Accuracy curve
    fig2, ax2 = plt.subplots()
    ax2.plot(epochs_range, train_acc_plot, label="Train Acc")
    ax2.plot(epochs_range, val_acc_plot,   label="Val Acc")
    ax2.set(title="Accuracy vs Epoch", xlabel="Epoch", ylabel="Accuracy")
    ax2.legend()
    wandb.log({"acc_curve": wandb.Image(fig2)})
    plt.close(fig2)

    return model


In [20]:
class BaselineModel(nn.Module):
    def __init__(self):
      super().__init__()

      self.conv1 = nn.Conv2d(1, 32, kernel_size = 3, padding = 1)
      self.conv2 = nn.Conv2d(32, 64, kernel_size = 3, padding = 1)
      self.conv3 = nn.Conv2d(64, 128, kernel_size = 3, padding = 1)
      self.pooling = nn.MaxPool2d(2,2)
      self.relu = nn.ReLU()

      self.flatten = nn.Flatten()
      self.linear = nn.Linear((128 * 6 * 6), 128)
      self.output = nn.Linear(128, 7)


    def forward(self, x):
      x = self.conv1(x)
      x = self.pooling(x)
      x = self.relu(x)
      x = self.conv2(x)
      x = self.pooling(x)
      x = self.relu(x)
      x = self.conv3(x)
      x = self.pooling(x)
      x = self.relu(x)
      x = self.flatten(x)
      x = self.linear(x)
      x = self.output(x)

      return x

In [21]:
baseline_model = BaselineModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(baseline_model.parameters(), lr=1e-3)
trained = train_model(
    baseline_model,
    train_dl,
    val_dl,
    criterion,
    optimizer,
    device,
    epochs=10,
    class_names=class_names
)

0,1
f1_Angry,▁
f1_Disgust,▁
f1_Fear,▁
f1_Happy,▁
f1_Neutral,▁
f1_Sad,▁
f1_Surprise,▁
train_acc,▁
train_loss,▁
val_acc,▁

0,1
f1_Angry,0.29811
f1_Disgust,0.18845
f1_Fear,0.26126
f1_Happy,0.60232
f1_Neutral,0.39846
f1_Sad,0.15109
f1_Surprise,0.57737
train_acc,0.32347
train_loss,1.71568
val_acc,0.39621


Epoch 1 [Train]: 100%|██████████| 395/395 [00:06<00:00, 62.24it/s]
Epoch 1 [Val]: 100%|██████████| 29/29 [00:00<00:00, 61.22it/s]


Epoch 1/10 — Train Loss: 1.6891, Train Acc: 0.3348  |  Val Loss: 1.5523, Val Acc: 0.4129
------------------------------------------------------------


Epoch 2 [Train]: 100%|██████████| 395/395 [00:07<00:00, 49.69it/s]
Epoch 2 [Val]: 100%|██████████| 29/29 [00:00<00:00, 60.54it/s]


Epoch 2/10 — Train Loss: 1.3016, Train Acc: 0.5122  |  Val Loss: 1.4285, Val Acc: 0.4675
------------------------------------------------------------


Epoch 3 [Train]: 100%|██████████| 395/395 [00:06<00:00, 63.77it/s]
Epoch 3 [Val]: 100%|██████████| 29/29 [00:00<00:00, 63.93it/s]


Epoch 3/10 — Train Loss: 1.0911, Train Acc: 0.5885  |  Val Loss: 1.3528, Val Acc: 0.5029
------------------------------------------------------------


Epoch 4 [Train]: 100%|██████████| 395/395 [00:07<00:00, 49.66it/s]
Epoch 4 [Val]: 100%|██████████| 29/29 [00:00<00:00, 61.07it/s]


Epoch 4/10 — Train Loss: 0.9745, Train Acc: 0.6346  |  Val Loss: 1.4600, Val Acc: 0.4723
------------------------------------------------------------


Epoch 5 [Train]: 100%|██████████| 395/395 [00:06<00:00, 64.40it/s]
Epoch 5 [Val]: 100%|██████████| 29/29 [00:00<00:00, 57.07it/s]


Epoch 5/10 — Train Loss: 0.8938, Train Acc: 0.6679  |  Val Loss: 1.4338, Val Acc: 0.5049
------------------------------------------------------------


Epoch 6 [Train]: 100%|██████████| 395/395 [00:08<00:00, 49.32it/s]
Epoch 6 [Val]: 100%|██████████| 29/29 [00:00<00:00, 62.79it/s]


Epoch 6/10 — Train Loss: 0.8252, Train Acc: 0.6977  |  Val Loss: 1.4726, Val Acc: 0.5065
------------------------------------------------------------


Epoch 7 [Train]: 100%|██████████| 395/395 [00:06<00:00, 63.28it/s]
Epoch 7 [Val]: 100%|██████████| 29/29 [00:00<00:00, 59.52it/s]


Epoch 7/10 — Train Loss: 0.7707, Train Acc: 0.7183  |  Val Loss: 1.5149, Val Acc: 0.5054
------------------------------------------------------------


Epoch 8 [Train]: 100%|██████████| 395/395 [00:07<00:00, 49.55it/s]
Epoch 8 [Val]: 100%|██████████| 29/29 [00:00<00:00, 59.65it/s]


Epoch 8/10 — Train Loss: 0.7238, Train Acc: 0.7376  |  Val Loss: 1.5778, Val Acc: 0.5146
------------------------------------------------------------


Epoch 9 [Train]: 100%|██████████| 395/395 [00:06<00:00, 64.15it/s]
Epoch 9 [Val]: 100%|██████████| 29/29 [00:00<00:00, 33.93it/s]


Epoch 9/10 — Train Loss: 0.6771, Train Acc: 0.7573  |  Val Loss: 1.6033, Val Acc: 0.5258
------------------------------------------------------------


Epoch 10 [Train]: 100%|██████████| 395/395 [00:07<00:00, 53.14it/s]
Epoch 10 [Val]: 100%|██████████| 29/29 [00:00<00:00, 60.88it/s]


Epoch 10/10 — Train Loss: 0.6394, Train Acc: 0.7711  |  Val Loss: 1.6563, Val Acc: 0.5135
------------------------------------------------------------
