<a href="https://colab.research.google.com/github/michealamanya/machine_learning/blob/main/mood_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
Enhanced Facial Expression Recognition (FER) System - IMPROVED VERSION
======================================================================
Key improvements over previous version:
✅ Larger input size (96x96) for better feature extraction
✅ Focal Loss to handle class imbalance better
✅ Mixup augmentation for regularization
✅ Cosine annealing with warm restarts
✅ Gradient clipping for stability
✅ More aggressive augmentation
✅ Fixed ONNX export
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, models
from PIL import Image
import kagglehub
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

# ============================================================================
# IMPROVED CONFIGURATION
# ============================================================================

class Config:
    # Image dimensions - INCREASED for better features
    IMG_HEIGHT = 96  # Changed from 48
    IMG_WIDTH = 96
    IMG_CHANNELS = 3

    # Training parameters
    BATCH_SIZE = 48  # Slightly smaller due to larger images
    LEARNING_RATE = 0.0005  # Lower initial LR
    NUM_EPOCHS = 150  # More epochs
    EARLY_STOP_PATIENCE = 20  # More patience

    # Model parameters
    DROPOUT_RATE = 0.4  # Slightly reduced
    USE_PRETRAINED = True
    USE_FOCAL_LOSS = True  # NEW: Better for class imbalance
    FOCAL_ALPHA = 0.25
    FOCAL_GAMMA = 2.0

    # Augmentation
    USE_MIXUP = True  # NEW: Mixup augmentation
    MIXUP_ALPHA = 0.2

    # TTA parameters
    USE_TTA = True
    TTA_TRANSFORMS = 8  # Increased

    # Optimization
    WEIGHT_DECAY = 0.0001
    GRADIENT_CLIP = 1.0  # NEW: Gradient clipping

    # Scheduler
    USE_COSINE_ANNEALING = True  # NEW: Better LR schedule
    T_MAX = 15
    ETA_MIN = 1e-6

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    MODEL_SAVE_PATH = 'best_fer_improved.pth'
    ONNX_SAVE_PATH = 'fer_model_improved.onnx'

# ============================================================================
# FOCAL LOSS - Better for imbalanced classes
# ============================================================================

class FocalLoss(nn.Module):
    """
    Focal Loss addresses class imbalance by down-weighting easy examples
    and focusing on hard, misclassified examples.
    """
    def __init__(self, alpha=0.25, gamma=2.0, class_weights=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.class_weights = class_weights

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none',
                                  weight=self.class_weights)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

# ============================================================================
# MIXUP AUGMENTATION
# ============================================================================

def mixup_data(x, y, alpha=0.2):
    """Mixup augmentation - combines two examples"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Mixup loss calculation"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# ============================================================================
# DATASET WITH IMPROVED AUGMENTATION
# ============================================================================

EMOTIONS = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
EMOTION_TO_IDX = {emotion: idx for idx, emotion in enumerate(EMOTIONS)}

class EmotionDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []

        for emotion in EMOTIONS:
            emotion_path = os.path.join(root_dir, emotion)
            if os.path.isdir(emotion_path):
                for img_name in os.listdir(emotion_path):
                    if img_name.lower().endswith(('.jpg', '.png', '.jpeg')):
                        self.images.append(os.path.join(emotion_path, img_name))
                        self.labels.append(EMOTION_TO_IDX[emotion])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('L')
        image = Image.merge('RGB', (image, image, image))
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

    def get_class_weights(self):
        label_counts = Counter(self.labels)
        total = len(self.labels)
        weights = {cls: total / (len(EMOTIONS) * count)
                  for cls, count in label_counts.items()}
        return torch.tensor([weights[i] for i in range(len(EMOTIONS))],
                          dtype=torch.float32)

# More aggressive augmentation
train_transform = transforms.Compose([
    transforms.Resize((Config.IMG_HEIGHT, Config.IMG_WIDTH)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=25),  # Increased
    transforms.RandomResizedCrop(Config.IMG_HEIGHT, scale=(0.75, 1.0)),  # More aggressive
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),  # Increased
    transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.85, 1.15)),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.3),  # NEW
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.2)  # NEW: Random erasing
])

val_test_transform = transforms.Compose([
    transforms.Resize((Config.IMG_HEIGHT, Config.IMG_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ============================================================================
# IMPROVED MODEL ARCHITECTURE
# ============================================================================

class ImprovedEmotionResNet(nn.Module):
    """
    Enhanced ResNet-18 with attention mechanism
    """
    def __init__(self, num_classes=7, pretrained=True, dropout_rate=0.4):
        super(ImprovedEmotionResNet, self).__init__()

        self.backbone = models.resnet18(pretrained=pretrained)
        num_features = self.backbone.fc.in_features

        # Remove original FC layer
        self.backbone.fc = nn.Identity()

        # Custom classifier with more capacity
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(num_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate * 0.8),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate * 0.6),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

# ============================================================================
# IMPROVED TRAINING LOOP
# ============================================================================

def train_epoch_improved(model, loader, criterion, optimizer, device, use_mixup=True):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        # Apply mixup
        if use_mixup and Config.USE_MIXUP:
            images, labels_a, labels_b, lam = mixup_data(images, labels, Config.MIXUP_ALPHA)

            optimizer.zero_grad()
            outputs = model(images)
            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
        else:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)

        loss.backward()

        # Gradient clipping
        if Config.GRADIENT_CLIP > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), Config.GRADIENT_CLIP)

        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)

        if use_mixup and Config.USE_MIXUP:
            # Approximate accuracy for mixup
            correct += (lam * (predicted == labels_a).sum().item() +
                       (1 - lam) * (predicted == labels_b).sum().item())
        else:
            correct += (predicted == labels).sum().item()

    return running_loss / total, 100 * correct / total

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return running_loss / total, 100 * correct / total

# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == '__main__':
    print("=" * 70)
    print("IMPROVED FER TRAINING")
    print("=" * 70)
    print(f"Device: {Config.DEVICE}")
    print(f"Image size: {Config.IMG_HEIGHT}x{Config.IMG_WIDTH}")
    print(f"Focal Loss: {Config.USE_FOCAL_LOSS}")
    print(f"Mixup: {Config.USE_MIXUP}")
    print(f"Cosine Annealing: {Config.USE_COSINE_ANNEALING}")

    # Download dataset
    print("\nDownloading FER2013 dataset...")
    path = kagglehub.dataset_download("msambare/fer2013")
    train_dir = os.path.join(path, "train")
    test_dir = os.path.join(path, "test")

    # Load datasets
    print("Loading datasets...")
    full_train_dataset = EmotionDataset(train_dir, transform=train_transform)
    test_dataset = EmotionDataset(test_dir, transform=val_test_transform)

    class_weights = full_train_dataset.get_class_weights().to(Config.DEVICE)
    print(f"Class weights: {class_weights}")

    # Split
    val_size = int(0.15 * len(full_train_dataset))
    train_size = len(full_train_dataset) - val_size
    train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE,
                          shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE,
                           shuffle=False, num_workers=2, pin_memory=True)

    print(f"Training samples: {train_size}")
    print(f"Validation samples: {val_size}")
    print(f"Test samples: {len(test_dataset)}")

    # Initialize model
    model = ImprovedEmotionResNet(
        num_classes=len(EMOTIONS),
        pretrained=Config.USE_PRETRAINED,
        dropout_rate=Config.DROPOUT_RATE
    ).to(Config.DEVICE)

    # Loss function
    if Config.USE_FOCAL_LOSS:
        criterion = FocalLoss(
            alpha=Config.FOCAL_ALPHA,
            gamma=Config.FOCAL_GAMMA,
            class_weights=class_weights
        )
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=Config.LEARNING_RATE,
        weight_decay=Config.WEIGHT_DECAY
    )

    # Scheduler
    if Config.USE_COSINE_ANNEALING:
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=Config.T_MAX,
            T_mult=2,
            eta_min=Config.ETA_MIN
        )
    else:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5
        )

    # Training loop
    print("\n" + "=" * 70)
    print("Starting training...")
    print("=" * 70)

    best_val_acc = 0.0
    patience_counter = 0

    for epoch in range(Config.NUM_EPOCHS):
        train_loss, train_acc = train_epoch_improved(
            model, train_loader, criterion, optimizer, Config.DEVICE
        )
        val_loss, val_acc = validate(model, val_loader, criterion, Config.DEVICE)

        if Config.USE_COSINE_ANNEALING:
            scheduler.step()
        else:
            scheduler.step(val_loss)

        current_lr = optimizer.param_groups[0]['lr']

        print(f"Epoch [{epoch+1:3d}/{Config.NUM_EPOCHS}] | "
              f"Loss: {train_loss:.4f}/{val_loss:.4f} | "
              f"Acc: {train_acc:.2f}%/{val_acc:.2f}% | "
              f"LR: {current_lr:.6f}")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, Config.MODEL_SAVE_PATH)
            print(f"  ✓ Best model saved! (Val Acc: {val_acc:.2f}%)")
        else:
            patience_counter += 1

        # Early stopping
        if patience_counter >= Config.EARLY_STOP_PATIENCE:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break

    print(f"\nTraining complete! Best validation accuracy: {best_val_acc:.2f}%")

    # Evaluate
    print("\nEvaluating on test set...")
    checkpoint = torch.load(Config.MODEL_SAVE_PATH)
    model.load_state_dict(checkpoint['model_state_dict'])

    test_loss, test_acc = validate(model, test_loader, criterion, Config.DEVICE)
    print(f"Test Accuracy: {test_acc:.2f}%")

    # Get predictions for report
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(Config.DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())

    print("\n" + "=" * 70)
    print("Classification Report:")
    print("=" * 70)
    print(classification_report(all_labels, all_preds, target_names=EMOTIONS, digits=4))

    # Export to ONNX (with fix)
    print("\nExporting to ONNX...")
    try:
        model.eval()
        dummy_input = torch.randn(1, 3, Config.IMG_HEIGHT, Config.IMG_WIDTH).to(Config.DEVICE)

        torch.onnx.export(
            model,
            dummy_input,
            Config.ONNX_SAVE_PATH,
            export_params=True,
            opset_version=14,  # Updated opset version
            do_constant_folding=True,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
        )
        print(f"✓ ONNX model saved: {Config.ONNX_SAVE_PATH}")
    except Exception as e:
        print(f"✗ ONNX export failed: {e}")
        print("  Install onnx: !pip install onnx onnxscript")

    print(f"\n{'=' * 70}")
    print("TRAINING SUMMARY")
    print(f"{'=' * 70}")
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
    print(f"Test Accuracy: {test_acc:.2f}%")
    print(f"Model saved: {Config.MODEL_SAVE_PATH}")

IMPROVED FER TRAINING
Device: cuda
Image size: 96x96
Focal Loss: True
Mixup: True
Cosine Annealing: True

Downloading FER2013 dataset...
Using Colab cache for faster access to the 'fer2013' dataset.
Loading datasets...
Class weights: tensor([1.0266, 9.4066, 1.0010, 0.5684, 0.8260, 0.8491, 1.2934],
       device='cuda:0')
Training samples: 24403
Validation samples: 4306
Test samples: 7178





Starting training...
Epoch [  1/150] | Loss: 0.3470/0.3033 | Acc: 17.71%/22.53% | LR: 0.000495
  ✓ Best model saved! (Val Acc: 22.53%)
Epoch [  2/150] | Loss: 0.3203/0.2862 | Acc: 23.90%/29.89% | LR: 0.000478
  ✓ Best model saved! (Val Acc: 29.89%)
Epoch [  3/150] | Loss: 0.3022/0.2606 | Acc: 30.62%/35.11% | LR: 0.000452
  ✓ Best model saved! (Val Acc: 35.11%)
Epoch [  4/150] | Loss: 0.2859/0.2627 | Acc: 35.78%/35.81% | LR: 0.000417
  ✓ Best model saved! (Val Acc: 35.81%)
Epoch [  5/150] | Loss: 0.2775/0.2544 | Acc: 36.97%/40.36% | LR: 0.000375
  ✓ Best model saved! (Val Acc: 40.36%)
Epoch [  6/150] | Loss: 0.2701/0.2345 | Acc: 39.31%/41.69% | LR: 0.000328
  ✓ Best model saved! (Val Acc: 41.69%)
Epoch [  7/150] | Loss: 0.2634/0.2255 | Acc: 39.92%/42.64% | LR: 0.000277
  ✓ Best model saved! (Val Acc: 42.64%)
Epoch [  8/150] | Loss: 0.2581/0.2246 | Acc: 40.79%/43.22% | LR: 0.000224
  ✓ Best model saved! (Val Acc: 43.22%)
Epoch [  9/150] | Loss: 0.2526/0.2207 | Acc: 42.32%/44.91% | LR: 0