In [4]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from timm import create_model

In [None]:
# Paths
data_dir = "../DATA_PREPARE_ATT_02/AffectNet"  # Updated path for colored images
model_save_path = "efficientnetv2_rw_s_emotion_model.pth"
checkpoint_path = "training_checkpoint_v2_rw_s.pth"

In [6]:
# Configuration
batch_size = 16
initial_lr = 1e-3
num_classes = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Emotion categories
emotion_classes = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]

Using device: cuda


In [7]:
# Data Preparation
def prepare_data_loaders(data_dir, batch_size):
    transform_train = transforms.Compose([
        transforms.Resize((260, 260)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    transform_val_test = transforms.Compose([
        transforms.Resize((260, 260)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    print("Loading datasets...")
    train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform_train)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform_val_test)
    test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform_val_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, test_loader, train_dataset

In [8]:
# Load Model with Adaptive Freezing
def load_model(num_classes):
    print("Loading and configuring the model...")
    model = create_model('efficientnetv2_rw_s', pretrained=True, num_classes=num_classes, in_chans=3)

    # Freeze all layers initially
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze the classifier
    for param in model.get_classifier().parameters():
        param.requires_grad = True

    return model.to(device)

In [9]:
# Adjust Model Trainable Layers
def adjust_trainable_layers(model, num_blocks_to_unfreeze):
    blocks = list(model.blocks)
    num_blocks = len(blocks)
    start_unfreeze = max(0, num_blocks - num_blocks_to_unfreeze)

    print(f"Unfreezing {num_blocks_to_unfreeze}/{num_blocks} blocks...")

    for i, block in enumerate(blocks):
        for param in block.parameters():
            param.requires_grad = i >= start_unfreeze

    return model

In [10]:
# Save Checkpoint
def save_checkpoint(model, optimizer, epoch, path):
    print(f"Saving checkpoint for epoch {epoch}...")
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, path)

In [11]:
# Load Checkpoint
def load_checkpoint(model, optimizer, path):
    print(f"Loading checkpoint from {path}...")
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

In [13]:
# Training Loop with Adaptive Adjustments
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, start_epoch=0):
    best_val_accuracy = 0.0
    num_blocks_to_unfreeze = 1  # Start with one block trainable

    for epoch in range(start_epoch, num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        model.train()
        train_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)
        print(f"Training Loss: {train_loss:.4f}")

        # Validation
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validating"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = correct / total
        print(f"Validation Loss: {val_loss / len(val_loader):.4f}, Accuracy: {val_accuracy * 100:.2f}%")

        # Save checkpoint
        save_checkpoint(model, optimizer, epoch, checkpoint_path)

        # Save the best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), model_save_path)
            print(f"New best model saved with accuracy: {val_accuracy * 100:.2f}%")

        # Adaptive adjustments
        if val_accuracy < 0.7:  # If performance stagnates, unfreeze additional layers
            num_blocks_to_unfreeze += 1
            model = adjust_trainable_layers(model, num_blocks_to_unfreeze)
            optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)

In [16]:
# Main Script
train_loader, val_loader, test_loader, train_dataset = prepare_data_loaders(data_dir, batch_size)
criterion = nn.CrossEntropyLoss()
model = load_model(num_classes)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)

# Resume training if checkpoint exists
start_epoch = 0
if os.path.exists(checkpoint_path):
    start_epoch = load_checkpoint(model, optimizer, checkpoint_path) + 1

train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=20, start_epoch=start_epoch)

Loading datasets...
Loading and configuring the model...

Epoch 1/20


Training: 100%|██████████| 2500/2500 [14:55<00:00,  2.79it/s]


Training Loss: 1.9652


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.37it/s]


Validation Loss: 1.9973, Accuracy: 28.00%
Saving checkpoint for epoch 0...
New best model saved with accuracy: 28.00%
Unfreezing 2/6 blocks...

Epoch 2/20


Training: 100%|██████████| 2500/2500 [14:57<00:00,  2.79it/s]  


Training Loss: 1.2603


Validating: 100%|██████████| 50/50 [00:13<00:00,  3.77it/s]


Validation Loss: 1.4398, Accuracy: 48.38%
Saving checkpoint for epoch 1...
New best model saved with accuracy: 48.38%
Unfreezing 3/6 blocks...

Epoch 3/20


Training: 100%|██████████| 2500/2500 [11:03<00:00,  3.77it/s]


Training Loss: 1.1354


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.21it/s]


Validation Loss: 1.3096, Accuracy: 50.75%
Saving checkpoint for epoch 2...
New best model saved with accuracy: 50.75%
Unfreezing 4/6 blocks...

Epoch 4/20


Training: 100%|██████████| 2500/2500 [15:33<00:00,  2.68it/s]


Training Loss: 1.0777


Validating: 100%|██████████| 50/50 [00:14<00:00,  3.38it/s]


Validation Loss: 1.2698, Accuracy: 52.75%
Saving checkpoint for epoch 3...
New best model saved with accuracy: 52.75%
Unfreezing 5/6 blocks...

Epoch 5/20


Training: 100%|██████████| 2500/2500 [13:44<00:00,  3.03it/s]


Training Loss: 1.0406


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.05it/s]


Validation Loss: 1.3045, Accuracy: 55.38%
Saving checkpoint for epoch 4...
New best model saved with accuracy: 55.38%
Unfreezing 6/6 blocks...

Epoch 6/20


Training: 100%|██████████| 2500/2500 [13:27<00:00,  3.10it/s]


Training Loss: 1.0023


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.17it/s]


Validation Loss: 1.2093, Accuracy: 54.37%
Saving checkpoint for epoch 5...
Unfreezing 7/6 blocks...

Epoch 7/20


Training: 100%|██████████| 2500/2500 [13:26<00:00,  3.10it/s]


Training Loss: 0.9578


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.20it/s]


Validation Loss: 1.3142, Accuracy: 54.50%
Saving checkpoint for epoch 6...
Unfreezing 8/6 blocks...

Epoch 8/20


Training: 100%|██████████| 2500/2500 [13:27<00:00,  3.10it/s]


Training Loss: 0.9260


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.25it/s]


Validation Loss: 1.2309, Accuracy: 54.75%
Saving checkpoint for epoch 7...
Unfreezing 9/6 blocks...

Epoch 9/20


Training: 100%|██████████| 2500/2500 [13:28<00:00,  3.09it/s]


Training Loss: 0.8947


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.23it/s]


Validation Loss: 1.1970, Accuracy: 55.12%
Saving checkpoint for epoch 8...
Unfreezing 10/6 blocks...

Epoch 10/20


Training: 100%|██████████| 2500/2500 [13:28<00:00,  3.09it/s]


Training Loss: 0.8631


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.27it/s]


Validation Loss: 1.2584, Accuracy: 56.12%
Saving checkpoint for epoch 9...
New best model saved with accuracy: 56.12%
Unfreezing 11/6 blocks...

Epoch 11/20


Training: 100%|██████████| 2500/2500 [13:30<00:00,  3.09it/s]


Training Loss: 0.8370


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.11it/s]


Validation Loss: 1.2241, Accuracy: 57.00%
Saving checkpoint for epoch 10...
New best model saved with accuracy: 57.00%
Unfreezing 12/6 blocks...

Epoch 12/20


Training: 100%|██████████| 2500/2500 [13:27<00:00,  3.09it/s]


Training Loss: 0.7954


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.14it/s]


Validation Loss: 1.2828, Accuracy: 56.50%
Saving checkpoint for epoch 11...
Unfreezing 13/6 blocks...

Epoch 13/20


Training: 100%|██████████| 2500/2500 [13:25<00:00,  3.10it/s]


Training Loss: 0.7618


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.25it/s]


Validation Loss: 1.3682, Accuracy: 52.12%
Saving checkpoint for epoch 12...
Unfreezing 14/6 blocks...

Epoch 14/20


Training: 100%|██████████| 2500/2500 [13:26<00:00,  3.10it/s]


Training Loss: 0.7307


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.25it/s]


Validation Loss: 1.2580, Accuracy: 55.75%
Saving checkpoint for epoch 13...
Unfreezing 15/6 blocks...

Epoch 15/20


Training: 100%|██████████| 2500/2500 [13:26<00:00,  3.10it/s]


Training Loss: 0.6983


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.26it/s]


Validation Loss: 1.2797, Accuracy: 55.38%
Saving checkpoint for epoch 14...
Unfreezing 16/6 blocks...

Epoch 16/20


Training: 100%|██████████| 2500/2500 [13:34<00:00,  3.07it/s]


Training Loss: 0.6617


Validating: 100%|██████████| 50/50 [00:10<00:00,  4.95it/s]


Validation Loss: 1.3955, Accuracy: 55.62%
Saving checkpoint for epoch 15...
Unfreezing 17/6 blocks...

Epoch 17/20


Training: 100%|██████████| 2500/2500 [13:28<00:00,  3.09it/s]


Training Loss: 0.6310


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.31it/s]


Validation Loss: 1.3601, Accuracy: 54.62%
Saving checkpoint for epoch 16...
Unfreezing 18/6 blocks...

Epoch 18/20


Training: 100%|██████████| 2500/2500 [13:24<00:00,  3.11it/s]


Training Loss: 0.5989


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.25it/s]


Validation Loss: 1.4066, Accuracy: 55.62%
Saving checkpoint for epoch 17...
Unfreezing 19/6 blocks...

Epoch 19/20


Training: 100%|██████████| 2500/2500 [13:24<00:00,  3.11it/s]


Training Loss: 0.5616


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.24it/s]


Validation Loss: 1.4815, Accuracy: 54.50%
Saving checkpoint for epoch 18...
Unfreezing 20/6 blocks...

Epoch 20/20


Training: 100%|██████████| 2500/2500 [13:25<00:00,  3.10it/s]


Training Loss: 0.5362


Validating: 100%|██████████| 50/50 [00:09<00:00,  5.23it/s]


Validation Loss: 1.4864, Accuracy: 54.62%
Saving checkpoint for epoch 19...
Unfreezing 21/6 blocks...
