In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import numpy as np
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
import optuna
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Paths
data_dir = "../FBMM/Unsplitted_Ready_Sets/set_01_class_balanced_augs_applied_splitted"
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "val")
test_dir = os.path.join(data_dir, "test")
pretrained_model_path = "./efficientnet_b2_emotion_model.pth"
model_save_path = "./models/fine_tuned_efficientnet_b2.pth"

In [3]:
# Configuration
batch_size = 16
num_epochs = 50
initial_lr = 1e-4
weight_decay = 1e-4
num_classes = 7
img_height, img_width = 260, 260
seed = 42
accumulation_steps = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Emotion categories
emotion_classes = ["Anger", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]

In [4]:
# Data Augmentation for RGB Dataset
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(260, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_val_test = transforms.Compose([
    transforms.Resize((260, 260)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [5]:
# Load Datasets
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform_train)
val_dataset = datasets.ImageFolder(root=val_dir, transform=transform_val_test)
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform_val_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [6]:
# Load Pretrained EfficientNet-B2 Model
model = models.efficientnet_b2(weights=None)  # Start with an uninitialized model

In [7]:
# Adjust input channels if necessary (Grayscale → RGB)
if model.features[0][0].in_channels == 1:
    model.features[0][0] = nn.Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

# Modify the classifier for 7 classes
model.classifier[1] = nn.Sequential(
    nn.Linear(model.classifier[1].in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, num_classes)
)

In [8]:
# Load previous model weights (handling key mismatches)
pretrained_dict = torch.load(pretrained_model_path, map_location=device)
model_dict = model.state_dict()

  pretrained_dict = torch.load(pretrained_model_path, map_location=device)


In [9]:
# Remove keys that don't match
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, strict=False)

<All keys matched successfully>

In [10]:
# Move to device
model = model.to(device)

# Optimizer & Scheduler
optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)

In [11]:
# Loss Function
criterion = nn.CrossEntropyLoss()

In [12]:
# Training Function
def train_model():
    model.train()
    for epoch in range(num_epochs):
        total_loss, correct, total = 0, 0, 0
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)

            # Forward
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            total_loss += loss.item()

        # Validation Step
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)
                val_loss += loss.item()

        train_acc = correct / total
        val_acc = val_correct / val_total
        print(f"Epoch {epoch+1}/{num_epochs} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

        # Adjust learning rate
        scheduler.step(val_loss)

        # Save best model
        torch.save(model.state_dict(), model_save_path)

In [13]:
# Run Training
train_model()

  1%|          | 68/7751 [00:56<1:46:34,  1.20it/s]


KeyboardInterrupt: 

In [None]:
# Testing Function
def test_model():
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    # Compute metrics
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=emotion_classes))

    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    df_cm = pd.DataFrame(cm, index=emotion_classes, columns=emotion_classes)
    plt.figure(figsize=(8, 6))
    sns.heatmap(df_cm, annot=True, fmt="d", cmap="Blues")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.show()

# Run Testing
test_model()