In [19]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from timm import create_model
from focal_loss import FocalLoss
from lookahead_pytorch import Lookahead
from face_recognition import face_locations
import numpy as np
from collections import defaultdict
import torch.nn.functional as F
from sklearn.utils.class_weight import compute_class_weight

In [20]:
# Paths
data_dir = "../DATA_PREPARE_ATT_02/AffectNet"
model_save_path = "efficientnet_b2_emotion_model.pth"
checkpoint_path = "adaptive_training_checkpoint_efficientnet_b2.pth"

In [21]:
# Configuration
min_batch_size, max_batch_size = 16, 64
initial_batch_size = 16
initial_lr = 1e-3
num_classes = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Emotion categories
emotion_classes = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]


Using device: cuda


In [22]:
# Region-of-Interest Detection
def detect_roi(image_tensor):
    """Focuses on facial regions like eyes and mouth."""
    images_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    faces = face_locations(images_np)
    if faces:
        top, right, bottom, left = faces[0]  # Assume first detected face
        cropped = images_np[top:bottom, left:right]
        return torch.tensor(cropped).permute(2, 0, 1).float() / 255.0
    return image_tensor  # If no face found, return original image

In [23]:
# Data Preparation
def prepare_data_loaders(data_dir, batch_size, augmentations):
    train_transforms = transforms.Compose(augmentations + [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    test_transforms = transforms.Compose([
        transforms.Resize((260, 260)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    print("Loading datasets...")
    train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_transforms)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=test_transforms)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, train_dataset

In [24]:
# Load Model with CBAM
def load_model(num_classes):
    print("Loading EfficientNet_b2 model with CBAM...")
    model = create_model('efficientnet_b2', pretrained=True, num_classes=num_classes, in_chans=3)
    for param in model.parameters():
        param.requires_grad = False
    for param in model.get_classifier().parameters():
        param.requires_grad = True
    return model.to(device)

# Adjust Trainable Layers Dynamically
def adjust_trainable_layers(model, num_layers_to_unfreeze):
    total_layers = len(list(model.parameters()))
    layers_to_unfreeze = int(total_layers * (num_layers_to_unfreeze / 100))
    for i, param in enumerate(model.parameters()):
        param.requires_grad = i >= total_layers - layers_to_unfreeze

In [25]:
# Dynamically Update Augmentations
def update_augmentations(augmentations, epoch_performance):
    """
    Dynamically adjusts data augmentations based on epoch performance.

    Args:
        augmentations (list): Current list of augmentations.
        epoch_performance (dict): Dictionary containing performance metrics (val_accuracy, train_loss, etc.).

    Returns:
        list: Updated list of augmentations.
    """
    # If validation accuracy is low, add stronger augmentations to increase robustness
    if epoch_performance['val_accuracy'] < 0.7:
        augmentations.append(transforms.RandomApply([
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.2),
        ], p=0.5))
        augmentations.append(transforms.RandomErasing(p=0.2))

    # If the training loss is very low (indicating potential overfitting), reduce augmentation intensity
    elif epoch_performance['train_loss'] < 0.3:
        if len(augmentations) > 0:
            augmentations.pop()

    # Add specific augmentations for balancing performance
    if epoch_performance['val_accuracy'] >= 0.7 and epoch_performance['val_accuracy'] < 0.8:
        if not any(isinstance(aug, transforms.RandomRotation) for aug in augmentations):
            augmentations.insert(0, transforms.RandomRotation(degrees=15))
        if not any(isinstance(aug, transforms.RandomHorizontalFlip) for aug in augmentations):
            augmentations.insert(0, transforms.RandomHorizontalFlip(p=0.5))
        if not any(isinstance(aug, transforms.RandomResizedCrop) for aug in augmentations):
            augmentations.insert(0, transforms.RandomResizedCrop(size=(260, 260), scale=(0.8, 1.0)))

    # If validation accuracy is plateauing, introduce aggressive augmentations to break stagnation
    if epoch_performance.get('plateau', False):
        augmentations.extend([
            transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10),
            transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0))
        ])

    return augmentations

In [26]:
# Adjust Batch Size Dynamically
def adjust_batch_size(current_batch_size, epoch_performance):
    if epoch_performance['val_accuracy'] < 0.5:
        return min(current_batch_size + 8, max_batch_size)
    elif epoch_performance['val_accuracy'] > 0.75:
        return max(current_batch_size - 8, min_batch_size)
    return current_batch_size

In [27]:
# Emotion-specific Loss Weighting
def calculate_emotion_weights(train_dataset, num_classes):
    labels = [sample[1] for sample in train_dataset.samples]
    class_weights = compute_class_weight('balanced', classes=np.arange(num_classes), y=labels)
    return torch.tensor(class_weights, dtype=torch.float).to(device)

In [28]:
# Training Loop
def train_model(model, train_loader, val_loader, optimizer, scheduler, augmentations, num_epochs, max_epochs):
    best_accuracy = 0
    epoch_performance = defaultdict(float)
    num_layers_to_unfreeze = 10

    for epoch in range(num_epochs, max_epochs):
        print(f"\nEpoch {epoch+1}/{max_epochs}")
        train_loss, val_loss, correct, total = 0.0, 0.0, 0, 0
        model.train()

        # Training
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs_probabilities = F.softmax(outputs, dim=1)  # Convert logits to probabilities
            loss = criterion(outputs_probabilities, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validation
        model.eval()
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validating"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                outputs_probabilities = F.softmax(outputs, dim=1)  # Convert logits to probabilities
                loss = criterion(outputs_probabilities, labels)  # Use probabilities with FocalLoss
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)  # Use logits for prediction
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = correct / total
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        # Log Performance
        epoch_performance['train_loss'] = train_loss
        epoch_performance['val_loss'] = val_loss
        epoch_performance['val_accuracy'] = val_accuracy

        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}")

        # Save Best Model
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), model_save_path)
            print(f"New best model saved with accuracy: {val_accuracy:.2f}")

        # Update Learning Rate
        scheduler.step(val_loss)

        # Adjust Training Parameters
        adjust_trainable_layers(model, num_layers_to_unfreeze)
        num_layers_to_unfreeze += 10
        augmentations = update_augmentations(augmentations, epoch_performance)

        # Save Checkpoint
        checkpoint = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
        torch.save(checkpoint, checkpoint_path)

        # Dynamic Epochs
        if epoch > 10 and val_accuracy > 0.8:
            print("Stopping early as accuracy is satisfactory.")
            break

In [31]:
# Main Script
augmentations = [transforms.Resize((260, 260))]
train_loader, val_loader, train_dataset = prepare_data_loaders(data_dir, initial_batch_size, augmentations)
class_weights = calculate_emotion_weights(train_dataset, num_classes)

criterion = FocalLoss(gamma=2.0)  # Pass alpha if class weighting is required
model = load_model(num_classes)

base_optimizer = optim.AdamW(model.parameters(), lr=initial_lr)
optimizer = Lookahead(base_optimizer)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)

train_model(model, train_loader, val_loader, optimizer, scheduler, augmentations, num_epochs=0, max_epochs=50)

Loading datasets...
Loading EfficientNet_b2 model with CBAM...

Epoch 1/50


Training: 100%|██████████| 2500/2500 [16:58<00:00,  2.45it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  5.64it/s]


Train Loss: 1.8734, Val Loss: 1.8405, Val Accuracy: 0.23
New best model saved with accuracy: 0.23

Epoch 2/50


Training: 100%|██████████| 2500/2500 [03:15<00:00, 12.80it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.29it/s]


Train Loss: 1.0362, Val Loss: 1.1101, Val Accuracy: 0.41
New best model saved with accuracy: 0.41

Epoch 3/50


Training: 100%|██████████| 2500/2500 [03:44<00:00, 11.13it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.36it/s]


Train Loss: 0.8282, Val Loss: 1.0138, Val Accuracy: 0.45
New best model saved with accuracy: 0.45

Epoch 4/50


Training: 100%|██████████| 2500/2500 [04:12<00:00,  9.92it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.07it/s]


Train Loss: 0.7344, Val Loss: 0.9149, Val Accuracy: 0.49
New best model saved with accuracy: 0.49

Epoch 5/50


Training: 100%|██████████| 2500/2500 [04:47<00:00,  8.70it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.06it/s]


Train Loss: 0.6603, Val Loss: 0.8339, Val Accuracy: 0.51
New best model saved with accuracy: 0.51

Epoch 6/50


Training: 100%|██████████| 2500/2500 [05:20<00:00,  7.81it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.35it/s]


Train Loss: 0.6014, Val Loss: 0.8190, Val Accuracy: 0.52
New best model saved with accuracy: 0.52

Epoch 7/50


Training: 100%|██████████| 2500/2500 [05:33<00:00,  7.50it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.33it/s]


Train Loss: 0.5287, Val Loss: 0.8936, Val Accuracy: 0.52

Epoch 8/50


Training: 100%|██████████| 2500/2500 [06:01<00:00,  6.92it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.31it/s]


Train Loss: 0.4611, Val Loss: 0.8482, Val Accuracy: 0.53
New best model saved with accuracy: 0.53

Epoch 9/50


Training: 100%|██████████| 2500/2500 [06:49<00:00,  6.10it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  5.98it/s]


Train Loss: 0.4041, Val Loss: 0.8619, Val Accuracy: 0.54
New best model saved with accuracy: 0.54

Epoch 10/50


Training: 100%|██████████| 2500/2500 [07:49<00:00,  5.32it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.26it/s]


Train Loss: 0.1890, Val Loss: 1.1025, Val Accuracy: 0.50

Epoch 11/50


Training: 100%|██████████| 2500/2500 [09:08<00:00,  4.56it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.35it/s]


Train Loss: 0.1089, Val Loss: 1.1737, Val Accuracy: 0.52

Epoch 12/50


Training: 100%|██████████| 2500/2500 [09:07<00:00,  4.57it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.07it/s]


Train Loss: 0.0912, Val Loss: 1.2420, Val Accuracy: 0.53

Epoch 13/50


Training: 100%|██████████| 2500/2500 [09:11<00:00,  4.53it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.29it/s]


Train Loss: 0.0409, Val Loss: 1.3832, Val Accuracy: 0.54

Epoch 14/50


Training: 100%|██████████| 2500/2500 [09:12<00:00,  4.53it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.12it/s]


Train Loss: 0.0263, Val Loss: 1.4299, Val Accuracy: 0.54

Epoch 15/50


Training: 100%|██████████| 2500/2500 [09:13<00:00,  4.52it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.18it/s]


Train Loss: 0.0261, Val Loss: 1.3713, Val Accuracy: 0.54
New best model saved with accuracy: 0.54

Epoch 16/50


Training: 100%|██████████| 2500/2500 [09:08<00:00,  4.56it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.34it/s]


Train Loss: 0.0132, Val Loss: 1.4884, Val Accuracy: 0.55
New best model saved with accuracy: 0.55

Epoch 17/50


Training: 100%|██████████| 2500/2500 [09:11<00:00,  4.53it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.11it/s]


Train Loss: 0.0087, Val Loss: 1.4892, Val Accuracy: 0.54

Epoch 18/50


Training: 100%|██████████| 2500/2500 [09:15<00:00,  4.50it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.28it/s]


Train Loss: 0.0086, Val Loss: 1.5495, Val Accuracy: 0.54

Epoch 19/50


Training: 100%|██████████| 2500/2500 [09:15<00:00,  4.50it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.23it/s]


Train Loss: 0.0055, Val Loss: 1.6024, Val Accuracy: 0.55

Epoch 20/50


Training: 100%|██████████| 2500/2500 [09:11<00:00,  4.53it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.00it/s]


Train Loss: 0.0039, Val Loss: 1.6779, Val Accuracy: 0.54

Epoch 21/50


Training: 100%|██████████| 2500/2500 [09:11<00:00,  4.53it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.30it/s]


Train Loss: 0.0044, Val Loss: 1.6805, Val Accuracy: 0.55
New best model saved with accuracy: 0.55

Epoch 22/50


Training: 100%|██████████| 2500/2500 [09:07<00:00,  4.57it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.21it/s]


Train Loss: 0.0030, Val Loss: 1.7140, Val Accuracy: 0.56
New best model saved with accuracy: 0.56

Epoch 23/50


Training: 100%|██████████| 2500/2500 [09:11<00:00,  4.53it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  5.99it/s]


Train Loss: 0.0027, Val Loss: 1.7347, Val Accuracy: 0.56

Epoch 24/50


Training: 100%|██████████| 2500/2500 [09:09<00:00,  4.55it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.31it/s]


Train Loss: 0.0026, Val Loss: 1.7087, Val Accuracy: 0.55

Epoch 25/50


Training: 100%|██████████| 2500/2500 [09:09<00:00,  4.55it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.08it/s]


Train Loss: 0.0022, Val Loss: 1.7520, Val Accuracy: 0.55

Epoch 26/50


Training:   5%|▍         | 120/2500 [00:39<12:59,  3.05it/s] 


KeyboardInterrupt: 