In [29]:
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, WeightedRandomSampler
from tqdm import tqdm

In [30]:
# Paths
data_dir = "../DATA_PREPARE_ATT_03/Splitted_MyDataset"  # Root directory of the split dataset (train, val, test)
model_save_path = "efficientnet_b0_emotion_model.pth"  # Path to save the trained model
metadata_save_path = "saved_metadata"  # Path to save dataset metadata

In [31]:
# Configuration
batch_size = 32
num_epochs = 20
learning_rate = 1e-4
num_classes = 7
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Emotion categories
emotion_classes = ["Anger", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]

# Define focus weights for prioritized emotions (Anger, Sad, Happy, Neutral)
focus_weights = [1.5 if emotion in ["Anger", "Sad", "Happy", "Neutral"] else 1.0 for emotion in emotion_classes]

Using device: cuda


In [32]:
# 1. Save Dataset Metadata and Class Weights
def save_dataset_metadata(train_dataset, val_dataset, test_dataset, class_weights, output_dir="saved_metadata"):
    os.makedirs(output_dir, exist_ok=True)

    # Save class weights
    with open(os.path.join(output_dir, "class_weights.pkl"), "wb") as f:
        pickle.dump(class_weights, f)

    # Save dataset information (indices and targets)
    dataset_info = {
        "train_indices": train_dataset.samples,
        "val_indices": val_dataset.samples,
        "test_indices": test_dataset.samples
    }
    with open(os.path.join(output_dir, "dataset_info.pkl"), "wb") as f:
        pickle.dump(dataset_info, f)

    print("Dataset metadata and class weights saved.")

In [33]:
# 2. Load Saved Metadata
def load_dataset_metadata(output_dir="saved_metadata"):
    # Load class weights
    with open(os.path.join(output_dir, "class_weights.pkl"), "rb") as f:
        class_weights = pickle.load(f)

    # Load dataset information
    with open(os.path.join(output_dir, "dataset_info.pkl"), "rb") as f:
        dataset_info = pickle.load(f)

    print("Dataset metadata and class weights loaded.")
    return class_weights, dataset_info

In [34]:
# 3. Create DataLoaders from Metadata
def get_saved_data_loaders(data_dir, dataset_info, class_weights, batch_size):
    # Transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),  # Normalize to [0, 1]
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # EfficientNet normalization
    ])

    # Load datasets
    train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform)
    test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform)

    # Restore indices
    train_dataset.samples = dataset_info["train_indices"]
    val_dataset.samples = dataset_info["val_indices"]
    test_dataset.samples = dataset_info["test_indices"]

    # Create WeightedRandomSampler
    sample_weights = [class_weights[label] for _, label in train_dataset.samples]
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, test_loader

In [35]:
# 4. Full Preprocessing Workflow
def prepare_data_loaders(data_dir, batch_size, focus_weights, emotion_classes, metadata_save_path):
    # Check if metadata exists
    if os.path.exists(metadata_save_path):
        print("Loading saved metadata...")
        class_weights, dataset_info = load_dataset_metadata(metadata_save_path)
        return get_saved_data_loaders(data_dir, dataset_info, class_weights, batch_size)
    else:
        # Transformations
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        print("Loading datasets...")
        train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform)
        val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform)
        test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform)
        print("Datasets loaded successfully.")

        # Compute class weights
        print("Calculating class counts...")
        class_counts = [0] * len(emotion_classes)
        for _, label in tqdm(train_dataset, desc="Counting samples per class"):
            class_counts[label] += 1

        total_samples = sum(class_counts)
        class_weights = [total_samples / count for count in class_counts]
        class_weights = [weight * focus_weights[i] for i, weight in enumerate(class_weights)]
        class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

        # Save metadata
        save_dataset_metadata(train_dataset, val_dataset, test_dataset, class_weights, metadata_save_path)

        # Create WeightedRandomSampler
        print("Creating weighted sampler...")
        sample_weights = []
        for _, label in tqdm(train_dataset, desc="Calculating sample weights"):
            sample_weights.append(class_weights[label])

        sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

        # Data loaders
        print("Creating data loaders...")
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=4, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

        print("Data loaders created successfully.")
        return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = prepare_data_loaders(data_dir, batch_size, focus_weights, emotion_classes, metadata_save_path)


Loading datasets...
Datasets loaded successfully.
Calculating class counts...


Counting samples per class: 100%|██████████| 92309/92309 [21:37<00:00, 71.17it/s]   


Dataset metadata and class weights saved.
Creating weighted sampler...


Calculating sample weights: 100%|██████████| 92309/92309 [08:27<00:00, 181.98it/s] 

Creating data loaders...
Data loaders created successfully.





In [36]:
# 5. Model Definition
model = models.efficientnet_b0(pretrained=True)  # Load pretrained EfficientNet B0
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)  # Adjust the final layer
model = model.to(device)

In [37]:
# 6. Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()  # Weighted loss will already be applied in the data loaders
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [38]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, model_save_path):
    best_val_loss = float("inf")
    best_val_accuracy = 0.0  # To track the accuracy of the best model

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        print(f"\nEpoch {epoch + 1}/{num_epochs}: Training...")

        # Training phase with progress bar
        for batch_idx, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training Batches", leave=False)):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        print(f"Epoch {epoch + 1}/{num_epochs}: Validating...")

        with torch.no_grad():
            for batch_idx, (inputs, labels) in enumerate(tqdm(val_loader, desc="Validation Batches", leave=False)):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                # Calculate validation accuracy
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = correct / total

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy * 100:.2f}%")

        # Save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), model_save_path)
            print(f"Model saved at epoch {epoch + 1} with Val Accuracy: {best_val_accuracy * 100:.2f}%")

    print(f"Training complete. Best model achieved Val Accuracy: {best_val_accuracy * 100:.2f}%")

# Run the updated training function
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, model_save_path)


Epoch 1/20: Training...


                                                                      

Epoch 1/20: Validating...


                                                                     

Epoch 1/20, Train Loss: 1.2257, Val Loss: 1.0225, Val Accuracy: 63.25%
Model saved at epoch 1 with Val Accuracy: 63.25%

Epoch 2/20: Training...


                                                                     

Epoch 2/20: Validating...


                                                                     

Epoch 2/20, Train Loss: 0.9361, Val Loss: 1.0405, Val Accuracy: 63.18%

Epoch 3/20: Training...


                                                                     

Epoch 3/20: Validating...


                                                                     

Epoch 3/20, Train Loss: 0.7635, Val Loss: 1.1118, Val Accuracy: 62.90%

Epoch 4/20: Training...


                                                                     

Epoch 4/20: Validating...


                                                                     

Epoch 4/20, Train Loss: 0.6237, Val Loss: 1.1679, Val Accuracy: 63.34%

Epoch 5/20: Training...


                                                                       

Epoch 5/20: Validating...


                                                                     

Epoch 5/20, Train Loss: 0.5092, Val Loss: 1.2653, Val Accuracy: 62.75%

Epoch 6/20: Training...


                                                          

KeyboardInterrupt: 

In [None]:
# 8. Test the Model with Progress Bar
def test_model(model, test_loader, emotion_classes, model_save_path):
    model.load_state_dict(torch.load(model_save_path))
    model.eval()

    all_predictions = []
    all_labels = []
    print("Testing the model...")
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Testing Progress"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1)
            all_predictions.append(probabilities.cpu())
            all_labels.append(labels.cpu())

    # Combine all batches
    all_predictions = torch.cat(all_predictions)
    all_labels = torch.cat(all_labels)

    # Calculate per-class accuracy
    correct_predictions = (torch.argmax(all_predictions, dim=1) == all_labels).float()
    accuracy = correct_predictions.mean().item()

    print(f"Test Accuracy: {accuracy * 100:.2f}%")
    return all_predictions, all_labels

test_predictions, test_labels = test_model(model, test_loader, emotion_classes, model_save_path)