In [1]:
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, WeightedRandomSampler
from tqdm import tqdm

In [2]:
# Paths
data_dir = "../DATA_PREPARE_ATT_03/Splitted_MyDataset02"  # Root directory of the split dataset (train, val, test)
model_save_path = "efficientnet_b0_emotion_model.pth"  # Path to save the trained model
metadata_save_path = "saved_metadata"  # Path to save dataset metadata
checkpoint_path = "training_checkpoint.pth"  # Path to save training checkpoint

In [3]:
# Configuration
batch_size = 32
num_epochs = 20
learning_rate = 1e-4
num_classes = 7
patience = 5  # Early stopping patience
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Emotion categories
emotion_classes = ["Anger", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]

# Define focus weights for prioritized emotions (Anger, Sad, Happy, Neutral)
focus_weights = [1.5 if emotion in ["Anger", "Sad", "Happy", "Neutral"] else 1.0 for emotion in emotion_classes]


Using device: cuda


In [4]:
# Early Stopping Class
class EarlyStopping:
    def __init__(self, patience=3, delta=0.01):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [5]:
# Save and Load Metadata
def save_dataset_metadata(train_dataset, val_dataset, test_dataset, class_weights, output_dir="saved_metadata"):
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, "class_weights.pkl"), "wb") as f:
        pickle.dump(class_weights, f)
    dataset_info = {
        "train_indices": train_dataset.samples,
        "val_indices": val_dataset.samples,
        "test_indices": test_dataset.samples,
    }
    with open(os.path.join(output_dir, "dataset_info.pkl"), "wb") as f:
        pickle.dump(dataset_info, f)
    print("Dataset metadata and class weights saved.")

def load_dataset_metadata(output_dir="saved_metadata"):
    with open(os.path.join(output_dir, "class_weights.pkl"), "rb") as f:
        class_weights = pickle.load(f)
    with open(os.path.join(output_dir, "dataset_info.pkl"), "rb") as f:
        dataset_info = pickle.load(f)
    print("Dataset metadata and class weights loaded.")
    return class_weights, dataset_info

In [6]:
# Define the Missing Function
def get_saved_data_loaders(data_dir, dataset_info, class_weights, batch_size):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform)
    test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform)

    train_dataset.samples = dataset_info["train_indices"]
    val_dataset.samples = dataset_info["val_indices"]
    test_dataset.samples = dataset_info["test_indices"]

    sample_weights = [class_weights[label] for _, label in train_dataset.samples]
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    return train_loader, val_loader, test_loader

In [7]:
# Preprocessing Workflow
def prepare_data_loaders(data_dir, batch_size, focus_weights, emotion_classes, metadata_save_path):
    if os.path.exists(metadata_save_path):
        print("Loading saved metadata...")
        class_weights, dataset_info = load_dataset_metadata(metadata_save_path)
        return get_saved_data_loaders(data_dir, dataset_info, class_weights, batch_size)
    else:
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(30),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        transform_val_test = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        print("Loading datasets...")
        train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform_train)
        val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform_val_test)
        test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform_val_test)
        print("Datasets loaded successfully.")

        print("Calculating class counts...")
        class_counts = [0] * len(emotion_classes)
        for _, label in tqdm(train_dataset, desc="Counting samples per class"):
            class_counts[label] += 1

        total_samples = sum(class_counts)
        class_weights = [total_samples / count for count in class_counts]
        class_weights = [weight * focus_weights[i] for i, weight in enumerate(class_weights)]
        save_dataset_metadata(train_dataset, val_dataset, test_dataset, class_weights, metadata_save_path)

        print("Creating weighted sampler...")
        sample_weights = [class_weights[label] for _, label in train_dataset.samples]
        sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

        print("Creating data loaders...")
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=4, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

        return train_loader, val_loader, test_loader

# Create DataLoaders
train_loader, val_loader, test_loader = prepare_data_loaders(data_dir, batch_size, focus_weights, emotion_classes, metadata_save_path)


Loading datasets...
Datasets loaded successfully.
Calculating class counts...


Counting samples per class: 100%|██████████| 98902/98902 [43:40<00:00, 37.74it/s]  


Dataset metadata and class weights saved.
Creating weighted sampler...
Creating data loaders...


In [8]:
# Model Definition with Dropout
model = models.efficientnet_b0(pretrained=True)
model.classifier = nn.Sequential(
    nn.Dropout(p=0.5),  # Increase dropout rate
    nn.Linear(model.classifier[1].in_features, num_classes),
)
model = model.to(device)



In [9]:
# Loss, Optimizer, and Scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)  # Added weight decay
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

In [10]:
# Training with Early Stopping and Checkpointing
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, model_save_path, checkpoint_path):
    best_val_loss = float("inf")
    early_stopping = EarlyStopping(patience=patience)

    start_epoch = 0

    # Load checkpoint if available
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        print(f"Resuming training from epoch {start_epoch}.")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_loss = 0.0
        print(f"\nEpoch {epoch + 1}/{num_epochs}: Training...")
        for inputs, labels in tqdm(train_loader, desc="Training Batches", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        print(f"Epoch {epoch + 1}/{num_epochs}: Validating...")
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation Batches", leave=False):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_loss /= len(val_loader)
        val_accuracy = correct / total

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy * 100:.2f}%")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), model_save_path)
            print(f"Model saved at epoch {epoch + 1} with Val Accuracy: {val_accuracy * 100:.2f}%")

        # Save checkpoint
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
        }, checkpoint_path)

        scheduler.step(val_loss)
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, model_save_path, checkpoint_path)


Epoch 1/20: Training...


                                                                     

Epoch 1/20: Validating...


                                                                     

Epoch 1/20, Train Loss: 1.5619, Val Loss: 1.2202, Val Accuracy: 55.17%
Model saved at epoch 1 with Val Accuracy: 55.17%

Epoch 2/20: Training...


                                                                     

Epoch 2/20: Validating...


                                                                     

Epoch 2/20, Train Loss: 1.3659, Val Loss: 1.0990, Val Accuracy: 59.41%
Model saved at epoch 2 with Val Accuracy: 59.41%

Epoch 3/20: Training...


                                                                       

Epoch 3/20: Validating...


                                                                     

Epoch 3/20, Train Loss: 1.3045, Val Loss: 1.0766, Val Accuracy: 60.46%
Model saved at epoch 3 with Val Accuracy: 60.46%

Epoch 4/20: Training...


                                                                      

Epoch 4/20: Validating...


                                                                     

Epoch 4/20, Train Loss: 1.2626, Val Loss: 1.0580, Val Accuracy: 61.04%
Model saved at epoch 4 with Val Accuracy: 61.04%

Epoch 5/20: Training...


                                                                       

Epoch 5/20: Validating...


                                                                     

Epoch 5/20, Train Loss: 1.2389, Val Loss: 1.0395, Val Accuracy: 61.85%
Model saved at epoch 5 with Val Accuracy: 61.85%

Epoch 6/20: Training...


                                                                     

Epoch 6/20: Validating...


                                                                     

Epoch 6/20, Train Loss: 1.2159, Val Loss: 1.0030, Val Accuracy: 63.76%
Model saved at epoch 6 with Val Accuracy: 63.76%

Epoch 7/20: Training...


                                                                     

Epoch 7/20: Validating...


                                                                     

Epoch 7/20, Train Loss: 1.1996, Val Loss: 1.0191, Val Accuracy: 62.90%

Epoch 8/20: Training...


                                                                       

Epoch 8/20: Validating...


                                                                     

Epoch 8/20, Train Loss: 1.1755, Val Loss: 0.9990, Val Accuracy: 63.96%
Model saved at epoch 8 with Val Accuracy: 63.96%

Epoch 9/20: Training...


                                                                     

Epoch 9/20: Validating...


                                                                     

Epoch 9/20, Train Loss: 1.1634, Val Loss: 1.0212, Val Accuracy: 63.05%

Epoch 10/20: Training...


                                                                       

Epoch 10/20: Validating...


                                                                     

Epoch 10/20, Train Loss: 1.1536, Val Loss: 0.9939, Val Accuracy: 64.03%
Model saved at epoch 10 with Val Accuracy: 64.03%

Epoch 11/20: Training...


                                                                     

Epoch 11/20: Validating...


                                                                     

Epoch 11/20, Train Loss: 1.1432, Val Loss: 1.0265, Val Accuracy: 63.19%
Early stopping triggered.




In [11]:
# Testing with Per-Class Accuracy
def test_model(model, test_loader, emotion_classes, model_save_path):
    model.load_state_dict(torch.load(model_save_path))
    model.eval()
    correct = [0] * len(emotion_classes)
    total = [0] * len(emotion_classes)
    print("Testing the model...")
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Testing Progress"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            for i in range(len(labels)):
                label = labels[i].item()
                total[label] += 1
                correct[label] += (predicted[i] == label).item()

    for i, emotion in enumerate(emotion_classes):
        accuracy = 100 * correct[i] / total[i] if total[i] > 0 else 0
        print(f"{emotion}: {accuracy:.2f}%")
    overall_accuracy = sum(correct) / sum(total)
    print(f"Overall Test Accuracy: {overall_accuracy * 100:.2f}%")

test_model(model, test_loader, emotion_classes, model_save_path)

  model.load_state_dict(torch.load(model_save_path))


Testing the model...


Testing Progress: 100%|██████████| 413/413 [01:20<00:00,  5.16it/s]

Anger: 68.18%
Disgust: 26.89%
Fear: 53.30%
Happy: 77.03%
Neutral: 67.96%
Sad: 64.28%
Surprise: 52.60%
Overall Test Accuracy: 64.86%



