In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.utils.class_weight import compute_class_weight
from timm import create_model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Paths
data_dir = "../DATA_PREPARE_ATT_06/MixedDataset"
model_save_path = "adaptive_efficientnetv2_rw_s_emotion_model.pth"
checkpoint_path = "adaptive_training_checkpoint_v2_rw_s.pth"

In [3]:
# Configuration
batch_size = 16
initial_lr = 1e-3
num_classes = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Emotion categories
emotion_classes = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]


Using device: cuda


In [4]:
# Data Preparation
def prepare_data_loaders(data_dir, batch_size):
    transform_train = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((260, 260)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ])
    transform_val_test = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((260, 260)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ])

    print("Loading datasets...")
    train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform_train)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform_val_test)
    test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform_val_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, test_loader, train_dataset

In [5]:
# Compute Class Weights
def compute_weights(train_dataset, num_classes):
    targets = [label for _, label in train_dataset.samples]
    class_weights = compute_class_weight('balanced', classes=np.arange(num_classes), y=targets)
    return torch.tensor(class_weights, dtype=torch.float).to(device)

In [6]:
# Load Model with Adaptive Freezing
def load_model(num_classes):
    print("Loading and configuring the model...")
    model = create_model('efficientnetv2_rw_s', pretrained=True, num_classes=num_classes, in_chans=1)

    # Freeze all layers initially
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze the classifier
    for param in model.get_classifier().parameters():
        param.requires_grad = True

    return model.to(device)

In [7]:
# Adjust Model Trainable Layers
def adjust_trainable_layers(model, num_blocks_to_unfreeze):
    blocks = list(model.blocks)
    num_blocks = len(blocks)
    start_unfreeze = max(0, num_blocks - num_blocks_to_unfreeze)

    print(f"Unfreezing {num_blocks_to_unfreeze}/{num_blocks} blocks...")

    for i, block in enumerate(blocks):
        for param in block.parameters():
            param.requires_grad = i >= start_unfreeze

    return model

In [8]:
# Training Loop with Adaptive Adjustments
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs):
    best_val_accuracy = 0.0
    num_blocks_to_unfreeze = 1  # Start with one block trainable

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        model.train()
        train_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)
        print(f"Training Loss: {train_loss:.4f}")

        # Validation
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validating"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = correct / total
        print(f"Validation Loss: {val_loss / len(val_loader):.4f}, Accuracy: {val_accuracy * 100:.2f}%")

        # Adjust model if accuracy improves
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), model_save_path)
            print(f"New best model saved with accuracy: {val_accuracy * 100:.2f}%")

        # Adaptive adjustments
        if val_accuracy < 0.7:  # If performance stagnates, unfreeze additional layers
            num_blocks_to_unfreeze += 1
            model = adjust_trainable_layers(model, num_blocks_to_unfreeze)
            optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)


In [9]:
# Main Script
train_loader, val_loader, test_loader, train_dataset = prepare_data_loaders(data_dir, batch_size)
class_weights = compute_weights(train_dataset, num_classes)
criterion = nn.CrossEntropyLoss(weight=class_weights)
model = load_model(num_classes)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)

train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=20)

Loading datasets...
Loading and configuring the model...

Epoch 1/20


Training: 100%|██████████| 6466/6466 [26:51<00:00,  4.01it/s]  


Training Loss: 1.7632


Validating: 100%|██████████| 275/275 [01:19<00:00,  3.46it/s]


Validation Loss: 1.9976, Accuracy: 29.82%
New best model saved with accuracy: 29.82%
Unfreezing 2/6 blocks...

Epoch 2/20


Training: 100%|██████████| 6466/6466 [38:17<00:00,  2.81it/s]  


Training Loss: 1.1064


Validating: 100%|██████████| 275/275 [00:44<00:00,  6.15it/s]


Validation Loss: 1.3317, Accuracy: 50.52%
New best model saved with accuracy: 50.52%
Unfreezing 3/6 blocks...

Epoch 3/20


Training:  34%|███▎      | 2170/6466 [09:53<19:35,  3.65it/s]


KeyboardInterrupt: 