In [71]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from timm import create_model
from focal_loss import FocalLoss
from sklearn.utils.class_weight import compute_class_weight
from lookahead_pytorch import Lookahead
from face_recognition import face_locations
import numpy as np
from collections import defaultdict
import torch.nn.functional as F

In [72]:
# Paths
data_dir = "../DATA_PREPARE_ATT_02/AffectNet"
model_save_path = "efficientnet_b2_emotion_model.pth"
checkpoint_path = "adaptive_training_checkpoint_efficientnet_b2.pth"

In [73]:
# Configuration
batch_size = 16
min_batch_size, max_batch_size = 16, 64
initial_lr = 1e-3
num_classes = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Emotion categories
emotion_classes = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]

Using device: cuda


In [74]:
# Region-of-Interest Detection
def detect_roi(image_tensor):
    """Focuses on facial regions like eyes and mouth."""
    images_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    faces = face_locations(images_np)
    if faces:
        top, right, bottom, left = faces[0]  # Assume first detected face
        cropped = images_np[top:bottom, left:right]
        return torch.tensor(cropped).permute(2, 0, 1).float() / 255.0
    return image_tensor  # If no face found, return original image

In [75]:
# Data Preparation
def prepare_data_loaders(data_dir, batch_size, augmentations):
    train_transforms = transforms.Compose(augmentations + [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    test_transforms = transforms.Compose([
        transforms.Resize((260, 260)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    print("Loading datasets...")
    train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_transforms)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=test_transforms)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, train_dataset

In [76]:
# Load Model with CBAM
def load_model(num_classes):
    print("Loading EfficientNet_b2 model with CBAM...")
    model = create_model('efficientnet_b2', pretrained=True, num_classes=num_classes, in_chans=3)
    for param in model.parameters():
        param.requires_grad = False
    for param in model.get_classifier().parameters():
        param.requires_grad = True
    return model.to(device)

In [77]:
# Adjust Trainable Layers Dynamically
def adjust_trainable_layers(model, num_layers_to_unfreeze):
    total_layers = len(list(model.parameters()))
    layers_to_unfreeze = total_layers * (num_layers_to_unfreeze / 100)
    for i, param in enumerate(model.parameters()):
        param.requires_grad = i >= total_layers - layers_to_unfreeze

In [78]:
# Dynamically Update Augmentations
def update_augmentations(augmentations, epoch_performance):
    if epoch_performance['val_accuracy'] < 0.7:
        augmentations.append(transforms.ColorJitter(brightness=0.2, contrast=0.2))
    elif epoch_performance['train_loss'] < 0.3:
        if len(augmentations) > 0:
            augmentations.pop()
    return augmentations

In [79]:
# Adjust Batch Size Dynamically
def adjust_batch_size(current_batch_size, epoch_performance):
    if epoch_performance['val_accuracy'] < 0.5:
        return min(current_batch_size + 8, max_batch_size)
    elif epoch_performance['val_accuracy'] > 0.75:
        return max(current_batch_size - 8, min_batch_size)
    return current_batch_size

In [82]:
# Training Loop
def train_model(model, train_loader, val_loader, optimizer, scheduler, augmentations, num_epochs, max_epochs):
    best_accuracy = 0
    epoch_performance = defaultdict(float)
    num_layers_to_unfreeze = 10

    for epoch in range(num_epochs, max_epochs):
        print(f"\nEpoch {epoch+1}/{max_epochs}")
        train_loss, val_loss, correct, total = 0.0, 0.0, 0, 0
        model.train()

        # Training
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs_probabilities = F.softmax(outputs, dim=1)  # Convert logits to probabilities
            loss = criterion(outputs_probabilities, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validation
        model.eval()
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validating"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                outputs_probabilities = F.softmax(outputs, dim=1)  # Convert logits to probabilities
                loss = criterion(outputs_probabilities, labels)  # Use probabilities with FocalLoss
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)  # Use logits for prediction
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = correct / total
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        # Log Performance
        epoch_performance['train_loss'] = train_loss
        epoch_performance['val_loss'] = val_loss
        epoch_performance['val_accuracy'] = val_accuracy

        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}")

        # Save Best Model
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), model_save_path)
            print(f"New best model saved with accuracy: {val_accuracy:.2f}")

        # Update Learning Rate
        scheduler.step(val_loss)

        # Adjust Training Parameters
        adjust_trainable_layers(model, num_layers_to_unfreeze)
        num_layers_to_unfreeze += 10
        augmentations = update_augmentations(augmentations, epoch_performance)

        # Save Checkpoint
        checkpoint = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
        torch.save(checkpoint, checkpoint_path)

        # Dynamic Epochs
        if epoch > 10 and val_accuracy > 0.8:
            print("Stopping early as accuracy is satisfactory.")
            break

In [83]:
# Main Script
augmentations = [transforms.Resize((260, 260))]
train_loader, val_loader, train_dataset = prepare_data_loaders(data_dir, batch_size, augmentations)

criterion = FocalLoss(gamma=2.0)  # Pass alpha if class weighting is required
model = load_model(num_classes)

base_optimizer = optim.AdamW(model.parameters(), lr=initial_lr)
optimizer = Lookahead(base_optimizer)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)

train_model(model, train_loader, val_loader, optimizer, scheduler, augmentations, num_epochs=0, max_epochs=50)

Loading datasets...
Loading EfficientNet_b2 model with CBAM...

Epoch 1/50


Training: 100%|██████████| 2500/2500 [02:45<00:00, 15.07it/s]
Validating: 100%|██████████| 50/50 [00:09<00:00,  5.46it/s]


Train Loss: 1.8599, Val Loss: 1.7593, Val Accuracy: 0.25
New best model saved with accuracy: 0.25

Epoch 2/50


Training: 100%|██████████| 2500/2500 [03:19<00:00, 12.55it/s]
Validating: 100%|██████████| 50/50 [00:09<00:00,  5.10it/s]


Train Loss: 1.0194, Val Loss: 1.0662, Val Accuracy: 0.44
New best model saved with accuracy: 0.44

Epoch 3/50


Training: 100%|██████████| 2500/2500 [03:46<00:00, 11.05it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.21it/s]


Train Loss: 0.8061, Val Loss: 0.9961, Val Accuracy: 0.48
New best model saved with accuracy: 0.48

Epoch 4/50


Training: 100%|██████████| 2500/2500 [04:11<00:00,  9.96it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.25it/s]


Train Loss: 0.7250, Val Loss: 0.9196, Val Accuracy: 0.49
New best model saved with accuracy: 0.49

Epoch 5/50


Training: 100%|██████████| 2500/2500 [04:47<00:00,  8.70it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.16it/s]


Train Loss: 0.6481, Val Loss: 0.8633, Val Accuracy: 0.51
New best model saved with accuracy: 0.51

Epoch 6/50


Training: 100%|██████████| 2500/2500 [10:03<00:00,  4.14it/s]  
Validating: 100%|██████████| 50/50 [00:10<00:00,  4.55it/s]


Train Loss: 0.5960, Val Loss: 0.8627, Val Accuracy: 0.54
New best model saved with accuracy: 0.54

Epoch 7/50


Training: 100%|██████████| 2500/2500 [05:41<00:00,  7.32it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.24it/s]


Train Loss: 0.5212, Val Loss: 0.8601, Val Accuracy: 0.52

Epoch 8/50


Training: 100%|██████████| 2500/2500 [06:03<00:00,  6.87it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.15it/s]


Train Loss: 0.4492, Val Loss: 0.8834, Val Accuracy: 0.52

Epoch 9/50


Training: 100%|██████████| 2500/2500 [06:50<00:00,  6.09it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  6.24it/s]


Train Loss: 0.3967, Val Loss: 0.9139, Val Accuracy: 0.51

Epoch 10/50


Training: 100%|██████████| 2500/2500 [07:50<00:00,  5.32it/s]
Validating: 100%|██████████| 50/50 [00:07<00:00,  6.30it/s]


Train Loss: 0.3427, Val Loss: 0.9418, Val Accuracy: 0.50

Epoch 11/50


Training: 100%|██████████| 2500/2500 [09:21<00:00,  4.46it/s]
Validating: 100%|██████████| 50/50 [00:09<00:00,  5.46it/s]


Train Loss: 0.1529, Val Loss: 1.0626, Val Accuracy: 0.53

Epoch 12/50


Training: 100%|██████████| 2500/2500 [09:23<00:00,  4.43it/s]
Validating: 100%|██████████| 50/50 [00:08<00:00,  5.92it/s]


Train Loss: 0.0869, Val Loss: 1.2251, Val Accuracy: 0.54

Epoch 13/50


Training:  11%|█         | 274/2500 [01:13<09:54,  3.74it/s] 


KeyboardInterrupt: 