In [12]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import OneCycleLR

In [13]:
# Paths
data_dir = "../DATA_PREPARE_ATT_04/Grayscale_Face_images"
model_save_path = "optimized_efficientnet_b2_emotion_model.pth"
checkpoint_path = "optimized_training_checkpoint.pth"

In [14]:
# Configuration
batch_size = 32
accumulation_steps = 4
num_epochs = 25
initial_lr = 1e-3
num_classes = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Emotion categories
emotion_classes = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]


Using device: cuda


In [15]:
# Early Stopping Class
class EarlyStopping:
    def __init__(self, patience=5, delta=0.01):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [17]:
# Data Preparation
def prepare_data_loaders(data_dir, batch_size):
    transform_train = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(30),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
        transforms.RandomAffine(degrees=0, shear=10, scale=(0.8, 1.2)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485], std=[0.229]),
        transforms.RandomErasing(p=0.5, scale=(0.02, 0.2))
    ])
    transform_val_test = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485], std=[0.229])
    ])

    print("Loading datasets...")
    with tqdm(total=3, desc="Loading datasets", unit="step") as pbar:
        train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform_train)
        pbar.update(1)
        val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform_val_test)
        pbar.update(1)
        test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform_val_test)
        pbar.update(1)
    print("Datasets loaded successfully.")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, test_loader, train_dataset

In [18]:
# Compute class weights
def compute_weights(train_dataset, num_classes):
    print("Computing class weights...")
    targets = [label for _, label in tqdm(train_dataset, desc="Computing class weights", unit="sample")]
    class_weights = compute_class_weight('balanced', classes=np.arange(num_classes), y=targets)
    return torch.tensor(class_weights, dtype=torch.float).to(device)

In [19]:
# Load the model and freeze initial layers
def load_model(num_classes, freeze_ratio=0.8):
    print("Loading and configuring the model...")
    model = models.efficientnet_b2(pretrained=True)
    num_layers = len(list(model.parameters()))
    layers_to_freeze = int(num_layers * freeze_ratio)

    with tqdm(total=num_layers, desc="Freezing layers", unit="layer") as pbar:
        for i, param in enumerate(model.parameters()):
            param.requires_grad = False if i < layers_to_freeze else True
            pbar.update(1)

    model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)  # Adjust for grayscale
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.6),
        nn.Linear(model.classifier[1].in_features, num_classes),
    )
    return model.to(device)

In [20]:
# Loss Function
class WeightedFocalLoss(nn.Module):
    def __init__(self, gamma=2, class_weights=None):
        super(WeightedFocalLoss, self).__init__()
        self.gamma = gamma
        self.class_weights = class_weights

    def forward(self, inputs, targets):
        ce_loss = nn.CrossEntropyLoss(weight=self.class_weights)(inputs, targets)
        pt = torch.exp(-ce_loss)
        return ((1 - pt) ** self.gamma * ce_loss).mean()

In [21]:
# Training Loop
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, model_save_path, checkpoint_path):
    best_val_loss = float("inf")
    early_stopping = EarlyStopping(patience=5)
    start_epoch = 0

    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        print(f"Resuming training from epoch {start_epoch}.")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_loss = 0.0
        optimizer.zero_grad()

        print(f"\nEpoch {epoch + 1}/{num_epochs}: Training...")
        with tqdm(total=len(train_loader), desc="Training Batches", unit="batch") as pbar:
            for i, (inputs, labels) in enumerate(train_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels) / accumulation_steps
                loss.backward()
                train_loss += loss.item()
                if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                pbar.update(1)

        train_loss /= len(train_loader)

        # Validation
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        print(f"Epoch {epoch + 1}/{num_epochs}: Validating...")
        with tqdm(total=len(val_loader), desc="Validation Batches", unit="batch") as pbar:
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    pbar.update(1)

        val_loss /= len(val_loader)
        val_accuracy = correct / total

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy * 100:.2f}%")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), model_save_path)
            print(f"Model saved at epoch {epoch + 1} with Val Accuracy: {val_accuracy * 100:.2f}%")

        if early_stopping(val_loss):
            print("Early stopping triggered.")
            break

In [None]:
# Main Script
train_loader, val_loader, test_loader, train_dataset = prepare_data_loaders(data_dir, batch_size)
class_weights = compute_weights(train_dataset, num_classes)
criterion = WeightedFocalLoss(gamma=2, class_weights=class_weights)
model = load_model(num_classes)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr, weight_decay=1e-3)
scheduler = OneCycleLR(optimizer, max_lr=initial_lr, steps_per_epoch=len(train_loader), epochs=num_epochs)

train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, model_save_path, checkpoint_path)

Loading datasets...


Loading datasets: 100%|██████████| 3/3 [00:00<00:00,  7.86step/s]


Datasets loaded successfully.
Computing class weights...


Computing class weights: 100%|██████████| 90043/90043 [20:50<00:00, 72.01sample/s] 


Loading and configuring the model...


Freezing layers: 100%|██████████| 301/301 [00:00<00:00, 300663.37layer/s]



Epoch 1/25: Training...


Training Batches: 100%|██████████| 2814/2814 [27:07<00:00,  1.73batch/s]


Epoch 1/25: Validating...


Validation Batches: 100%|██████████| 603/603 [04:02<00:00,  2.49batch/s]


Epoch 1/25, Train Loss: 0.3930, Val Loss: 1.4559, Val Accuracy: 24.48%
Model saved at epoch 1 with Val Accuracy: 24.48%

Epoch 2/25: Training...


Training Batches: 100%|██████████| 2814/2814 [39:19<00:00,  1.19batch/s]  


Epoch 2/25: Validating...


Validation Batches: 100%|██████████| 603/603 [02:34<00:00,  3.91batch/s]


Epoch 2/25, Train Loss: 0.3729, Val Loss: 1.3322, Val Accuracy: 29.49%
Model saved at epoch 2 with Val Accuracy: 29.49%

Epoch 3/25: Training...


Training Batches: 100%|██████████| 2814/2814 [33:24<00:00,  1.40batch/s]  


Epoch 3/25: Validating...


Validation Batches: 100%|██████████| 603/603 [02:50<00:00,  3.54batch/s]


Epoch 3/25, Train Loss: 0.3577, Val Loss: 1.2176, Val Accuracy: 33.52%
Model saved at epoch 3 with Val Accuracy: 33.52%

Epoch 4/25: Training...


Training Batches: 100%|██████████| 2814/2814 [23:13<00:00,  2.02batch/s] 


Epoch 4/25: Validating...


Validation Batches: 100%|██████████| 603/603 [03:13<00:00,  3.11batch/s]


Epoch 4/25, Train Loss: 0.3442, Val Loss: 1.1037, Val Accuracy: 38.50%
Model saved at epoch 4 with Val Accuracy: 38.50%

Epoch 5/25: Training...


Training Batches: 100%|██████████| 2814/2814 [36:47<00:00,  1.27batch/s]  


Epoch 5/25: Validating...


Validation Batches: 100%|██████████| 603/603 [03:13<00:00,  3.12batch/s] 


Epoch 5/25, Train Loss: 0.3330, Val Loss: 1.0335, Val Accuracy: 40.58%
Model saved at epoch 5 with Val Accuracy: 40.58%

Epoch 6/25: Training...


Training Batches: 100%|██████████| 2814/2814 [38:23<00:00,  1.22batch/s]  


Epoch 6/25: Validating...


Validation Batches: 100%|██████████| 603/603 [03:31<00:00,  2.86batch/s]


Epoch 6/25, Train Loss: 0.3211, Val Loss: 0.9569, Val Accuracy: 43.19%
Model saved at epoch 6 with Val Accuracy: 43.19%

Epoch 7/25: Training...


Training Batches: 100%|██████████| 2814/2814 [41:43<00:00,  1.12batch/s]  


Epoch 7/25: Validating...


Validation Batches: 100%|██████████| 603/603 [03:38<00:00,  2.76batch/s] 


Epoch 7/25, Train Loss: 0.3103, Val Loss: 0.9357, Val Accuracy: 44.45%
Model saved at epoch 7 with Val Accuracy: 44.45%

Epoch 8/25: Training...


Training Batches:  20%|██        | 572/2814 [05:39<30:41,  1.22batch/s]  