Multi-Modal Mood Matcher - Visual Feature Extraction Training Notebook
========================================================================
This notebook trains a ResNet-50/EfficientNet backbone for emotion recognition
and extracts identity-invariant, emotion-sensitive embeddings.

Target: Train on AffectNet/RAF-DB, extract penultimate layer embeddings (512-D or 2048-D)

In [1]:
# ============================================================================
# 1. IMPORTS AND SETUP
# ============================================================================

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
import json
import random

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import timm  # For EfficientNet and other modern architectures

# Image Processing
import cv2
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Face Detection (MediaPipe alternative using OpenCV)
# For production, integrate MediaPipe: pip install mediapipe
try:
    import mediapipe as mp
    MEDIAPIPE_AVAILABLE = True
except ImportError:
    MEDIAPIPE_AVAILABLE = False
    print("MediaPipe not available. Using OpenCV cascade for face detection.")

# Metrics and Logging
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split
import wandb  # Optional: for experiment tracking

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [3]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cpu


In [4]:
# ============================================================================
# 2. CONFIGURATION
# ============================================================================

class Config:
    # Paths
    DATA_ROOT = "./data/affectnet"  # Change to your dataset path
    OUTPUT_DIR = "./outputs/emotion_model"
    CHECKPOINT_DIR = "./checkpoints"
    
    # Model Architecture
    BACKBONE = "efficientnet_b3"  # Options: resnet50, efficientnet_b0, efficientnet_b3
    PRETRAINED = True
    EMBEDDING_DIM = 1536  # ResNet50: 2048, EfficientNet-B0: 1280, EfficientNet-B3: 1536
    NUM_CLASSES = 8  # AffectNet: neutral, happy, sad, surprise, fear, disgust, anger, contempt
    
    # Training Hyperparameters
    BATCH_SIZE = 64
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    SCHEDULER = "cosine"  # Options: cosine, step, plateau
    
    # Data Augmentation
    IMG_SIZE = 224
    TRAIN_AUGMENT = True
    
    # Training Strategy
    FREEZE_BACKBONE_EPOCHS = 5  # Freeze backbone for first N epochs
    GRAD_CLIP = 1.0
    MIXED_PRECISION = True  # Use automatic mixed precision
    
    # Early Stopping
    PATIENCE = 10
    MIN_DELTA = 0.001
    
    # Logging
    USE_WANDB = False  # Set to True to enable W&B logging
    WANDB_PROJECT = "emotion-embedding"
    
    # Emotion Labels (AffectNet) - mapping from string labels to indices
    EMOTION_TO_IDX = {
        'neutral': 0,
        'happy': 1,
        'sad': 2,
        'surprise': 3,
        'fear': 4,
        'disgust': 5,
        'anger': 6,
        'contempt': 7
    }
    
    # Reverse mapping for display
    EMOTION_LABELS = {v: k for k, v in EMOTION_TO_IDX.items()}

config = Config()
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)

In [5]:
# ============================================================================
# 3. FACE DETECTION AND PREPROCESSING
# ============================================================================

class FaceDetector:
    """Face detection and alignment using MediaPipe or OpenCV"""
    
    def __init__(self, use_mediapipe=MEDIAPIPE_AVAILABLE):
        self.use_mediapipe = use_mediapipe
        
        if self.use_mediapipe:
            self.mp_face_detection = mp.solutions.face_detection
            self.face_detection = self.mp_face_detection.FaceDetection(
                model_selection=1, 
                min_detection_confidence=0.5
            )
        else:
            # Fallback to OpenCV Haar Cascade
            cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
            self.face_cascade = cv2.CascadeClassifier(cascade_path)
    
    def detect_and_crop(self, image):
        """
        Detect face and return cropped, aligned face image
        Args:
            image: numpy array (H, W, 3) in RGB
        Returns:
            cropped_face: numpy array or None if no face detected
        """
        if self.use_mediapipe:
            results = self.face_detection.process(image)
            if not results.detections:
                return None
            
            detection = results.detections[0]  # Take first face
            bboxC = detection.location_data.relative_bounding_box
            h, w, _ = image.shape
            
            x = int(bboxC.xmin * w)
            y = int(bboxC.ymin * h)
            w_box = int(bboxC.width * w)
            h_box = int(bboxC.height * h)
            
            # Add margin
            margin = 0.2
            x = max(0, int(x - w_box * margin))
            y = max(0, int(y - h_box * margin))
            w_box = int(w_box * (1 + 2 * margin))
            h_box = int(h_box * (1 + 2 * margin))
            
            cropped = image[y:y+h_box, x:x+w_box]
            return cropped if cropped.size > 0 else None
        else:
            # OpenCV detection
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            faces = self.face_cascade.detectMultiScale(gray, 1.3, 5)
            
            if len(faces) == 0:
                return None
            
            # Get image dimensions
            img_h, img_w = image.shape[:2]
            
            x, y, w, h = faces[0]  # Take first face
            margin = int(0.2 * w)
            x = max(0, x - margin)
            y = max(0, y - margin)
            w = min(w + 2 * margin, img_w - x)  # Ensure we don't exceed image width
            h = min(h + 2 * margin, img_h - y)  # Ensure we don't exceed image height
            
            cropped = image[y:y+h, x:x+w]
            return cropped if cropped.size > 0 else None

face_detector = FaceDetector()

I0000 00:00:1766729195.353764       1 gl_context.cc:344] GL version: 2.1 (2.1 Metal - 90.5), renderer: Apple M2


INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


In [6]:
# ============================================================================
# 4. DATASET CLASS
# ============================================================================

class EmotionDataset(Dataset):
    """
    Dataset for emotion recognition from face images
    
    Expected data structure:
    data_root/
        Train/
            anger/
            happy/
            ...
        Test/
            Anger/
            happy/
            ...
        labels.csv (optional - with columns: pth, label)
    """
    
    def __init__(self, data_root, split='Train', transform=None, 
                 csv_path=None, use_face_detection=True):
        self.data_root = Path(data_root)
        self.split = split
        self.transform = transform
        self.use_face_detection = use_face_detection
        
        # Load data
        if csv_path:
            self.samples = self._load_from_csv(csv_path)
        else:
            self.samples = self._load_from_directory()
        
        print(f"{split} dataset: {len(self.samples)} samples")
        self._print_distribution()
    
    def _load_from_csv(self, csv_path):
        """Load dataset from CSV file"""
        df = pd.read_csv(csv_path)
        samples = []
        for _, row in df.iterrows():
            # CSV has columns: pth, label (string)
            img_path = self.data_root / self.split / row['pth']
            if img_path.exists():
                label_str = row['label'].lower()  # Normalize to lowercase
                if label_str in config.EMOTION_TO_IDX:
                    samples.append({
                        'path': str(img_path),
                        'label': config.EMOTION_TO_IDX[label_str]
                    })
        return samples
    
    def _load_from_directory(self):
        """Load dataset from directory structure"""
        samples = []
        split_dir = self.data_root / self.split
        
        if not split_dir.exists():
            raise ValueError(f"Split directory not found: {split_dir}")
        
        for emotion_dir in split_dir.iterdir():
            if not emotion_dir.is_dir():
                continue
            
            # Get emotion label from folder name (normalize to lowercase)
            emotion_name = emotion_dir.name.lower()
            
            if emotion_name not in config.EMOTION_TO_IDX:
                print(f"Warning: Unknown emotion folder '{emotion_dir.name}', skipping...")
                continue
            
            label = config.EMOTION_TO_IDX[emotion_name]
            
            for img_path in emotion_dir.glob('*'):
                if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
                    samples.append({'path': str(img_path), 'label': label})
        
        return samples
    
    def _print_distribution(self):
        """Print label distribution"""
        labels = [s['label'] for s in self.samples]
        unique, counts = np.unique(labels, return_counts=True)
        print(f"\nLabel distribution in {self.split}:")
        for label, count in zip(unique, counts):
            emotion_name = config.EMOTION_LABELS.get(label, f"class_{label}")
            print(f"  {emotion_name}: {count} ({100*count/len(labels):.1f}%)")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load image
        image = cv2.imread(sample['path'])
        if image is None:
            # Return a black image if loading fails
            image = np.zeros((config.IMG_SIZE, config.IMG_SIZE, 3), dtype=np.uint8)
        else:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Face detection and cropping
        if self.use_face_detection:
            face = face_detector.detect_and_crop(image)
            if face is None:
                # If no face detected, use whole image
                face = image
        else:
            face = image
        
        # Convert to PIL for transforms
        face = Image.fromarray(face)
        
        # Apply transforms
        if self.transform:
            face = self.transform(face)
        
        label = sample['label']
        
        return face, label

In [7]:
# ============================================================================
# 5. DATA AUGMENTATION
# ============================================================================

def get_transforms(split='train', img_size=224):
    """Get data transforms for training and validation"""
    
    if split == 'train':
        return transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])

In [8]:
# ============================================================================
# 6. MODEL ARCHITECTURE
# ============================================================================

class EmotionEmbeddingModel(nn.Module):
    """
    Emotion recognition model with embedding extraction capability
    
    Architecture:
    - Backbone (ResNet50/EfficientNet) -> Feature Extractor
    - Penultimate Layer -> Embedding (512/2048-D)
    - Classification Head -> Emotion Classes
    
    For inference: Remove classification head and extract embeddings
    """
    
    def __init__(self, backbone='resnet50', num_classes=7, 
                 pretrained=True, embedding_dim=2048):
        super().__init__()
        
        self.backbone_name = backbone
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        
        # Load backbone
        if backbone == 'resnet50':
            self.backbone = models.resnet50(pretrained=pretrained)
            in_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()  # Remove final FC layer
            
        elif backbone.startswith('efficientnet'):
            self.backbone = timm.create_model(backbone, pretrained=pretrained)
            in_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
            
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")
        
        # Embedding projection (if needed)
        if in_features != embedding_dim:
            self.embedding_projection = nn.Sequential(
                nn.Linear(in_features, embedding_dim),
                nn.BatchNorm1d(embedding_dim),
                nn.ReLU(inplace=True)
            )
        else:
            self.embedding_projection = nn.Identity()
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(embedding_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights for new layers"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x, return_embedding=False):
        """
        Forward pass
        
        Args:
            x: Input tensor (B, 3, H, W)
            return_embedding: If True, return embedding instead of logits
        
        Returns:
            If return_embedding=True: embedding (B, embedding_dim)
            Else: logits (B, num_classes)
        """
        # Extract features from backbone
        features = self.backbone(x)
        
        # Get embeddings
        embeddings = self.embedding_projection(features)
        
        if return_embedding:
            return embeddings
        
        # Classification
        logits = self.classifier(embeddings)
        
        return logits
    
    def freeze_backbone(self):
        """Freeze backbone weights"""
        for param in self.backbone.parameters():
            param.requires_grad = False
        print("Backbone frozen")
    
    def unfreeze_backbone(self):
        """Unfreeze backbone weights"""
        for param in self.backbone.parameters():
            param.requires_grad = True
        print("Backbone unfrozen")
    
    def get_embedding(self, x):
        """Extract embedding (for inference)"""
        with torch.no_grad():
            return self.forward(x, return_embedding=True)

In [9]:
# ============================================================================
# 7. TRAINING UTILITIES
# ============================================================================

class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class AverageMeter:
    """Computes and stores the average and current value"""
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class EarlyStopping:
    """Early stopping to stop training when validation loss doesn't improve"""
    
    def __init__(self, patience=10, min_delta=0.001, mode='min'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        
    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
        elif self._is_improvement(score):
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        
        return self.early_stop
    
    def _is_improvement(self, score):
        if self.mode == 'min':
            return score < (self.best_score - self.min_delta)
        else:
            return score > (self.best_score + self.min_delta)


In [10]:
# ============================================================================
# 8. TRAINING LOOP
# ============================================================================

class Trainer:
    """Training and validation logic"""
    
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        # Loss function (use Focal Loss for imbalanced datasets)
        self.criterion = FocalLoss(alpha=None, gamma=2.0)
        
        # Optimizer
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config.LEARNING_RATE,
            weight_decay=config.WEIGHT_DECAY
        )
        
        # Learning rate scheduler
        if config.SCHEDULER == 'cosine':
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer, T_max=config.NUM_EPOCHS
            )
        elif config.SCHEDULER == 'step':
            self.scheduler = optim.lr_scheduler.StepLR(
                self.optimizer, step_size=10, gamma=0.1
            )
        else:
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode='min', patience=5, factor=0.5
            )
        
        # Mixed precision training - only enable on CUDA
        if config.MIXED_PRECISION and torch.cuda.is_available():
            self.scaler = torch.amp.GradScaler('cuda')
        else:
            self.scaler = None
        
        # Early stopping
        self.early_stopping = EarlyStopping(
            patience=config.PATIENCE, 
            min_delta=config.MIN_DELTA,
            mode='min'
        )
        
        # Training history
        self.history = {
            'train_loss': [], 'train_acc': [],
            'val_loss': [], 'val_acc': [],
            'lr': []
        }
        
        self.best_val_acc = 0.0
    
    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        
        losses = AverageMeter()
        accs = AverageMeter()
        
        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.config.NUM_EPOCHS}')
        
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            
            # Mixed precision training
            if self.scaler:
                with torch.cuda.amp.autocast():
                    outputs = self.model(images)
                    loss = self.criterion(outputs, labels)
                
                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                
                # Gradient clipping
                if self.config.GRAD_CLIP:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.config.GRAD_CLIP
                    )
                
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                self.optimizer.zero_grad()
                loss.backward()
                
                if self.config.GRAD_CLIP:
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.config.GRAD_CLIP
                    )
                
                self.optimizer.step()
            
            # Calculate accuracy
            _, preds = torch.max(outputs, 1)
            acc = (preds == labels).float().mean()
            
            # Update meters
            losses.update(loss.item(), images.size(0))
            accs.update(acc.item(), images.size(0))
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{losses.avg:.4f}',
                'acc': f'{accs.avg:.4f}',
                'lr': f'{self.optimizer.param_groups[0]["lr"]:.6f}'
            })
        
        return losses.avg, accs.avg
    
    def validate(self):
        """Validate the model"""
        self.model.eval()
        
        losses = AverageMeter()
        accs = AverageMeter()
        
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels in tqdm(self.val_loader, desc='Validating'):
                images = images.to(device)
                labels = labels.to(device)
                
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                _, preds = torch.max(outputs, 1)
                acc = (preds == labels).float().mean()
                
                losses.update(loss.item(), images.size(0))
                accs.update(acc.item(), images.size(0))
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        return losses.avg, accs.avg, all_preds, all_labels
    
    def train(self):
        """Full training loop"""
        print("\n" + "="*80)
        print("STARTING TRAINING")
        print("="*80 + "\n")
        
        # Freeze backbone for initial epochs if specified
        if self.config.FREEZE_BACKBONE_EPOCHS > 0:
            self.model.freeze_backbone()
        
        for epoch in range(self.config.NUM_EPOCHS):
            # Unfreeze backbone after specified epochs
            if epoch == self.config.FREEZE_BACKBONE_EPOCHS:
                self.model.unfreeze_backbone()
            
            # Train
            train_loss, train_acc = self.train_epoch(epoch)
            
            # Validate
            val_loss, val_acc, val_preds, val_labels = self.validate()
            
            # Update scheduler
            if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.scheduler.step(val_loss)
            else:
                self.scheduler.step()
            
            # Record history
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['lr'].append(self.optimizer.param_groups[0]['lr'])
            
            # Print epoch summary
            print(f"\nEpoch {epoch+1}/{self.config.NUM_EPOCHS}")
            print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
            print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            print(f"LR: {self.optimizer.param_groups[0]['lr']:.6f}")
            
            # Save best model
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.save_checkpoint(epoch, is_best=True)
                print(f"✓ New best model saved! Val Acc: {val_acc:.4f}")
            
            # Regular checkpoint
            if (epoch + 1) % 5 == 0:
                self.save_checkpoint(epoch, is_best=False)
            
            # Early stopping
            if self.early_stopping(val_loss):
                print(f"\nEarly stopping triggered at epoch {epoch+1}")
                break
            
            print("-" * 80)
        
        print("\n" + "="*80)
        print("TRAINING COMPLETED")
        print(f"Best Validation Accuracy: {self.best_val_acc:.4f}")
        print("="*80 + "\n")
        
        return self.history
    
    def save_checkpoint(self, epoch, is_best=False):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_val_acc': self.best_val_acc,
            'history': self.history,
            'config': vars(self.config)
        }
        
        if is_best:
            path = os.path.join(self.config.CHECKPOINT_DIR, 'best_model.pth')
        else:
            path = os.path.join(self.config.CHECKPOINT_DIR, f'checkpoint_epoch_{epoch+1}.pth')
        
        torch.save(checkpoint, path)

In [11]:
# ============================================================================
# 9. EVALUATION AND VISUALIZATION
# ============================================================================

def plot_training_history(history, save_path=None):
    """Plot training and validation metrics"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Accuracy
    axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
    axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    # Learning Rate
    axes[2].plot(history['lr'], label='Learning Rate', marker='o', color='orange')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_title('Learning Rate Schedule')
    axes[2].set_yscale('log')
    axes[2].legend()
    axes[2].grid(True)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_confusion_matrix(y_true, y_pred, labels, save_path=None):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def evaluate_model(model, test_loader):
    """Comprehensive model evaluation"""
    model.eval()
    
    all_preds = []
    all_labels = []
    all_embeddings = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Evaluating'):
            images = images.to(device)
            
            # Get predictions
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            # Get embeddings
            embeddings = model(images, return_embedding=True)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_embeddings.append(embeddings.cpu().numpy())
    
    all_embeddings = np.vstack(all_embeddings)
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"\nTest Accuracy: {accuracy:.4f}")
    
    # Classification report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, 
                                target_names=list(config.EMOTION_LABELS.values())))
    
    # Confusion matrix
    plot_confusion_matrix(all_labels, all_preds, 
                         list(config.EMOTION_LABELS.values()),
                         save_path=os.path.join(config.OUTPUT_DIR, 'confusion_matrix.png'))
    
    return all_embeddings, all_labels, all_preds

In [12]:
# ============================================================================
# 10. EMBEDDING EXTRACTION FOR INFERENCE
# ============================================================================

class EmbeddingExtractor:
    """Extract embeddings from trained model for inference"""
    
    def __init__(self, model_path, config):
        self.config = config
        self.device = device
        
        # Load model
        self.model = EmotionEmbeddingModel(
            backbone=config.BACKBONE,
            num_classes=config.NUM_CLASSES,
            embedding_dim=config.EMBEDDING_DIM
        ).to(self.device)
        
        # Load checkpoint
        checkpoint = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        
        print(f"Model loaded from {model_path}")
        print(f"Embedding dimension: {config.EMBEDDING_DIM}")
    
    def extract_embedding(self, image):
        """
        Extract embedding from a single image
        
        Args:
            image: PIL Image or numpy array (H, W, 3)
        
        Returns:
            embedding: numpy array of shape (embedding_dim,)
        """
        # Preprocess
        transform = get_transforms(split='val', img_size=self.config.IMG_SIZE)
        
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        
        image_tensor = transform(image).unsqueeze(0).to(self.device)
        
        # Extract embedding
        with torch.no_grad():
            embedding = self.model.get_embedding(image_tensor)
        
        return embedding.cpu().numpy().squeeze()
    
    def extract_embeddings_batch(self, images):
        """
        Extract embeddings from a batch of images
        
        Args:
            images: List of PIL Images or numpy arrays
        
        Returns:
            embeddings: numpy array of shape (N, embedding_dim)
        """
        transform = get_transforms(split='val', img_size=self.config.IMG_SIZE)
        
        image_tensors = []
        for img in images:
            if isinstance(img, np.ndarray):
                img = Image.fromarray(img)
            image_tensors.append(transform(img))
        
        batch = torch.stack(image_tensors).to(self.device)
        
        with torch.no_grad():
            embeddings = self.model.get_embedding(batch)
        
        return embeddings.cpu().numpy()
    
    def extract_video_embeddings(self, video_path, sample_rate=5):
        """
        Extract embeddings from video frames
        
        Args:
            video_path: Path to video file
            sample_rate: Extract every Nth frame
        
        Returns:
            embeddings: numpy array of shape (num_frames, embedding_dim)
            frame_indices: List of frame indices
        """
        cap = cv2.VideoCapture(video_path)
        
        embeddings = []
        frame_indices = []
        frame_count = 0
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            
            if frame_count % sample_rate == 0:
                # Convert BGR to RGB
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                
                # Detect and crop face
                face = face_detector.detect_and_crop(frame_rgb)
                if face is not None:
                    # Extract embedding
                    embedding = self.extract_embedding(face)
                    embeddings.append(embedding)
                    frame_indices.append(frame_count)
            
            frame_count += 1
        
        cap.release()
        
        if len(embeddings) == 0:
            return None, None
        
        return np.array(embeddings), frame_indices

In [13]:
# ============================================================================
# 11. MAIN EXECUTION
# ============================================================================

def main():
    """Main training pipeline"""
    
    print("="*80)
    print("EMOTION EMBEDDING MODEL TRAINING PIPELINE")
    print("="*80)
    
    # Initialize W&B (optional)
    if config.USE_WANDB:
        wandb.init(project=config.WANDB_PROJECT, config=vars(config))
    
    # Create datasets
    print("\nLoading datasets...")
    train_dataset = EmotionDataset(
        data_root=config.DATA_ROOT,
        split='Train',  # Use actual folder name
        transform=get_transforms('train', config.IMG_SIZE),
        use_face_detection=True
    )
    
    val_dataset = EmotionDataset(
        data_root=config.DATA_ROOT,
        split='Test',  # Use actual folder name (Test instead of val)
        transform=get_transforms('val', config.IMG_SIZE),
        use_face_detection=True
    )
    
    # Create data loaders
    # Note: num_workers=0 required for Jupyter notebooks (multiprocessing issues)
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=0,  # Set to 0 for Jupyter notebook compatibility
        pin_memory=False  # Disable for MPS/CPU
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=0,  # Set to 0 for Jupyter notebook compatibility
        pin_memory=False  # Disable for MPS/CPU
    )
    
    # Initialize model
    print(f"\nInitializing model: {config.BACKBONE}")
    model = EmotionEmbeddingModel(
        backbone=config.BACKBONE,
        num_classes=config.NUM_CLASSES,
        pretrained=config.PRETRAINED,
        embedding_dim=config.EMBEDDING_DIM
    )
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Initialize trainer
    trainer = Trainer(model, train_loader, val_loader, config)
    
    # Train model
    history = trainer.train()
    
    # Plot training history
    plot_training_history(
        history,
        save_path=os.path.join(config.OUTPUT_DIR, 'training_history.png')
    )
    
    # Evaluate on validation set
    print("\nFinal evaluation on validation set...")
    model_path = os.path.join(config.CHECKPOINT_DIR, 'best_model.pth')
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    val_embeddings, val_labels, val_preds = evaluate_model(model, val_loader)
    
    # Save embeddings for analysis
    np.save(os.path.join(config.OUTPUT_DIR, 'val_embeddings.npy'), val_embeddings)
    np.save(os.path.join(config.OUTPUT_DIR, 'val_labels.npy'), val_labels)
    
    print("\n" + "="*80)
    print("TRAINING PIPELINE COMPLETED SUCCESSFULLY!")
    print(f"Best model saved at: {model_path}")
    print(f"Outputs saved in: {config.OUTPUT_DIR}")
    print("="*80)
    
    # Demonstration of embedding extraction
    print("\n" + "="*80)
    print("DEMONSTRATION: EMBEDDING EXTRACTION")
    print("="*80)
    
    extractor = EmbeddingExtractor(model_path, config)
    
    # Example: Extract embedding from first validation image
    sample_image, sample_label = val_dataset[0]
    sample_image_pil = transforms.ToPILImage()(sample_image)
    
    embedding = extractor.extract_embedding(sample_image_pil)
    print(f"\nExtracted embedding shape: {embedding.shape}")
    print(f"True emotion: {config.EMOTION_LABELS[sample_label]}")
    print(f"Embedding (first 10 values): {embedding[:10]}")

if __name__ == '__main__':
    main()

EMOTION EMBEDDING MODEL TRAINING PIPELINE

Loading datasets...
Train dataset: 16108 samples

Label distribution in Train:
  neutral: 2758 (17.1%)
  happy: 2340 (14.5%)
  sad: 3091 (19.2%)
  surprise: 2119 (13.2%)
  fear: 1512 (9.4%)
  disgust: 1229 (7.6%)
  anger: 1500 (9.3%)
  contempt: 1559 (9.7%)
Test dataset: 14518 samples

Label distribution in Test:
  neutral: 2368 (16.3%)
  happy: 2704 (18.6%)
  sad: 1584 (10.9%)
  surprise: 1920 (13.2%)
  fear: 1664 (11.5%)
  disgust: 1248 (8.6%)
  anger: 1718 (11.8%)
  contempt: 1312 (9.0%)

Initializing model: efficientnet_b3
Total parameters: 11,488,304
Trainable parameters: 11,488,304

STARTING TRAINING

Backbone frozen


Epoch 1/50: 100%|██████████| 252/252 [26:43<00:00,  6.36s/it, loss=11.2302, acc=0.1874, lr=0.000100]
Validating: 100%|██████████| 227/227 [23:47<00:00,  6.29s/it]



Epoch 1/50
Train Loss: 11.2302 | Train Acc: 0.1874
Val Loss: 5.1088 | Val Acc: 0.2823
LR: 0.000100
✓ New best model saved! Val Acc: 0.2823
--------------------------------------------------------------------------------


Epoch 2/50: 100%|██████████| 252/252 [27:50<00:00,  6.63s/it, loss=8.1157, acc=0.2722, lr=0.000100]
Validating: 100%|██████████| 227/227 [25:05<00:00,  6.63s/it]



Epoch 2/50
Train Loss: 8.1157 | Train Acc: 0.2722
Val Loss: 4.4158 | Val Acc: 0.3302
LR: 0.000100
✓ New best model saved! Val Acc: 0.3302
--------------------------------------------------------------------------------


Epoch 3/50: 100%|██████████| 252/252 [36:29<00:00,  8.69s/it, loss=7.2560, acc=0.2998, lr=0.000100]
Validating: 100%|██████████| 227/227 [27:17<00:00,  7.22s/it]



Epoch 3/50
Train Loss: 7.2560 | Train Acc: 0.2998
Val Loss: 4.0638 | Val Acc: 0.3456
LR: 0.000099
✓ New best model saved! Val Acc: 0.3456
--------------------------------------------------------------------------------


Epoch 4/50:  64%|██████▍   | 162/252 [21:40<12:02,  8.03s/it, loss=6.9971, acc=0.3130, lr=0.000099]


KeyboardInterrupt: 