In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from tqdm import tqdm
# from sklearn.metrics import confusion_matrix
from torch.optim.lr_scheduler import CyclicLR

In [None]:
# Paths
data_dir = "../DATA_PREPARE_ATT_04/Grayscale_Face_images"
model_save_path = "efficientnet_b2_emotion_model.pth"
checkpoint_path = "training_checkpoint.pth"

In [None]:
# Configuration
batch_size = 32 
accumulation_steps = 4  
num_epochs = 20
learning_rate = 1e-4
num_classes = 8  # Number of emotion categories
patience = 5  # Early stopping patience
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Emotion categories
emotion_classes = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]


Using device: cuda


In [5]:
# Early Stopping Class
class EarlyStopping:
    def __init__(self, patience=3, delta=0.01):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [None]:
# Data Augmentation
def prepare_data_loaders(data_dir, batch_size):
    transform_train = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(30),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
        transforms.RandomAffine(degrees=0, shear=10, scale=(0.8, 1.2)),
        transforms.ToTensor(),  
        transforms.Normalize(mean=[0.485], std=[0.229]), 
        transforms.RandomErasing(p=0.5, scale=(0.02, 0.2)) 
    ])
    transform_val_test = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),  
        transforms.Resize((224, 224)),
        transforms.ToTensor(),  
        transforms.Normalize(mean=[0.485], std=[0.229]) 
    ])

    print("Loading datasets...")
    train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform_train)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform_val_test)
    test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform_val_test)
    print("Datasets loaded successfully.")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, test_loader

In [None]:
# Model Definition
model = models.efficientnet_b2(pretrained=True)
model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)  
model.classifier = nn.Sequential(
    nn.Dropout(p=0.6), 
    nn.Linear(model.classifier[1].in_features, num_classes),
)
model = model.to(device)



In [9]:
# Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, inputs, targets):
        ce_loss = nn.CrossEntropyLoss()(inputs, targets)
        pt = torch.exp(-ce_loss)
        return ((1 - pt) ** self.gamma * ce_loss).mean()

In [None]:
# Loaders and Criterion
train_loader, val_loader, test_loader = prepare_data_loaders(data_dir, batch_size)
criterion = FocalLoss(gamma=2) 
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-3)
scheduler = CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=2000, mode='triangular')

Loading datasets...
Datasets loaded successfully.


In [11]:
# Training Function
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, model_save_path, checkpoint_path):
    best_val_loss = float("inf")
    early_stopping = EarlyStopping(patience=patience)

    start_epoch = 0
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        print(f"Resuming training from epoch {start_epoch}.")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_loss = 0.0
        print(f"\nEpoch {epoch + 1}/{num_epochs}: Training...")
        optimizer.zero_grad()
        for i, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training Batches", leave=False)):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels) / accumulation_steps
            loss.backward()
            train_loss += loss.item()
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
        train_loss /= len(train_loader)

        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        print(f"Epoch {epoch + 1}/{num_epochs}: Validating...")
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation Batches", leave=False):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_loss /= len(val_loader)
        val_accuracy = correct / total

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy * 100:.2f}%")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), model_save_path)
            print(f"Model saved at epoch {epoch + 1} with Val Accuracy: {val_accuracy * 100:.2f}%")

        if early_stopping(val_loss):
            print("Early stopping triggered.")
            break


train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, model_save_path, checkpoint_path)


Epoch 1/20: Training...


                                                                       

Epoch 1/20: Validating...


                                                                     

Epoch 1/20, Train Loss: 0.3482, Val Loss: 0.8604, Val Accuracy: 46.68%
Model saved at epoch 1 with Val Accuracy: 46.68%

Epoch 2/20: Training...


                                                                     

Epoch 2/20: Validating...


                                                                     

Epoch 2/20, Train Loss: 0.2469, Val Loss: 0.6523, Val Accuracy: 55.83%
Model saved at epoch 2 with Val Accuracy: 55.83%

Epoch 3/20: Training...


                                                                       

Epoch 3/20: Validating...


                                                                     

Epoch 3/20, Train Loss: 0.2195, Val Loss: 0.5925, Val Accuracy: 57.84%
Model saved at epoch 3 with Val Accuracy: 57.84%

Epoch 4/20: Training...


                                                                       

Epoch 4/20: Validating...


                                                                     

Epoch 4/20, Train Loss: 0.1937, Val Loss: 0.4881, Val Accuracy: 62.11%
Model saved at epoch 4 with Val Accuracy: 62.11%

Epoch 5/20: Training...


                                                                       

Epoch 5/20: Validating...


                                                                     

Epoch 5/20, Train Loss: 0.1693, Val Loss: 0.4203, Val Accuracy: 64.43%
Model saved at epoch 5 with Val Accuracy: 64.43%

Epoch 6/20: Training...


                                                                       

Epoch 6/20: Validating...


                                                                     

Epoch 6/20, Train Loss: 0.1502, Val Loss: 0.4117, Val Accuracy: 65.94%
Model saved at epoch 6 with Val Accuracy: 65.94%

Epoch 7/20: Training...


                                                                       

Epoch 7/20: Validating...


                                                                     

Epoch 7/20, Train Loss: 0.1509, Val Loss: 0.4373, Val Accuracy: 64.49%

Epoch 8/20: Training...


                                                                       

Epoch 8/20: Validating...


                                                                     

Epoch 8/20, Train Loss: 0.1651, Val Loss: 0.4804, Val Accuracy: 62.25%

Epoch 9/20: Training...


                                                                       

Epoch 9/20: Validating...


                                                                     

Epoch 9/20, Train Loss: 0.1720, Val Loss: 0.4847, Val Accuracy: 62.82%

Epoch 10/20: Training...


                                                                       

Epoch 10/20: Validating...


                                                                     

Epoch 10/20, Train Loss: 0.1554, Val Loss: 0.4007, Val Accuracy: 65.74%
Model saved at epoch 10 with Val Accuracy: 65.74%

Epoch 11/20: Training...


                                                                       

Epoch 11/20: Validating...


                                                                     

Epoch 11/20, Train Loss: 0.1367, Val Loss: 0.3688, Val Accuracy: 67.81%
Model saved at epoch 11 with Val Accuracy: 67.81%

Epoch 12/20: Training...


                                                                       

Epoch 12/20: Validating...


                                                                     

Epoch 12/20, Train Loss: 0.1243, Val Loss: 0.3619, Val Accuracy: 67.60%
Model saved at epoch 12 with Val Accuracy: 67.60%

Epoch 13/20: Training...


                                                                       

Epoch 13/20: Validating...


                                                                     

Epoch 13/20, Train Loss: 0.1321, Val Loss: 0.3970, Val Accuracy: 66.81%

Epoch 14/20: Training...


                                                                       

Epoch 14/20: Validating...


                                                                     

Epoch 14/20, Train Loss: 0.1455, Val Loss: 0.4410, Val Accuracy: 65.07%

Epoch 15/20: Training...


                                                                       

Epoch 15/20: Validating...


                                                                     

Epoch 15/20, Train Loss: 0.1495, Val Loss: 0.3954, Val Accuracy: 65.78%

Epoch 16/20: Training...


                                                                       

Epoch 16/20: Validating...


                                                                     

Epoch 16/20, Train Loss: 0.1332, Val Loss: 0.3619, Val Accuracy: 68.21%

Epoch 17/20: Training...


                                                                       

Epoch 17/20: Validating...


                                                                     

Epoch 17/20, Train Loss: 0.1166, Val Loss: 0.3359, Val Accuracy: 69.09%
Model saved at epoch 17 with Val Accuracy: 69.09%

Epoch 18/20: Training...


                                                                       

Epoch 18/20: Validating...


                                                                     

Epoch 18/20, Train Loss: 0.1115, Val Loss: 0.3505, Val Accuracy: 68.46%

Epoch 19/20: Training...


                                                                       

Epoch 19/20: Validating...


                                                                     

Epoch 19/20, Train Loss: 0.1206, Val Loss: 0.3900, Val Accuracy: 67.41%

Epoch 20/20: Training...


                                                                       

Epoch 20/20: Validating...


                                                                     

Epoch 20/20, Train Loss: 0.1351, Val Loss: 0.4151, Val Accuracy: 65.94%




In [12]:
# Testing Function
def test_model(model, test_loader, emotion_classes, model_save_path):
    model.load_state_dict(torch.load(model_save_path))
    model.eval()
    correct = [0] * len(emotion_classes)
    total = [0] * len(emotion_classes)
    print("Testing the model...")
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Testing Progress"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            for i in range(len(labels)):
                label = labels[i].item()
                total[label] += 1
                correct[label] += (predicted[i] == label).item()

    for i, emotion in enumerate(emotion_classes):
        accuracy = 100 * correct[i] / total[i] if total[i] > 0 else 0
        print(f"{emotion}: {accuracy:.2f}%")
    overall_accuracy = sum(correct) / sum(total)
    print(f"Overall Test Accuracy: {overall_accuracy * 100:.2f}%")

test_model(model, test_loader, emotion_classes, model_save_path)

  model.load_state_dict(torch.load(model_save_path))


Testing the model...


Testing Progress: 100%|██████████| 604/604 [02:32<00:00,  3.95it/s]

Anger: 54.04%
Contempt: 81.70%
Disgust: 68.39%
Fear: 51.29%
Happy: 88.52%
Neutral: 83.17%
Sad: 59.17%
Surprise: 68.39%
Overall Test Accuracy: 69.83%



