In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from torchvision.models.efficientnet import EfficientNet_B0_Weights

In [2]:
# Paths
data_dir = "../FBMM/Unsplitted_Ready_Sets/set_01_class_balanced_augs_applied_splitted"
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "val")
test_dir = os.path.join(data_dir, "test")
model_save_path = "./models/optimized_efficientnet_b0_emotion_model.pth"

# Configuration
batch_size = 8
num_epochs = 50
initial_lr = 1e-4 
weight_decay = 1e-4  
num_classes = 7
img_height, img_width = 224, 224
seed = 42
accumulation_steps = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Emotion categories
emotion_classes = ["Anger", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]


In [3]:
# ✅ Use EfficientNet-B0 Weights
weights = EfficientNet_B0_Weights.IMAGENET1K_V1

In [4]:
# Data Augmentation & Normalization
transform = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ✅ Correct Normalization
])

In [5]:
# Load Datasets
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
val_dataset = datasets.ImageFolder(root=val_dir, transform=transform)
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)

# Compute Class Weights
def compute_class_weights(dataset, num_classes):
    labels = np.array([label for _, label in dataset.samples])
    class_counts = np.bincount(labels, minlength=num_classes)
    class_weights = 1.0 / (class_counts + 1e-6)
    class_weights /= class_weights.sum()
    return torch.tensor(class_weights, dtype=torch.float32).to(device)

class_weights = compute_class_weights(train_dataset, num_classes)

# Data Loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)


In [6]:
# ✅ Load Pretrained EfficientNet-B0 & Fine-Tune
def load_model(num_classes):
    print("Loading and configuring the EfficientNet-B0 model...")

    model = models.efficientnet_b0(weights=weights)  # ✅ Use EfficientNet-B0
    model = model.to(memory_format=torch.channels_last)

    # Freeze all layers initially
    for param in model.parameters():
        param.requires_grad = False

    # Unlock last 20% of convolutional layers + classifier head
    total_layers = len(list(model.features.children()))
    fine_tune_layers = int(total_layers * 0.2)

    for layer in list(model.features.children())[-fine_tune_layers:]:
        for param in layer.parameters():
            param.requires_grad = True

    # Modify classifier head
    model.classifier = nn.Sequential(
        nn.Dropout(0.6),
        nn.Linear(model.classifier[1].in_features, num_classes)
    )

    # Ensure classifier is trainable
    for param in model.classifier.parameters():
        param.requires_grad = True

    return model.to(device)

# Load model
model = load_model(num_classes)

Loading and configuring the EfficientNet-B0 model...


In [7]:
# ✅ Fix: Use Label Smoothing to Prevent NaNs
criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)

# ✅ Fix: Reduce Learning Rate & Weight Decay
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr, weight_decay=weight_decay)

# ✅ Learning Rate Warm-Up & Scheduling
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=initial_lr, epochs=num_epochs, steps_per_epoch=len(train_loader), pct_start=0.1)

# ✅ Enable Mixed Precision Training
scaler = torch.amp.GradScaler(device="cuda")

In [8]:
# Training Loop with Full Logging (Loss, Accuracy, Best Model)
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, patience=7):
    best_val_loss = np.inf
    epochs_no_improve = 0

    for epoch in range(1, num_epochs + 1):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for i, (images, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels) / accumulation_steps

            scaler.scale(loss).backward()

            # ✅ Gradient Clipping to Prevent NaNs
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total

        # ✅ Validation Step
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)

        val_loss /= val_total
        val_acc = val_correct / val_total

        # ✅ Print Full Training Details
        print(f"\nEpoch {epoch}:")
        print(f"Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.4f}")
        print(f"Validation Loss: {val_loss:.4f} | Validation Accuracy: {val_acc:.4f}")

        # ✅ Save Best Model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), model_save_path)
            print("✅ Model Saved! (Best Validation Loss Improved)")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("⏳ Early Stopping Triggered!")
                break

# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs)

# Load best model for testing
model.load_state_dict(torch.load(model_save_path))

# Evaluate model
def evaluate_model(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"\n🎯 Final Test Accuracy: {accuracy:.4f}")

evaluate_model(model, test_loader)

  with torch.cuda.amp.autocast():
Epoch 1/50: 100%|██████████| 15502/15502 [1:04:06<00:00,  4.03it/s]



Epoch 1:
Train Loss: 0.2464 | Train Accuracy: 0.1513
Validation Loss: 1.9453 | Validation Accuracy: 0.1617
✅ Model Saved! (Best Validation Loss Improved)


Epoch 2/50: 100%|██████████| 15502/15502 [58:39<00:00,  4.40it/s]  



Epoch 2:
Train Loss: 0.2449 | Train Accuracy: 0.1624
Validation Loss: 1.9348 | Validation Accuracy: 0.1823
✅ Model Saved! (Best Validation Loss Improved)


Epoch 3/50: 100%|██████████| 15502/15502 [1:03:35<00:00,  4.06it/s]



Epoch 3:
Train Loss: 0.2438 | Train Accuracy: 0.1695
Validation Loss: 1.9254 | Validation Accuracy: 0.1990
✅ Model Saved! (Best Validation Loss Improved)


Epoch 4/50: 100%|██████████| 15502/15502 [1:00:57<00:00,  4.24it/s]



Epoch 4:
Train Loss: 0.2429 | Train Accuracy: 0.1780
Validation Loss: 1.9191 | Validation Accuracy: 0.2078
✅ Model Saved! (Best Validation Loss Improved)


Epoch 5/50: 100%|██████████| 15502/15502 [1:03:15<00:00,  4.08it/s]



Epoch 5:
Train Loss: 0.2418 | Train Accuracy: 0.1876
Validation Loss: 1.9076 | Validation Accuracy: 0.2263
✅ Model Saved! (Best Validation Loss Improved)


Epoch 6/50: 100%|██████████| 15502/15502 [1:58:01<00:00,  2.19it/s]     



Epoch 6:
Train Loss: 0.2406 | Train Accuracy: 0.1975
Validation Loss: 1.8954 | Validation Accuracy: 0.2475
✅ Model Saved! (Best Validation Loss Improved)


Epoch 7/50: 100%|██████████| 15502/15502 [1:03:48<00:00,  4.05it/s]



Epoch 7:
Train Loss: 0.2390 | Train Accuracy: 0.2120
Validation Loss: 1.8822 | Validation Accuracy: 0.2664
✅ Model Saved! (Best Validation Loss Improved)


Epoch 8/50: 100%|██████████| 15502/15502 [1:02:26<00:00,  4.14it/s]



Epoch 8:
Train Loss: 0.2372 | Train Accuracy: 0.2260
Validation Loss: 1.8676 | Validation Accuracy: 0.2768
✅ Model Saved! (Best Validation Loss Improved)


Epoch 9/50: 100%|██████████| 15502/15502 [1:09:52<00:00,  3.70it/s]  



Epoch 9:
Train Loss: 0.2355 | Train Accuracy: 0.2389
Validation Loss: 1.8523 | Validation Accuracy: 0.2997
✅ Model Saved! (Best Validation Loss Improved)


Epoch 10/50: 100%|██████████| 15502/15502 [1:06:00<00:00,  3.91it/s]



Epoch 10:
Train Loss: 0.2338 | Train Accuracy: 0.2496
Validation Loss: 1.8352 | Validation Accuracy: 0.3186
✅ Model Saved! (Best Validation Loss Improved)


Epoch 11/50: 100%|██████████| 15502/15502 [1:11:48<00:00,  3.60it/s]  



Epoch 11:
Train Loss: 0.2318 | Train Accuracy: 0.2648
Validation Loss: 1.8181 | Validation Accuracy: 0.3298
✅ Model Saved! (Best Validation Loss Improved)


Epoch 12/50: 100%|██████████| 15502/15502 [1:50:42<00:00,  2.33it/s]   



Epoch 12:
Train Loss: 0.2301 | Train Accuracy: 0.2767
Validation Loss: 1.8008 | Validation Accuracy: 0.3382
✅ Model Saved! (Best Validation Loss Improved)


Epoch 13/50: 100%|██████████| 15502/15502 [1:13:32<00:00,  3.51it/s]  



Epoch 13:
Train Loss: 0.2280 | Train Accuracy: 0.2887
Validation Loss: 1.7870 | Validation Accuracy: 0.3528
✅ Model Saved! (Best Validation Loss Improved)


Epoch 14/50:  11%|█         | 1733/15502 [08:19<1:06:08,  3.47it/s]


KeyboardInterrupt: 