In [131]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as f
torch.set_float32_matmul_precision('medium')

In [132]:
class SharpenTransform:
    def __init__(self, sharpness_factor=2.0):
        self.sharpness_factor = sharpness_factor

    def __call__(self, img):
        return f.adjust_sharpness(img, self.sharpness_factor)

In [133]:
class FERPlusDataset(Dataset):
    def __init__(self, img_folder, csv_file, transform=None, skip_nf=True):
        """
        Args:
            img_folder (str): Directory containing the image files.
            csv_file (str): Path to CSV annotation file containing image names, labels, and emotion distributions.
            transform (callable, optional): Optional torchvision transforms to apply on image.
            skip_nf (bool): Whether to exclude "NF" (Not Face) images from dataset.
            
        Notes:
            - The CSV must include columns for image name, Usage split (Training, PublicTest, PrivateTest),
              emotion count columns matching 'neutral', 'happiness', 'surprise', 'sadness', 'anger', 
              'disgust', 'fear', 'contempt', 'unknown', and the 'NF' label.
            - Images are loaded as grayscale, resized to 224x224, and sharpened with factor 5 by default.
            - Label distribution is computed by normalizing the raw emotion counts per sample.
        """
        self.img_folder = img_folder
        self.transform = transform
        
        # Load CSV with label distributions and metadata
        self.data = pd.read_csv(csv_file)
        
        # Map folder name to 'Usage' column value in CSV to filter dataset split accordingly
        usage_map = {
            "FER2013Train": "Training",
            "FER2013Valid": "PublicTest",
            "FER2013Test": "PrivateTest"
        }
        # Identify which subset (split) to load based on folder name
        usage_required = usage_map.get(os.path.basename(img_folder), None)
        if usage_required is None:
            raise ValueError(f"Unknown dataset folder: {img_folder}")
        
        # Filter the dataset to only include images from the selected split
        self.data = self.data[self.data['Usage'] == usage_required].reset_index(drop=True)
        
        # Optionally exclude images marked as Not Face (NF = 1)
        if skip_nf:
            self.data = self.data[self.data['NF'] == 0].reset_index(drop=True)
        
        # Define emotion categories (order must match columns in CSV exactly)
        self.emotions = ['neutral','happiness','surprise','sadness','anger','disgust','fear','contempt','unknown']
        
        # If no transform provided, define default preprocessing:
        # Convert to 224x224 grayscale, apply sharpening with factor 5
        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((112, 112), interpolation=transforms.InterpolationMode.LANCZOS),
                SharpenTransform(sharpness_factor=2.0),
                transforms.ToTensor(),   # convert PIL Image to tensor after preprocessing
                transforms.Normalize(mean=[0.5076], std=[0.2119])
            ])


    def __len__(self):
        # Return total number of valid samples
        return len(self.data)

    def __getitem__(self, idx):
        # Get image filename at index
        img_name = self.data.loc[idx, 'Image name']

        # Construct full image path
        img_path = os.path.join(self.img_folder, img_name)
        
        # Load image as grayscale with PIL
        image = Image.open(img_path).convert('L')
        
        # Apply preprocessing transforms (resize, sharpen, to tensor)
        image = self.transform(image)

        # Extract raw emotion counts for label distribution
        counts = self.data.loc[idx, self.emotions].values.astype(float)
        
        # Normalize counts to create label distribution (probability vector)
        label_distribution = counts / counts.sum()
        
        # Convert label distribution to PyTorch tensor of floats
        label_distribution = torch.tensor(label_distribution, dtype=torch.float32)
        
        # Return preprocessed image tensor and LDL label distribution
        return image, label_distribution

In [134]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# Paths remain the same
train_imgs_path = "../FERPlusData/FER2013Train"
val_imgs_path = "../FERPlusData/FER2013Valid"
test_imgs_path = "../FERPlusData/FER2013Test"
data_csv_path = "../fer2013new.csv"

# Define the sharpening transform once to reuse
sharpen = SharpenTransform(sharpness_factor=2.0)

# Training dataset and loader with full augmentations
train_dataset = FERPlusDataset(
    img_folder=train_imgs_path,
    csv_file=data_csv_path,
    skip_nf=True,
    transform=transforms.Compose([
        transforms.Resize((112, 112), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05), shear=5),
        sharpen,
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5076], std=[0.2119])
    ])
)

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
)

# Validation dataset and loader with only deterministic preprocessing
val_dataset = FERPlusDataset(
    img_folder=val_imgs_path,
    csv_file=data_csv_path,
    skip_nf=True,
    transform=transforms.Compose([
        transforms.Resize((112, 112), interpolation=transforms.InterpolationMode.LANCZOS),
        sharpen,
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5076], std=[0.2119])
    ])
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
)

In [135]:
# Attention
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1=nn.Conv2d(in_planes, in_planes//ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        return self.sigmoid(avg_out + max_out)
    
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = (kernel_size - 1) // 2
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv1(x_cat))

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=8, spatial_kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(spatial_kernel_size)

    def forward(self, x):
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out

In [136]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=False, p=0.2):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout2d(p=p) if dropout else nn.Identity()

        self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.skip(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out

In [137]:
class ResidualCBAMClassifier(nn.Module):
    def __init__(self, img_size=112, in_channels=1, num_classes=9):
        super().__init__()
        self.blocks = nn.Sequential(
            ResidualBlock(in_channels, 16),
            ResidualBlock(16, 32),
            ResidualBlock(32, 64),
            ResidualBlock(64, 128),
            ResidualBlock(128, 128, dropout=True, p=0.1)
        )
        self.cbam = CBAM(128, ratio=8)

        # Get flattened feature dim dynamically
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, img_size, img_size)
            feat_dim = torch.flatten(self.cbam(self.blocks(dummy)), 1).shape[1]

        self.fc = nn.Linear(feat_dim, num_classes)

    def forward(self, x):
        x = self.blocks(x)
        x = self.cbam(x)
        x = torch.flatten(x, 1)
        logits = self.fc(x)
        return F.softmax(logits, dim=1)

In [138]:
# 1. Device
device = torch.device(
    "cuda" if torch.cuda.is_available() 
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using device: {device}")

# 2. Max epochs
max_epochs = 25

# 3. Model & Optimizer
model = ResidualCBAMClassifier().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

# 4. LR Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=max_epochs,
    eta_min=1e-6
)

Using device: cuda


In [139]:
import numpy as np
class EarlyStopper:
    """
    Early stops training if validation loss doesn't improve after a given patience.
    """
    def __init__(self, patience=7, min_delta=0.0, verbose=False, save_path=None):
        """
        Args:
            patience (int): How long to wait after last improvement before stopping.
            min_delta (float): Minimum change to qualify as improvement.
            verbose (bool): Print messages when validation improves.
            save_path (str or None): Path to save the best model checkpoint.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = np.Inf
        self.early_stop = False
        self.save_path = save_path

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            if self.save_path is not None:
                torch.save(model.state_dict(), self.save_path)
                if self.verbose:
                    print(f"Validation loss decreased. Saving model to {self.save_path}")
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True


In [140]:
def loss_fn(logits, label_distributions):
    log_probs = F.log_softmax(logits, dim=1)
    loss = F.kl_div(log_probs, label_distributions, reduction='batchmean')
    return loss

In [141]:
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
from torchmetrics import MeanSquaredError
from tqdm import tqdm

def train_val_loop(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    loss_fn,
    early_stopper,
    num_epochs,
    device
):
    # Initialize torchmetrics for accuracy, f1 and mse on the device
    train_acc = MulticlassAccuracy(num_classes=9, average='micro').to(device)
    val_acc = MulticlassAccuracy(num_classes=9, average='micro').to(device)

    train_f1 = MulticlassF1Score(num_classes=9, average='weighted').to(device)
    val_f1 = MulticlassF1Score(num_classes=9, average='weighted').to(device)

    train_mse = MeanSquaredError().to(device)
    val_mse = MeanSquaredError().to(device)

    scaler = torch.amp.GradScaler()
    use_amp = (device.type == 'cuda')

    for epoch in range(1, num_epochs + 1):
        # --- Training Phase ---
        model.train()
        total_train_loss = 0
        train_acc.reset()
        train_f1.reset()
        train_mse.reset()
        train_samples = 0

        train_bar = tqdm(train_loader, desc=f"Epoch {epoch} Training", leave=False)
        for images, labels in train_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            with torch.amp.autocast(enabled=use_amp, device_type=device.type):
                logits = model(images)
                loss = loss_fn(logits, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_train_loss += loss.item() * images.size(0)
            train_samples += images.size(0)

            probs = F.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)
            targets = torch.argmax(labels, dim=1)

            train_acc.update(preds, targets)
            train_f1.update(preds, targets)
            train_mse.update(probs, labels)

            train_bar.set_postfix(loss=loss.item())

        avg_train_loss = total_train_loss / train_samples
        avg_train_acc = train_acc.compute().item()
        avg_train_f1 = train_f1.compute().item()
        avg_train_mse = train_mse.compute().item()

        # --- Validation Phase ---
        model.eval()
        total_val_loss = 0
        val_acc.reset()
        val_f1.reset()
        val_mse.reset()
        val_samples = 0

        with torch.no_grad():
            val_bar = tqdm(val_loader, desc=f"Epoch {epoch} Validation", leave=False)
            for images, labels in val_bar:
                images, labels = images.to(device), labels.to(device)
                with torch.amp.autocast(enabled=use_amp, device_type=device.type):
                    logits = model(images)
                    loss = loss_fn(logits, labels)
                total_val_loss += loss.item() * images.size(0)
                val_samples += images.size(0)

                probs = F.softmax(logits, dim=1)
                preds = torch.argmax(probs, dim=1)
                targets = torch.argmax(labels, dim=1)

                val_acc.update(preds, targets)
                val_f1.update(preds, targets)
                val_mse.update(probs, labels)

                val_bar.set_postfix(loss=loss.item())

        avg_val_loss = total_val_loss / val_samples
        avg_val_acc = val_acc.compute().item()
        avg_val_f1 = val_f1.compute().item()
        avg_val_mse = val_mse.compute().item()

        # Step learning rate scheduler once per epoch after validation
        scheduler.step()

        # Print training and validation results for the epoch
        print(f"[Epoch {epoch}/{num_epochs}] "
              f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f} | "
              f"Train Acc: {avg_train_acc:.4f}, Val Acc: {avg_val_acc:.4f} | "
              f"Train F1: {avg_train_f1:.4f}, Val F1: {avg_val_f1:.4f} | "
              f"Train MSE: {avg_train_mse:.4f}, Val MSE: {avg_val_mse:.4f}")

        # Early stopping check and save best model
        early_stopper(avg_val_loss, model)
        if early_stopper.early_stop:
            print("Early stopping triggered. Stopping training.")
            break

    # Load the best weights saved by early stopper (if any)
    if early_stopper.save_path is not None:
        print(f"Loading best model from {early_stopper.save_path}")
        model.load_state_dict(torch.load(early_stopper.save_path, map_location=device))

In [None]:
early_stopper = EarlyStopper(
    patience=10,
    min_delta=1e-4,
    verbose=True,
    save_path="../models/early_best_models/new_arch_best_model.pth"
)

train_val_loop(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_fn=loss_fn,
    early_stopper=early_stopper,
    num_epochs=max_epochs,
    device=device
)
# We will need to change the model or aq


Epoch 1 Training:   0%|          | 0/890 [00:00<?, ?it/s]

                                                                                

[Epoch 1/25] Train Loss: 1.4384, Val Loss: 1.4919 | Train Acc: 0.3603, Val Acc: 0.3727 | Train F1: 0.1914, Val F1: 0.2023 | Train MSE: 0.0550, Val MSE: 0.0587
Validation loss decreased. Saving model to ../models/early_best_models/new_arch_best_model.pth


                                                                                

[Epoch 2/25] Train Loss: 1.4383, Val Loss: 1.4919 | Train Acc: 0.3607, Val Acc: 0.3727 | Train F1: 0.1912, Val F1: 0.2023 | Train MSE: 0.0550, Val MSE: 0.0587
EarlyStopping counter: 1 out of 10


                                                                                

[Epoch 3/25] Train Loss: 1.4383, Val Loss: 1.4919 | Train Acc: 0.3607, Val Acc: 0.3727 | Train F1: 0.1912, Val F1: 0.2023 | Train MSE: 0.0550, Val MSE: 0.0587
EarlyStopping counter: 2 out of 10


                                                                              

KeyboardInterrupt: 