In [10]:
import os
import numpy as np  # Ensure numpy is imported
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix
from timm import create_model
from tqdm import tqdm
import pickle

In [11]:
# Paths
root_dir = "."  # Root folder where the previous model and new dataset are stored
model_save_path = os.path.join(root_dir, "fine_tuned_efficientnetv2_rw_s.pth")
checkpoint_path = os.path.join(root_dir, "fine_tuning_checkpoint.pth")
class_weights_path = os.path.join(root_dir, "class_weights.pkl")
dataset_info_path = os.path.join(root_dir, "dataset_info.pkl")
pretrained_model_path = os.path.join(root_dir, "adaptive_efficientnetv2_rw_s_emotion_model.pth")
data_dir = os.path.join(root_dir, "../DATA_PREPARE_ATT_05/AffPreProcessed")  # Update this path to point to your new dataset

In [12]:
# Configuration
batch_size = 32
initial_lr = 1e-3
num_epochs = 20
num_classes = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Emotion categories
emotion_classes = ["Anger", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral", "Contempt"]


Using device: cuda


In [13]:
# Data Preparation
def prepare_data_loaders(data_dir, batch_size):
    if os.path.exists(dataset_info_path):
        print("Loading preprocessed datasets...")
        with open(dataset_info_path, "rb") as f:
            train_loader, val_loader, test_loader = pickle.load(f)
        return train_loader, val_loader, test_loader

    transform_train = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((260, 260)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ])
    transform_val_test = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((260, 260)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ])

    print("Loading datasets...")
    train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform_train)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform_val_test)
    test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform_val_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Save datasets for future reuse
    with open(dataset_info_path, "wb") as f:
        pickle.dump((train_loader, val_loader, test_loader), f)

    return train_loader, val_loader, test_loader

In [14]:
# Compute Class Weights
def compute_weights(train_dataset, num_classes):
    if os.path.exists(class_weights_path):
        print("Loading precomputed class weights...")
        with open(class_weights_path, "rb") as f:
            return torch.tensor(pickle.load(f), dtype=torch.float).to(device)

    print("Computing class weights...")
    targets = [label for _, label in train_dataset.samples]
    class_weights = compute_class_weight('balanced', classes=np.arange(num_classes), y=targets)

    # Save class weights for future reuse
    with open(class_weights_path, "wb") as f:
        pickle.dump(class_weights, f)

    return torch.tensor(class_weights, dtype=torch.float).to(device)

In [15]:
# Load Pre-Trained Model
def load_pretrained_model(pretrained_model_path, num_classes, trainable_ratio=0.1):
    print("Loading the pre-trained model...")
    model = create_model('efficientnetv2_rw_s', pretrained=False, num_classes=num_classes, in_chans=1)
    model.load_state_dict(torch.load(pretrained_model_path, map_location=device))

    # Freeze 90% of the layers
    total_layers = len(list(model.parameters()))
    trainable_layers = int(total_layers * trainable_ratio)
    print(f"Total layers: {total_layers}, Trainable layers: {trainable_layers}")

    for i, param in enumerate(model.parameters()):
        param.requires_grad = i >= total_layers - trainable_layers

    return model.to(device)

In [16]:
# Adjust Model Trainable Layers
def adjust_trainable_layers(model, num_blocks_to_unfreeze):
    blocks = list(model.blocks)
    num_blocks = len(blocks)
    start_unfreeze = max(0, num_blocks - num_blocks_to_unfreeze)

    print(f"Unfreezing {num_blocks_to_unfreeze}/{num_blocks} blocks...")
    for i, block in enumerate(blocks):
        for param in block.parameters():
            param.requires_grad = i >= start_unfreeze

    return model

In [17]:
# Training Loop with Dynamic Adjustments
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs):
    best_val_accuracy = 0.0
    num_blocks_to_unfreeze = 1  # Start with one block trainable
    early_stop_counter = 0
    patience = 3  # Early stopping patience

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        model.train()
        train_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)
        print(f"Training Loss: {train_loss:.4f}")

        # Validation
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validating"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = correct / total
        print(f"Validation Loss: {val_loss / len(val_loader):.4f}, Accuracy: {val_accuracy * 100:.2f}%")

        # Save the best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), model_save_path)
            print(f"New best model saved with accuracy: {val_accuracy * 100:.2f}%")
            early_stop_counter = 0
        else:
            early_stop_counter += 1

        # Early Stopping
        if early_stop_counter >= patience:
            print("Early stopping triggered. Training halted.")
            break

        # Dynamic adjustments
        if val_accuracy < best_val_accuracy or epoch % 2 == 0:
            num_blocks_to_unfreeze += 1
            model = adjust_trainable_layers(model, num_blocks_to_unfreeze)
            optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)


In [18]:
# Main Script
if __name__ == "__main__":
    train_loader, val_loader, test_loader = prepare_data_loaders(data_dir, batch_size)
    train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"))
    class_weights = compute_weights(train_dataset, num_classes)

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    model = load_pretrained_model(pretrained_model_path, num_classes)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)

    train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs)

Loading preprocessed datasets...
Computing class weights...
Loading the pre-trained model...


  model.load_state_dict(torch.load(pretrained_model_path, map_location=device))


Total layers: 458, Trainable layers: 45

Epoch 1/20


Training: 100%|██████████| 1100/1100 [06:13<00:00,  2.94it/s]


Training Loss: 1.3096


Validating: 100%|██████████| 138/138 [01:09<00:00,  2.00it/s]


Validation Loss: 1.2074, Accuracy: 55.23%
New best model saved with accuracy: 55.23%
Unfreezing 2/6 blocks...

Epoch 2/20


Training: 100%|██████████| 1100/1100 [08:12<00:00,  2.23it/s]


Training Loss: 1.1193


Validating: 100%|██████████| 138/138 [00:38<00:00,  3.57it/s]


Validation Loss: 1.0653, Accuracy: 60.39%
New best model saved with accuracy: 60.39%

Epoch 3/20


Training: 100%|██████████| 1100/1100 [08:50<00:00,  2.07it/s]


Training Loss: 0.9700


Validating: 100%|██████████| 138/138 [01:35<00:00,  1.44it/s]


Validation Loss: 1.0603, Accuracy: 60.66%
New best model saved with accuracy: 60.66%
Unfreezing 3/6 blocks...

Epoch 4/20


Training: 100%|██████████| 1100/1100 [25:54<00:00,  1.41s/it]


Training Loss: 0.8818


Validating: 100%|██████████| 138/138 [00:41<00:00,  3.34it/s]


Validation Loss: 1.0690, Accuracy: 60.61%
Unfreezing 4/6 blocks...

Epoch 5/20


Training: 100%|██████████| 1100/1100 [45:28<00:00,  2.48s/it]


Training Loss: 0.7877


Validating: 100%|██████████| 138/138 [01:35<00:00,  1.44it/s]


Validation Loss: 1.1113, Accuracy: 61.27%
New best model saved with accuracy: 61.27%
Unfreezing 5/6 blocks...

Epoch 6/20


Training:   1%|▏         | 16/1100 [02:25<2:43:48,  9.07s/it]


KeyboardInterrupt: 