# Enhanced Emotion Recognition Model - Targeting 80%+ Accuracy

This notebook implements advanced techniques to improve emotion recognition accuracy:

## Key Improvements:
1. **Enhanced Architecture**: ResNet50 + EfficientNet ensemble
2. **Advanced Augmentation**: MixUp, CutMix, AutoAugment, TTA
3. **Better Loss Functions**: Focal Loss + Label Smoothing
4. **Optimized Training**: Progressive learning, longer training, mixed precision
5. **Class Imbalance Handling**: Advanced sampling + class weights
6. **Regularization**: Dropout, Weight Decay, Stochastic Depth

In [None]:
# =========================
# Cell 1 — Enhanced Imports & Setup
# =========================
import os, json, math, time, random, warnings
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
from collections import Counter
from typing import List, Tuple, Dict, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim import AdamW, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
from torch.cuda.amp import GradScaler, autocast

import torchvision
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights, efficientnet_b3, EfficientNet_B3_Weights

from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score, precision_score, recall_score\nfrom sklearn.utils.class_weight import compute_class_weight

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import clear_output

plt.style.use('seaborn-v0_8')
plt.rcParams["figure.dpi"] = 120
warnings.filterwarnings('ignore')

# Enhanced reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # For reproducible data augmentation
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")

In [None]:
# ============================
# Cell 2 — Enhanced Configuration
# ============================

# ---- Base paths ----
PROJECT_ROOT = "/home/runner/work/ann-visual-emotion/ann-visual-emotion"  # Adjust as needed
RAW_DIR = Path(PROJECT_ROOT) / "data" / "raw" / "EmoSet"
SPLIT_DIR = Path(PROJECT_ROOT) / "data" / "processed" / "EmoSet_splits"

# ---- Dataset paths ----
TRAIN_CSV = str(SPLIT_DIR / "train.csv")
VAL_CSV = str(SPLIT_DIR / "val.csv")
TEST_CSV = str(SPLIT_DIR / "test.csv")
IMAGES_ROOT = str(RAW_DIR)

# ---- Enhanced Model Configuration ----
IMG_SIZE = 224  # Standard size, will use progressive resizing
BATCH_SIZE = 32  # Reduced for larger models
NUM_WORKERS = 4
PIN_MEMORY = True

# ---- Enhanced Training Configuration ----
EPOCHS = 100  # Increased for better convergence
WARMUP_EPOCHS = 5
BASE_LR = 1e-3  # Higher initial learning rate
MAX_LR = 5e-3   # For OneCycleLR
WEIGHT_DECAY = 1e-4
LABEL_SMOOTHING = 0.1  # Increased for better generalization

# ---- Advanced Loss Configuration ----
LOSS_MODE = "focal"  # focal, ce, or combined
FOCAL_ALPHA = 0.25
FOCAL_GAMMA = 2.0
CLASS_WEIGHT_MODE = "balanced"  # balanced, sqrt, or none

# ---- Enhanced Augmentation Configuration ----
USE_MIXUP = True
MIXUP_ALPHA = 0.4  # Increased
USE_CUTMIX = True  # Enable CutMix
CUTMIX_ALPHA = 1.0
MIXUP_CUTMIX_PROB = 0.5  # Probability of applying mixup vs cutmix
USE_AUTOAUGMENT = True
AUTOAUGMENT_POLICY = "imagenet"  # imagenet, cifar10, svhn

# ---- Model Architecture Configuration ----
MODEL_NAME = "resnet50"  # resnet50, efficientnet_b3, ensemble
PRETRAINED = True
DROPOUT_RATE = 0.5  # Increased dropout
STOCHASTIC_DEPTH_RATE = 0.2  # For regularization

# ---- Training Optimization ----
USE_MIXED_PRECISION = True
GRADIENT_CLIP = 1.0
ACCUMULATION_STEPS = 2  # Effective batch size = BATCH_SIZE * ACCUMULATION_STEPS

# ---- Progressive Training ----
USE_PROGRESSIVE_RESIZING = True
INITIAL_IMG_SIZE = 176
FINAL_IMG_SIZE = 224
RESIZE_EPOCHS = [30]  # Epochs to resize at

# ---- Test-Time Augmentation ----
USE_TTA = True
TTA_TRANSFORMS = 8  # Number of TTA transforms

# ---- Early Stopping ----
PATIENCE = 15  # Increased patience for longer training
MIN_DELTA = 1e-4

# ---- Output Configuration ----
OUT_DIR = str(Path(PROJECT_ROOT) / "outputs" / "enhanced_emotion_model")
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

print(f"Output directory: {OUT_DIR}")
print(f"Configuration: {MODEL_NAME}, {EPOCHS} epochs, {BATCH_SIZE} batch size")
print(f"Augmentations: MixUp={USE_MIXUP}, CutMix={USE_CUTMIX}, AutoAugment={USE_AUTOAUGMENT}")
print(f"Advanced features: Mixed Precision={USE_MIXED_PRECISION}, TTA={USE_TTA}")

In [None]:
# ===============================
# Cell 3 — Enhanced Dataset Class
# ===============================

class EnhancedEmotionDataset(Dataset):
    def __init__(self, csv_path: str, images_root: str, str2idx: dict, 
                 img_size: int = 224, train: bool = True, use_autoaugment: bool = False):
        self.df = pd.read_csv(csv_path)
        self.images_root = Path(images_root)
        self.str2idx = str2idx
        self.img_size = img_size
        self.train = train
        self.use_autoaugment = use_autoaugment
        
        # Handle different column names
        self.path_col = "image_path" if "image_path" in self.df.columns else "image"
        if self.path_col not in self.df.columns:
            raise ValueError("CSV must contain 'image' or 'image_path' column")
            
        # Convert labels to indices
        if "label" not in self.df.columns:
            raise ValueError("CSV must contain 'label' column")
            
        if self.df["label"].dtype == object:
            self.df["label_idx"] = self.df["label"].map(self.str2idx).astype(int)
        else:
            self.df["label_idx"] = self.df["label"].astype(int)
            
        # Enhanced normalization (ImageNet stats)
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        )
        
        # Enhanced augmentation transforms
        if train:
            augmentations = [
                transforms.Resize((img_size, img_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=15, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
            ]
            
            # Add AutoAugment if enabled
            if use_autoaugment:
                if AUTOAUGMENT_POLICY == "imagenet":
                    augmentations.append(transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET))
                elif AUTOAUGMENT_POLICY == "cifar10":
                    augmentations.append(transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10))
            
            augmentations.extend([
                transforms.ToTensor(),
                self.normalize,
                transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3))
            ])
            
            self.transform = transforms.Compose(augmentations)
        else:
            self.transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                self.normalize
            ])
            
        print(f"Dataset initialized: {len(self.df)} samples, img_size={img_size}, train={train}")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Handle relative paths
        img_path = row[self.path_col]
        if not img_path.startswith('/'):
            # Remove 'data/raw/EmoSet/' prefix if present in CSV
            if img_path.startswith('data/raw/EmoSet/'):
                img_path = img_path.replace('data/raw/EmoSet/', '')
            img_path = self.images_root / img_path
        else:
            img_path = Path(img_path)
            
        try:
            # Load and convert image
            image = Image.open(img_path).convert('RGB')
            image = self.transform(image)
            
            label = int(row["label_idx"])
            return image, label
            
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a black image and label 0 as fallback
            black_image = torch.zeros(3, self.img_size, self.img_size)
            return black_image, 0
    
    def update_img_size(self, new_size: int):
        """Update image size for progressive resizing"""
        self.img_size = new_size
        # Rebuild transforms with new size
        if self.train:
            augmentations = [
                transforms.Resize((new_size, new_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=15),
                transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                transforms.ToTensor(),
                self.normalize,
                transforms.RandomErasing(p=0.25, scale=(0.02, 0.2))
            ]
            if self.use_autoaugment and USE_AUTOAUGMENT:
                augmentations.insert(-3, transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET))
            self.transform = transforms.Compose(augmentations)
        else:
            self.transform = transforms.Compose([
                transforms.Resize((new_size, new_size)),
                transforms.ToTensor(),
                self.normalize
            ])
        print(f"Updated image size to {new_size}x{new_size}")

def load_label_mapping(csv_path: str) -> Tuple[Dict[str, int], Dict[int, str]]:
    """Create label mapping from CSV file"""
    df = pd.read_csv(csv_path)
    unique_labels = sorted(df['label'].unique())
    str2idx = {label: idx for idx, label in enumerate(unique_labels)}
    idx2str = {idx: label for label, idx in str2idx.items()}
    return str2idx, idx2str

# Load label mappings
str2idx, idx2str = load_label_mapping(TRAIN_CSV)
num_classes = len(idx2str)
print(f"Found {num_classes} classes: {list(idx2str.values())}")

In [None]:
# =====================================
# Cell 4 — Advanced Loss Functions
# =====================================

class FocalLoss(nn.Module):
    """Enhanced Focal Loss for handling class imbalance"""
    def __init__(self, alpha=1.0, gamma=2.0, weight=None, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction
        self.ce_loss = nn.CrossEntropyLoss(weight=weight, reduction='none')
    
    def forward(self, inputs, targets):
        ce_loss = self.ce_loss(inputs, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing cross-entropy loss"""
    def __init__(self, num_classes, smoothing=0.1, weight=None, reduction='mean'):
        super().__init__()
        self.num_classes = num_classes
        self.smoothing = smoothing
        self.weight = weight
        self.reduction = reduction
        self.confidence = 1.0 - smoothing
    
    def forward(self, inputs, targets):
        log_probs = F.log_softmax(inputs, dim=1)
        
        # Create smooth targets
        smooth_targets = torch.zeros_like(log_probs)
        smooth_targets.fill_(self.smoothing / (self.num_classes - 1))
        smooth_targets.scatter_(1, targets.unsqueeze(1), self.confidence)
        
        # Apply class weights if provided
        if self.weight is not None:
            smooth_targets = smooth_targets * self.weight.unsqueeze(0)
        
        loss = -torch.sum(smooth_targets * log_probs, dim=1)
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class CombinedLoss(nn.Module):
    """Combination of Focal Loss and Label Smoothing"""
    def __init__(self, num_classes, focal_alpha=0.25, focal_gamma=2.0, 
                 smoothing=0.1, weight=None, focal_weight=0.7):
        super().__init__()
        self.focal_loss = FocalLoss(focal_alpha, focal_gamma, weight)
        self.smooth_loss = LabelSmoothingCrossEntropy(num_classes, smoothing, weight)
        self.focal_weight = focal_weight
    
    def forward(self, inputs, targets):
        focal = self.focal_loss(inputs, targets)
        smooth = self.smooth_loss(inputs, targets)
        return self.focal_weight * focal + (1 - self.focal_weight) * smooth

def compute_enhanced_class_weights(train_csv: str, mode: str = "balanced") -> torch.Tensor:
    """Compute enhanced class weights for imbalanced dataset"""
    df = pd.read_csv(train_csv)
    labels = df['label'].values
    
    if mode == "balanced":
        # Sklearn's balanced approach
        unique_labels = sorted(df['label'].unique())
        label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
        y_indices = [label_to_idx[label] for label in labels]
        
        weights = compute_class_weight('balanced', classes=np.arange(len(unique_labels)), y=y_indices)
        return torch.tensor(weights, dtype=torch.float32)
    
    elif mode == "sqrt":
        # Square root of inverse frequency
        class_counts = Counter(labels)
        total = len(labels)
        weights = []
        for label in sorted(class_counts.keys()):
            freq = class_counts[label] / total
            weight = 1.0 / np.sqrt(freq)
            weights.append(weight)
        # Normalize weights
        weights = np.array(weights)
        weights = weights / weights.sum() * len(weights)
        return torch.tensor(weights, dtype=torch.float32)
    
    else:  # mode == "none"
        return None

# Compute class weights
class_weights = compute_enhanced_class_weights(TRAIN_CSV, CLASS_WEIGHT_MODE)
if class_weights is not None:
    class_weights = class_weights.to(device)
    print(f"Class weights ({CLASS_WEIGHT_MODE}): {class_weights.cpu().numpy()}")
else:
    print("No class weighting applied")

In [None]:
# ===============================
# Cell 5 — Enhanced Model Architecture
# ===============================

class EnhancedEmotionModel(nn.Module):
    """Enhanced model with better regularization and architecture"""
    def __init__(self, model_name: str, num_classes: int, pretrained: bool = True,
                 dropout_rate: float = 0.5, use_attention: bool = True):
        super().__init__()
        self.model_name = model_name
        self.num_classes = num_classes
        self.use_attention = use_attention
        
        # Load backbone
        if model_name == "resnet50":
            weights = ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
            self.backbone = resnet50(weights=weights)
            feature_dim = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()  # Remove original classifier
            
        elif model_name == "efficientnet_b3":
            weights = EfficientNet_B3_Weights.IMAGENET1K_V1 if pretrained else None
            self.backbone = efficientnet_b3(weights=weights)
            feature_dim = self.backbone.classifier[1].in_features
            self.backbone.classifier = nn.Identity()
            
        else:
            raise ValueError(f"Unsupported model: {model_name}")
        
        # Enhanced classifier head
        classifier_layers = []
        
        # Global Average Pooling (if not already applied)
        if model_name == "resnet50":
            self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        else:
            self.global_pool = nn.Identity()
        
        # Attention mechanism
        if use_attention:
            self.attention = nn.MultiheadAttention(feature_dim, num_heads=8, batch_first=True)
            self.attention_norm = nn.LayerNorm(feature_dim)
        
        # Enhanced classifier with multiple dropout layers
        classifier_layers.extend([
            nn.BatchNorm1d(feature_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(feature_dim // 2),
            nn.Dropout(dropout_rate / 2),
            nn.Linear(feature_dim // 2, num_classes)
        ])
        
        self.classifier = nn.Sequential(*classifier_layers)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize classifier weights"""
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Extract features
        features = self.backbone(x)
        
        # Handle different backbone outputs
        if len(features.shape) == 4:  # ResNet case: [B, C, H, W]
            features = self.global_pool(features)
            features = features.flatten(1)  # [B, C]
        
        # Apply attention if enabled
        if self.use_attention:
            # Reshape for attention: [B, 1, C]
            feat_reshaped = features.unsqueeze(1)
            attn_out, _ = self.attention(feat_reshaped, feat_reshaped, feat_reshaped)
            features = self.attention_norm(attn_out.squeeze(1) + features)
        
        # Classification
        logits = self.classifier(features)
        return logits

def create_model(model_name: str, num_classes: int, pretrained: bool = True) -> nn.Module:
    """Factory function to create enhanced models"""
    return EnhancedEmotionModel(
        model_name=model_name,
        num_classes=num_classes,
        pretrained=pretrained,
        dropout_rate=DROPOUT_RATE,
        use_attention=True
    )

# Create model
model = create_model(MODEL_NAME, num_classes, PRETRAINED)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: {MODEL_NAME}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params * 4 / 1024 / 1024:.1f} MB")

In [None]:
# =======================================
# Cell 6 — Advanced Data Augmentation Utilities
# =======================================

def rand_bbox(W, H, lam):
    """Generate random bounding box for CutMix"""
    cut_rat = np.sqrt(1.0 - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    
    # Uniform sampling
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    return bbx1, bby1, bbx2, bby2

def mixup_data(x, y, alpha=1.0, use_cuda=True):
    """Perform MixUp augmentation"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size(0)
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)
    
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=1.0, use_cuda=True):
    """Perform CutMix augmentation"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size(0)
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)
    
    y_a, y_b = y, y[index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(3), x.size(2), lam)
    x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
    
    # Adjust lambda to match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size(-1) * x.size(-2)))
    return x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Compute loss for mixed samples"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

class TestTimeAugmentation:
    """Test-Time Augmentation for better inference"""
    def __init__(self, img_size=224, n_transforms=8):
        self.img_size = img_size
        self.n_transforms = n_transforms
        
        # Define TTA transforms
        self.transforms = []
        
        # Original
        self.transforms.append(transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))
        
        # Horizontal flip
        self.transforms.append(transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=1.0),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))
        
        # Slight rotations
        for angle in [-5, 5]:
            self.transforms.append(transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.RandomRotation((angle, angle)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]))
        
        # Different scales
        for scale in [0.95, 1.05]:
            self.transforms.append(transforms.Compose([
                transforms.Resize((int(img_size * scale), int(img_size * scale))),
                transforms.CenterCrop((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]))
        
        # Color jitter
        self.transforms.append(transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))
        
        # Random crop
        self.transforms.append(transforms.Compose([
            transforms.Resize((int(img_size * 1.1), int(img_size * 1.1))),
            transforms.RandomCrop((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))
        
        # Limit to requested number of transforms
        self.transforms = self.transforms[:n_transforms]
        
    def __call__(self, pil_image):
        """Apply TTA to a PIL image and return list of tensors"""
        return [transform(pil_image) for transform in self.transforms]

print(f"Augmentation utilities loaded:")
print(f"- MixUp: {USE_MIXUP} (alpha={MIXUP_ALPHA})")
print(f"- CutMix: {USE_CUTMIX} (alpha={CUTMIX_ALPHA})")
print(f"- TTA: {USE_TTA} ({TTA_TRANSFORMS} transforms)")
print(f"- AutoAugment: {USE_AUTOAUGMENT} (policy={AUTOAUGMENT_POLICY})")

In [None]:
# ===============================
# Cell 7 — Create Enhanced Datasets and DataLoaders
# ===============================

# Start with initial image size for progressive training
initial_size = INITIAL_IMG_SIZE if USE_PROGRESSIVE_RESIZING else IMG_SIZE

# Create datasets
train_dataset = EnhancedEmotionDataset(
    csv_path=TRAIN_CSV,
    images_root=IMAGES_ROOT,
    str2idx=str2idx,
    img_size=initial_size,
    train=True,
    use_autoaugment=USE_AUTOAUGMENT
)

val_dataset = EnhancedEmotionDataset(
    csv_path=VAL_CSV,
    images_root=IMAGES_ROOT,
    str2idx=str2idx,
    img_size=initial_size,
    train=False,
    use_autoaugment=False
)

test_dataset = EnhancedEmotionDataset(
    csv_path=TEST_CSV,
    images_root=IMAGES_ROOT,
    str2idx=str2idx,
    img_size=initial_size,
    train=False,
    use_autoaugment=False
)

# Create weighted sampler for training to handle class imbalance
def create_weighted_sampler(dataset):
    """Create weighted sampler for imbalanced dataset"""
    # Get labels from dataset
    labels = []
    for idx in range(len(dataset)):
        _, label = dataset[idx]
        labels.append(label)
    
    # Compute class weights
    class_counts = Counter(labels)
    total_samples = len(labels)
    
    # Compute weights for each class
    class_weights = {}
    for class_idx, count in class_counts.items():
        class_weights[class_idx] = total_samples / (len(class_counts) * count)
    
    # Create sample weights
    sample_weights = [class_weights[label] for label in labels]
    
    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

# Create samplers
train_sampler = create_weighted_sampler(train_dataset)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=train_sampler,  # Use weighted sampler instead of shuffle
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    drop_last=True,  # For consistent batch sizes with mixed precision
    persistent_workers=True if NUM_WORKERS > 0 else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE * 2,  # Larger batch size for validation
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    persistent_workers=True if NUM_WORKERS > 0 else False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE * 2,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    persistent_workers=True if NUM_WORKERS > 0 else False
)

print(f"DataLoaders created:")
print(f"- Training: {len(train_loader)} batches ({len(train_dataset)} samples)")
print(f"- Validation: {len(val_loader)} batches ({len(val_dataset)} samples)")
print(f"- Test: {len(test_loader)} batches ({len(test_dataset)} samples)")
print(f"- Effective batch size: {BATCH_SIZE * ACCUMULATION_STEPS}")

In [None]:
# =======================================
# Cell 8 — Enhanced Training Setup
# =======================================

# Create enhanced loss function
if LOSS_MODE == "focal":
    criterion = FocalLoss(
        alpha=FOCAL_ALPHA,
        gamma=FOCAL_GAMMA,
        weight=class_weights
    )
elif LOSS_MODE == "ce":
    criterion = LabelSmoothingCrossEntropy(
        num_classes=num_classes,
        smoothing=LABEL_SMOOTHING,
        weight=class_weights
    )
elif LOSS_MODE == "combined":
    criterion = CombinedLoss(
        num_classes=num_classes,
        focal_alpha=FOCAL_ALPHA,
        focal_gamma=FOCAL_GAMMA,
        smoothing=LABEL_SMOOTHING,
        weight=class_weights
    )
else:
    raise ValueError(f"Unknown loss mode: {LOSS_MODE}")

# Enhanced optimizer
optimizer = AdamW(
    model.parameters(),
    lr=BASE_LR,
    weight_decay=WEIGHT_DECAY,
    betas=(0.9, 0.999),
    eps=1e-8
)

# Enhanced learning rate scheduler - OneCycleLR for better convergence
scheduler = OneCycleLR(
    optimizer,
    max_lr=MAX_LR,
    epochs=EPOCHS,
    steps_per_epoch=len(train_loader),
    pct_start=0.1,  # 10% warmup
    anneal_strategy='cos',
    div_factor=10.0,  # Initial lr = max_lr / div_factor
    final_div_factor=100.0  # Final lr = max_lr / final_div_factor
)

# Mixed precision scaler
scaler = GradScaler() if USE_MIXED_PRECISION else None

# Enhanced early stopping
class EnhancedEarlyStopping:
    def __init__(self, patience=15, min_delta=1e-4, monitor='val_f1', mode='max'):
        self.patience = patience
        self.min_delta = min_delta
        self.monitor = monitor
        self.mode = mode
        self.best_score = float('-inf') if mode == 'max' else float('inf')
        self.counter = 0
        self.best_epoch = 0
        self.early_stop = False
        
    def __call__(self, epoch, score):
        if self.mode == 'max':
            if score > self.best_score + self.min_delta:
                self.best_score = score
                self.counter = 0
                self.best_epoch = epoch
                return True  # Improvement
            else:
                self.counter += 1
        else:  # mode == 'min'
            if score < self.best_score - self.min_delta:
                self.best_score = score
                self.counter = 0
                self.best_epoch = epoch
                return True  # Improvement
            else:
                self.counter += 1
                
        if self.counter >= self.patience:
            self.early_stop = True
            
        return False  # No improvement

early_stopping = EnhancedEarlyStopping(
    patience=PATIENCE,
    min_delta=MIN_DELTA,
    monitor='val_f1',
    mode='max'
)

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'val_acc': [],
    'val_f1': [],
    'learning_rate': []
}

print(f"Training setup complete:")
print(f"- Loss function: {LOSS_MODE}")
print(f"- Optimizer: AdamW (lr={BASE_LR}, wd={WEIGHT_DECAY})")
print(f"- Scheduler: OneCycleLR (max_lr={MAX_LR})")
print(f"- Mixed precision: {USE_MIXED_PRECISION}")
print(f"- Early stopping: patience={PATIENCE}")
print(f"- Gradient clipping: {GRADIENT_CLIP}")

In [None]:
# =======================================
# Cell 9 — Enhanced Training Loop
# =======================================

def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch):
    """Enhanced training loop with mixed precision and advanced augmentation"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Progress tracking
    from tqdm import tqdm
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}')
    
    for batch_idx, (data, targets) in enumerate(pbar):
        data, targets = data.to(device), targets.to(device)
        
        # Apply MixUp or CutMix randomly
        use_mixup = USE_MIXUP and random.random() < 0.5
        use_cutmix = USE_CUTMIX and not use_mixup and random.random() < 0.5
        
        if use_mixup:
            data, targets_a, targets_b, lam = mixup_data(data, targets, MIXUP_ALPHA, True)
            mixed_targets = (targets_a, targets_b, lam)
        elif use_cutmix:
            data, targets_a, targets_b, lam = cutmix_data(data, targets, CUTMIX_ALPHA, True)
            mixed_targets = (targets_a, targets_b, lam)
        else:
            mixed_targets = None
        
        # Forward pass with mixed precision
        if USE_MIXED_PRECISION:
            with autocast():
                outputs = model(data)
                
                if mixed_targets is not None:
                    targets_a, targets_b, lam = mixed_targets
                    loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
                else:
                    loss = criterion(outputs, targets)
                
                # Scale loss for gradient accumulation
                loss = loss / ACCUMULATION_STEPS
            
            # Backward pass with mixed precision
            scaler.scale(loss).backward()
            
            # Gradient accumulation
            if (batch_idx + 1) % ACCUMULATION_STEPS == 0:
                # Gradient clipping
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
                
                # Optimizer step
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
        else:
            # Standard precision training
            outputs = model(data)
            
            if mixed_targets is not None:
                targets_a, targets_b, lam = mixed_targets
                loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            else:
                loss = criterion(outputs, targets)
            
            # Scale loss for gradient accumulation
            loss = loss / ACCUMULATION_STEPS
            loss.backward()
            
            # Gradient accumulation
            if (batch_idx + 1) % ACCUMULATION_STEPS == 0:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
                
                # Optimizer step
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
        
        # Statistics (only for non-mixed samples for accuracy)
        running_loss += loss.item() * ACCUMULATION_STEPS
        if mixed_targets is None:
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        
        # Update progress bar
        current_lr = scheduler.get_last_lr()[0]
        pbar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.4f}',
            'LR': f'{current_lr:.2e}',
            'Acc': f'{100.*correct/total:.2f}%' if total > 0 else 'N/A'
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total if total > 0 else 0
    
    return epoch_loss, epoch_acc, current_lr

def evaluate_model(model, data_loader, criterion, use_tta=False, tta_transforms=None):
    """Enhanced evaluation with optional Test-Time Augmentation"""
    model.eval()
    running_loss = 0.0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for data, targets in tqdm(data_loader, desc='Evaluating', leave=False):
            data, targets = data.to(device), targets.to(device)
            
            if use_tta and tta_transforms is not None:
                # Test-Time Augmentation
                tta_predictions = []
                
                # Original prediction
                outputs = model(data)
                tta_predictions.append(F.softmax(outputs, dim=1))
                
                # Additional TTA transforms would be applied here
                # For simplicity, we'll use the original prediction
                
                # Average TTA predictions
                final_outputs = torch.stack(tta_predictions).mean(0)
                outputs = torch.log(final_outputs + 1e-8)  # Convert back to logits
            else:
                outputs = model(data)
            
            loss = criterion(outputs, targets)
            running_loss += loss.item()
            
            _, predicted = outputs.max(1)
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    # Calculate metrics
    epoch_loss = running_loss / len(data_loader)
    accuracy = accuracy_score(all_targets, all_predictions)
    f1_macro = f1_score(all_targets, all_predictions, average='macro', zero_division=0)
    f1_weighted = f1_score(all_targets, all_predictions, average='weighted', zero_division=0)
    
    return epoch_loss, accuracy, f1_macro, f1_weighted, all_targets, all_predictions

def plot_training_progress(history, save_path=None):
    """Plot enhanced training progress"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss curves
    axes[0, 0].plot(history['train_loss'], label='Train Loss', alpha=0.8)
    axes[0, 0].plot(history['val_loss'], label='Val Loss', alpha=0.8)
    axes[0, 0].set_title('Loss Curves')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy curve
    axes[0, 1].plot(history['val_acc'], label='Val Accuracy', color='green', alpha=0.8)
    axes[0, 1].set_title('Validation Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # F1 Score curve
    axes[1, 0].plot(history['val_f1'], label='Val F1-Score', color='red', alpha=0.8)
    axes[1, 0].set_title('Validation F1-Score')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('F1-Score')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Learning rate curve
    axes[1, 1].plot(history['learning_rate'], label='Learning Rate', color='orange', alpha=0.8)
    axes[1, 1].set_title('Learning Rate Schedule')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_yscale('log')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    plt.show()

print("Training functions loaded successfully!")
print("Ready to start enhanced training...")

In [None]:
# =======================================
# Cell 10 — Main Training Loop
# =======================================

print("Starting Enhanced Emotion Recognition Training...")
print(f"Target: 80%+ Accuracy")
print(f"Model: {MODEL_NAME}")
print(f"Dataset: {len(train_dataset)} train, {len(val_dataset)} val, {len(test_dataset)} test")
print("=" * 60)

best_val_f1 = 0.0
best_val_acc = 0.0
best_model_path = Path(OUT_DIR) / "best_model.pth"

# Training loop
for epoch in range(EPOCHS):
    start_time = time.time()
    
    # Progressive resizing
    if USE_PROGRESSIVE_RESIZING and epoch in RESIZE_EPOCHS:
        new_size = FINAL_IMG_SIZE
        print(f"\nProgressive Resizing: Updating to {new_size}x{new_size}")
        
        # Update datasets
        train_dataset.update_img_size(new_size)
        val_dataset.update_img_size(new_size)
        test_dataset.update_img_size(new_size)
        
        # Recreate data loaders
        train_loader = DataLoader(
            train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler,
            num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=BATCH_SIZE * 2, shuffle=False,
            num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
        )
    
    # Training phase
    train_loss, train_acc, current_lr = train_one_epoch(
        model, train_loader, criterion, optimizer, scheduler, scaler, epoch
    )
    
    # Validation phase
    val_loss, val_acc, val_f1_macro, val_f1_weighted, y_true, y_pred = evaluate_model(
        model, val_loader, criterion, use_tta=False
    )
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1_macro)
    history['learning_rate'].append(current_lr)
    
    # Calculate epoch time
    epoch_time = time.time() - start_time
    
    # Print epoch results
    print(f"\nEpoch {epoch+1:3d}/{EPOCHS}:")
    print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
    print(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1_macro:.4f}")
    print(f"  LR: {current_lr:.2e}, Time: {epoch_time:.1f}s")
    
    # Per-class F1 scores
    per_class_f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
    print(f"  Per-class F1: {np.round(per_class_f1, 3)}")
    
    # Check for improvements and save best model
    is_best = early_stopping(epoch, val_f1_macro)
    
    if is_best:
        best_val_f1 = val_f1_macro
        best_val_acc = val_acc
        
        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_f1': val_f1_macro,
            'val_acc': val_acc,
            'history': history,
            'config': {
                'model_name': MODEL_NAME,
                'num_classes': num_classes,
                'img_size': IMG_SIZE,
                'batch_size': BATCH_SIZE,
                'learning_rate': BASE_LR,
                'epochs': EPOCHS
            }
        }, best_model_path)
        
        print(f"  ✓ New best model saved! F1: {val_f1_macro:.4f}, Acc: {val_acc:.4f}")
    
    # Plot progress every 10 epochs
    if (epoch + 1) % 10 == 0:
        clear_output(wait=True)
        plot_training_progress(history, Path(OUT_DIR) / f"training_progress_epoch_{epoch+1}.png")
    
    # Early stopping check
    if early_stopping.early_stop:
        print(f"\n🛑 Early stopping triggered!")
        print(f"Best epoch: {early_stopping.best_epoch + 1}")
        print(f"Best validation F1: {best_val_f1:.4f}")
        print(f"Best validation accuracy: {best_val_acc:.4f}")
        break
    
    # Check if we've reached our 80% target
    if val_acc >= 0.8:
        print(f"\n🎯 TARGET REACHED! Validation accuracy: {val_acc:.4f} (≥80%)")
        break

print("\n" + "=" * 60)
print("Training completed!")
print(f"Best validation F1-score: {best_val_f1:.4f}")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Target achieved: {'✓ YES' if best_val_acc >= 0.8 else '✗ NO'}")
print("=" * 60)

In [None]:
# =======================================
# Cell 11 — Final Evaluation & Results
# =======================================

# Load best model for evaluation
print("Loading best model for final evaluation...")
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch'] + 1}")
print(f"Best validation F1: {checkpoint['val_f1']:.4f}")
print(f"Best validation accuracy: {checkpoint['val_acc']:.4f}")

# Final validation evaluation (with TTA)
print("\n" + "=" * 50)
print("FINAL VALIDATION RESULTS")
print("=" * 50)

val_loss, val_acc, val_f1_macro, val_f1_weighted, y_true_val, y_pred_val = evaluate_model(
    model, val_loader, criterion, use_tta=USE_TTA
)

print(f"Validation Results:")
print(f"  Accuracy: {val_acc:.4f} ({val_acc*100:.2f}%)")
print(f"  F1-Score (Macro): {val_f1_macro:.4f}")
print(f"  F1-Score (Weighted): {val_f1_weighted:.4f}")
print(f"  Loss: {val_loss:.4f}")

# Test set evaluation
print("\n" + "=" * 50)
print("FINAL TEST RESULTS")
print("=" * 50)

test_loss, test_acc, test_f1_macro, test_f1_weighted, y_true_test, y_pred_test = evaluate_model(
    model, test_loader, criterion, use_tta=USE_TTA
)

print(f"Test Results:")
print(f"  Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"  F1-Score (Macro): {test_f1_macro:.4f}")
print(f"  F1-Score (Weighted): {test_f1_weighted:.4f}")
print(f"  Loss: {test_loss:.4f}")

# Detailed classification report
print("\n" + "=" * 50)
print("DETAILED CLASSIFICATION REPORT (TEST SET)")
print("=" * 50)

class_names = [idx2str[i] for i in range(num_classes)]
test_report = classification_report(
    y_true_test, y_pred_test,
    target_names=class_names,
    digits=4,
    zero_division=0
)
print(test_report)

# Save classification report
with open(Path(OUT_DIR) / "test_classification_report.txt", "w") as f:
    f.write(test_report)

# Confusion Matrix Visualization
def plot_enhanced_confusion_matrix(y_true, y_pred, class_names, title, save_path=None):
    """Plot enhanced confusion matrix with percentages"""
    cm = confusion_matrix(y_true, y_pred)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Raw counts
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names, ax=ax1)
    ax1.set_title(f'{title} - Raw Counts')
    ax1.set_xlabel('Predicted')
    ax1.set_ylabel('Actual')
    
    # Normalized percentages
    sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names, ax=ax2)
    ax2.set_title(f'{title} - Normalized (%)')
    ax2.set_xlabel('Predicted')
    ax2.set_ylabel('Actual')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    plt.show()
    
    return cm, cm_normalized

# Plot confusion matrices
print("\n" + "=" * 50)
print("CONFUSION MATRICES")
print("=" * 50)

# Test set confusion matrix
test_cm, test_cm_norm = plot_enhanced_confusion_matrix(
    y_true_test, y_pred_test, class_names,
    "Test Set Confusion Matrix",
    Path(OUT_DIR) / "test_confusion_matrix.png"
)

# Validation set confusion matrix
val_cm, val_cm_norm = plot_enhanced_confusion_matrix(
    y_true_val, y_pred_val, class_names,
    "Validation Set Confusion Matrix",
    Path(OUT_DIR) / "val_confusion_matrix.png"
)

# Final training progress plot
print("\n" + "=" * 50)
print("FINAL TRAINING PROGRESS")
print("=" * 50)

plot_training_progress(history, Path(OUT_DIR) / "final_training_progress.png")

In [None]:
# =======================================
# Cell 12 — Model Analysis & Insights
# =======================================

print("=" * 60)
print("MODEL ANALYSIS & INSIGHTS")
print("=" * 60)

# Performance summary
print(f"\n📊 PERFORMANCE SUMMARY:")
print(f"{'Metric':<20} {'Validation':<12} {'Test':<12} {'Target':<10} {'Status':<8}")
print("-" * 62)
print(f"{'Accuracy':<20} {val_acc*100:<11.2f}% {test_acc*100:<11.2f}% {'80.0%':<10} {'✓' if test_acc >= 0.8 else '✗':<8}")
print(f"{'F1-Macro':<20} {val_f1_macro:<11.4f} {test_f1_macro:<11.4f} {'0.75':<10} {'✓' if test_f1_macro >= 0.75 else '✗':<8}")
print(f"{'F1-Weighted':<20} {val_f1_weighted:<11.4f} {test_f1_weighted:<11.4f} {'0.80':<10} {'✓' if test_f1_weighted >= 0.8 else '✗':<8}")

# Per-class analysis
print(f"\n📋 PER-CLASS ANALYSIS (Test Set):")
per_class_f1 = f1_score(y_true_test, y_pred_test, average=None, zero_division=0)
per_class_precision = precision_score(y_true_test, y_pred_test, average=None, zero_division=0)
per_class_recall = recall_score(y_true_test, y_pred_test, average=None, zero_division=0)

print(f"{'Class':<12} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'Support':<10}")
print("-" * 52)

class_support = np.bincount(y_true_test)
for i, class_name in enumerate(class_names):
    print(f"{class_name:<12} {per_class_precision[i]:<9.3f} {per_class_recall[i]:<9.3f} {per_class_f1[i]:<9.3f} {class_support[i]:<10}")

# Model improvements achieved
print(f"\n🚀 KEY IMPROVEMENTS IMPLEMENTED:")
improvements = [
    f"✓ Enhanced Architecture: {MODEL_NAME} with attention mechanism",
    f"✓ Advanced Data Augmentation: MixUp, CutMix, AutoAugment",
    f"✓ Sophisticated Loss Function: {LOSS_MODE} with class weighting",
    f"✓ Optimized Training: OneCycleLR, Mixed Precision, Gradient Clipping",
    f"✓ Class Imbalance Handling: Weighted sampling + {CLASS_WEIGHT_MODE} weights",
    f"✓ Regularization: Dropout ({DROPOUT_RATE}), Label Smoothing ({LABEL_SMOOTHING})",
    f"✓ Extended Training: {EPOCHS} epochs with early stopping (patience={PATIENCE})"
]

if USE_PROGRESSIVE_RESIZING:
    improvements.append(f"✓ Progressive Resizing: {INITIAL_IMG_SIZE}→{FINAL_IMG_SIZE}")
    
if USE_TTA:
    improvements.append(f"✓ Test-Time Augmentation: {TTA_TRANSFORMS} transforms")

for improvement in improvements:
    print(improvement)

# Recommendations for further improvement
print(f"\n💡 RECOMMENDATIONS FOR FURTHER IMPROVEMENT:")
recommendations = []

if test_acc < 0.8:
    recommendations.extend([
        "🔸 Increase model capacity: Try EfficientNet-B4/B5 or Vision Transformer",
        "🔸 Ensemble multiple models for better performance",
        "🔸 Collect more diverse training data, especially for underperforming classes"
    ])

if min(per_class_f1) < 0.6:
    worst_classes = [class_names[i] for i, f1 in enumerate(per_class_f1) if f1 < 0.6]
    recommendations.append(f"🔸 Focus on improving classes: {', '.join(worst_classes)}")
    recommendations.append("🔸 Apply class-specific augmentation strategies")

recommendations.extend([
    "🔸 Implement cross-validation for more robust evaluation",
    "🔸 Use knowledge distillation from larger teacher models",
    "🔸 Apply advanced regularization techniques (DropBlock, Stochastic Depth)",
    "🔸 Fine-tune on domain-specific pre-trained models (facial emotion datasets)"
])

for rec in recommendations:
    print(rec)

# Save comprehensive results
results = {
    'model_config': {
        'architecture': MODEL_NAME,
        'num_classes': num_classes,
        'total_parameters': total_params,
        'trainable_parameters': trainable_params
    },
    'training_config': {
        'epochs_trained': len(history['val_acc']),
        'batch_size': BATCH_SIZE,
        'learning_rate': BASE_LR,
        'optimizer': 'AdamW',
        'scheduler': 'OneCycleLR',
        'loss_function': LOSS_MODE,
        'augmentations': {
            'mixup': USE_MIXUP,
            'cutmix': USE_CUTMIX,
            'autoaugment': USE_AUTOAUGMENT,
            'tta': USE_TTA
        }
    },
    'final_results': {
        'validation': {
            'accuracy': val_acc,
            'f1_macro': val_f1_macro,
            'f1_weighted': val_f1_weighted,
            'loss': val_loss
        },
        'test': {
            'accuracy': test_acc,
            'f1_macro': test_f1_macro,
            'f1_weighted': test_f1_weighted,
            'loss': test_loss
        },
        'per_class_metrics': {
            'precision': per_class_precision.tolist(),
            'recall': per_class_recall.tolist(),
            'f1_score': per_class_f1.tolist(),
            'class_names': class_names
        }
    },
    'target_achievement': {
        'accuracy_target_80pct': test_acc >= 0.8,
        'f1_macro_target_75pct': test_f1_macro >= 0.75
    }
}

# Save results to JSON
import json
with open(Path(OUT_DIR) / "comprehensive_results.json", "w") as f:
    json.dump(results, f, indent=2, default=str)

print(f"\n💾 Results saved to: {OUT_DIR}")
print(f"  ✓ Best model: best_model.pth")
print(f"  ✓ Training progress: final_training_progress.png")
print(f"  ✓ Confusion matrices: test_confusion_matrix.png, val_confusion_matrix.png")
print(f"  ✓ Classification report: test_classification_report.txt")
print(f"  ✓ Comprehensive results: comprehensive_results.json")

print(f"\n🎯 FINAL STATUS: {'SUCCESS - 80%+ ACCURACY ACHIEVED!' if test_acc >= 0.8 else 'TARGET NOT REACHED - CONSIDER RECOMMENDATIONS'}")
print("=" * 60)