# ImgAE-Dx Enhanced Training with Advanced Features

This notebook implements research-grade training for medical image anomaly detection with:
- Validation split and monitoring
- Early stopping mechanism
- Learning rate scheduling (Cosine Annealing)
- Gradient clipping
- Warmup epochs
- Advanced checkpoint management

## Quick Start:
1. Set `CONFIG['model_type'] = 'both'` to train both U-Net and Reversed AE
2. Run cells in order
3. Monitor training progress with validation metrics

## 1. Setup Colab Environment

In [None]:
# Check GPU and mount Google Drive
import torch
import os

# Check GPU
if torch.cuda.is_available():
    gpu_info = !nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
    print(f"GPU: {gpu_info[0]}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
else:
    print("⚠️ No GPU detected! Please enable GPU in Runtime > Change runtime type")

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
print("✅ Google Drive mounted")

## 2. Install ImgAE-Dx

In [None]:
# Clone repository if not exists
if not os.path.exists('/content/ImgAE-Dx'):
    !git clone https://github.com/kinhluan/ImgAE-Dx.git /content/ImgAE-Dx
    %cd /content/ImgAE-Dx
else:
    %cd /content/ImgAE-Dx
    !git pull

# Add the src directory to Python path
import sys
if '/content/ImgAE-Dx/src' not in sys.path:
    sys.path.append('/content/ImgAE-Dx/src')

# Install dependencies
!pip install -e .
!pip install datasets transformers accelerate
!pip install wandb --upgrade

print("✅ ImgAE-Dx installed")

## 3. Enhanced Configuration

In [None]:
# Research-grade T4-optimized configuration
CONFIG = {
    # Model settings
    'model_type': 'both',  # 'unet', 'reversed_ae', or 'both'
    
    # Dataset source
    'dataset_source': 'huggingface',
    
    # HuggingFace Dataset settings
    'hf_dataset': 'hf-vision/chest-xray-pneumonia',
    'hf_config': None,
    'hf_split': 'train',
    'hf_streaming': False,
    'hf_token': None,
    'image_column': 'image',
    'label_column': 'labels',
    
    # Training settings (Research-optimized)
    'samples': 8000,        # Increased for better generalization
    'epochs': 50,           # Sufficient with early stopping
    'batch_size': 32,       # Memory-safe for 2 models
    'learning_rate': 5e-5,  # Stable medical image learning rate
    'image_size': 128,
    
    # Advanced training features
    'validation_split': 0.15,         # Monitor overfitting
    'early_stopping_patience': 8,     # Prevent overtraining
    'lr_scheduler': 'cosine',         # Better convergence
    'gradient_clip_norm': 1.0,        # Training stability
    'warmup_epochs': 3,               # Gradual learning rate warmup
    'min_lr_factor': 0.1,             # Minimum LR as factor of initial LR
    
    # T4 optimizations
    'mixed_precision': True,
    'memory_limit': 13,      # Leave 3GB headroom
    'gradient_accumulation_steps': 1,
    'num_workers': 4,        # Increased for better data loading
    
    # Enhanced checkpointing
    'checkpoint_dir': '/content/drive/MyDrive/imgae_dx_checkpoints',
    'save_frequency': 5,
    'keep_best_only': True,           # Save disk space
    'resume_from_checkpoint': None,   # Path to resume from
    'resume_model_type': None,        # Which model to resume
    
    # Logging
    'use_wandb': False,
    'wandb_project': 'imgae-dx-t4-colab',
    'wandb_run_name': None,
    'log_frequency': 1,              # Log every N batches
}

# Create directories
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)
os.makedirs('/content/outputs/logs', exist_ok=True)
os.makedirs('/content/outputs/plots', exist_ok=True)

print("🚀 Enhanced Configuration Set!")
print(f"Model: {CONFIG['model_type']}")
print(f"Samples: {CONFIG['samples']} (with {CONFIG['validation_split']:.1%} validation)")
print(f"Epochs: {CONFIG['epochs']} (early stopping: {CONFIG['early_stopping_patience']})")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Learning rate: {CONFIG['learning_rate']} (scheduler: {CONFIG['lr_scheduler']})")
print(f"Advanced features: Validation monitoring, Early stopping, LR scheduling, Gradient clipping")

## 4. Setup Weights & Biases (Optional)

In [None]:
if CONFIG['use_wandb']:
    import wandb
    
    wandb.login()
    
    run_name = CONFIG['wandb_run_name'] or f"enhanced_{CONFIG['model_type']}_{CONFIG['samples']}samples"
    wandb.init(
        project=CONFIG['wandb_project'],
        name=run_name,
        config=CONFIG
    )
    print(f"✅ W&B initialized: {run_name}")
else:
    print("W&B logging disabled")

## 5. Enhanced Dataset Loading with Validation Split

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

print(f"📂 Loading dataset: {CONFIG['hf_dataset']}...")

try:
    # Authentication if needed
    auth_kwargs = {'use_auth_token': CONFIG['hf_token']} if CONFIG['hf_token'] else {}
    
    # Load dataset
    dataset = load_dataset(
        CONFIG['hf_dataset'],
        CONFIG['hf_config'],
        split=CONFIG['hf_split'],
        **auth_kwargs
    )
    
    # Filter for NORMAL images only (label = 0)
    # For unsupervised anomaly detection, we only train on normal images
    normal_dataset = dataset.filter(lambda x: x[CONFIG['label_column']] == 0)
    print(f"📊 Filtered to {len(normal_dataset)} NORMAL images from {len(dataset)} total")
    
    # Take specified number of samples
    if len(normal_dataset) > CONFIG['samples']:
        normal_dataset = normal_dataset.select(range(CONFIG['samples']))
    
    print(f"✅ Using {len(normal_dataset)} NORMAL images for training")
    print(f"Dataset features: {normal_dataset.features}")
    
except Exception as e:
    print(f"❌ Error loading dataset: {e}")
    raise

# Define transforms
transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Enhanced dataset wrapper
class HFImageDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, transform):
        self.dataset = hf_dataset
        self.transform = transform
        
    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Get image
        image = None
        for col in [CONFIG['image_column'], 'image', 'img', 'pixel_values']:
            if col in item:
                image = item[col]
                break
        
        if image is None:
            raise ValueError(f"No image column found. Available: {list(item.keys())}")
        
        # Convert to PIL if needed
        if not isinstance(image, Image.Image):
            if isinstance(image, np.ndarray):
                image = Image.fromarray(image)
            else:
                image = Image.fromarray(np.array(image))
        
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        return self.transform(image)

# Create full dataset
full_dataset = HFImageDataset(normal_dataset, transform)

# Split into train and validation
val_size = int(CONFIG['validation_split'] * len(full_dataset))
train_size = len(full_dataset) - val_size

train_dataset, val_dataset = random_split(
    full_dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)  # Reproducible split
)

print(f"📊 Dataset split:")
print(f"   Training: {len(train_dataset)} images")
print(f"   Validation: {len(val_dataset)} images")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
    drop_last=False
)

print(f"✅ DataLoaders created")
print(f"   Train batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")

## 6. Initialize Models

In [None]:
from imgae_dx.models import UNet, ReversedAutoencoder
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model initialization
models_to_train = []

if CONFIG['model_type'] in ['unet', 'both']:
    unet = UNet(
        in_channels=1,
        out_channels=1,
        features=[64, 128, 256, 512]
    ).to(device)
    models_to_train.append(('unet', unet))
    print("✅ U-Net initialized")

if CONFIG['model_type'] in ['reversed_ae', 'both']:
    reversed_ae = ReversedAutoencoder(
        in_channels=1,
        latent_dim=128,
        image_size=CONFIG['image_size']
    ).to(device)
    models_to_train.append(('reversed_ae', reversed_ae))
    print("✅ Reversed Autoencoder initialized")

# Count parameters
print("\n📊 Model Parameters:")
for name, model in models_to_train:
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"   {name}: {trainable_params:,} trainable parameters")

## 7. Enhanced Training with Advanced Features

In [None]:
import time
import math
from torch.cuda.amp import GradScaler, autocast
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import torch.nn.utils as torch_utils

# T4 optimizations
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.cuda.set_per_process_memory_fraction(CONFIG['memory_limit'] / 16.0)
    print(f"✅ T4 optimizations enabled (Memory: {CONFIG['memory_limit']}GB)")

class EarlyStopping:
    """Early stopping utility class"""
    def __init__(self, patience=8, min_delta=0.0001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False
        
    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop

def cosine_warmup_scheduler(optimizer, warmup_epochs, total_epochs, min_lr_factor=0.1):
    """Create cosine annealing scheduler with warmup"""
    initial_lr = optimizer.param_groups[0]['lr']
    min_lr = initial_lr * min_lr_factor
    
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            # Warmup phase: linear increase from min_lr to initial_lr
            return min_lr_factor + (1 - min_lr_factor) * epoch / warmup_epochs
        else:
            # Cosine annealing phase
            progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
            return min_lr_factor + (1 - min_lr_factor) * 0.5 * (1 + math.cos(math.pi * progress))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def validate_model(model, val_loader, criterion, device):
    """Validate model on validation set"""
    model.eval()
    val_loss = 0.0
    val_batches = 0
    
    with torch.no_grad():
        for images in val_loader:
            images = images.to(device)
            
            if CONFIG['mixed_precision']:
                with autocast():
                    reconstructed = model(images)
                    loss = criterion(reconstructed, images)
            else:
                reconstructed = model(images)
                loss = criterion(reconstructed, images)
            
            val_loss += loss.item()
            val_batches += 1
    
    return val_loss / val_batches if val_batches > 0 else float('inf')

def enhanced_train_model(model_name, model, train_loader, val_loader, config):
    """Enhanced training function with all advanced features"""
    print(f"\n🚀 Enhanced Training: {model_name}")
    print(f"   Features: Validation monitoring, Early stopping, LR scheduling, Gradient clipping")
    
    # Setup optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
    criterion = nn.MSELoss()
    
    # Mixed precision setup
    scaler = GradScaler() if config['mixed_precision'] else None
    
    # Learning rate scheduler
    if config['lr_scheduler'] == 'cosine':
        scheduler = cosine_warmup_scheduler(
            optimizer, 
            config['warmup_epochs'], 
            config['epochs'],
            config['min_lr_factor']
        )
    else:
        scheduler = None
    
    # Early stopping
    early_stopping = EarlyStopping(patience=config['early_stopping_patience'])
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'learning_rate': []
    }
    
    # Resume from checkpoint if specified
    start_epoch = 0
    best_val_loss = float('inf')
    
    if (config.get('resume_from_checkpoint') and 
        config.get('resume_model_type') == model_name):
        
        print(f"📂 Resuming {model_name} from checkpoint...")
        try:
            checkpoint = torch.load(config['resume_from_checkpoint'], map_location=device)
            
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint.get('val_loss', float('inf'))
                if 'history' in checkpoint:
                    history = checkpoint['history']
                print(f"✅ Resumed from epoch {start_epoch}, best val loss: {best_val_loss:.4f}")
            else:
                model.load_state_dict(checkpoint)
                print(f"✅ Loaded model weights, starting fresh training")
        except Exception as e:
            print(f"⚠️ Could not resume: {e}. Starting fresh...")
    
    # Training loop
    for epoch in range(start_epoch, config['epochs']):
        # Training phase
        model.train()
        train_loss = 0.0
        train_batches = 0
        
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']} [LR: {current_lr:.2e}]")
        
        for batch_idx, images in enumerate(pbar):
            images = images.to(device)
            
            # Forward pass with mixed precision
            if config['mixed_precision']:
                with autocast():
                    reconstructed = model(images)
                    loss = criterion(reconstructed, images)
                
                # Backward pass with gradient clipping
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                
                if config.get('gradient_clip_norm'):
                    scaler.unscale_(optimizer)
                    torch_utils.clip_grad_norm_(model.parameters(), config['gradient_clip_norm'])
                
                scaler.step(optimizer)
                scaler.update()
            else:
                reconstructed = model(images)
                loss = criterion(reconstructed, images)
                
                optimizer.zero_grad()
                loss.backward()
                
                if config.get('gradient_clip_norm'):
                    torch_utils.clip_grad_norm_(model.parameters(), config['gradient_clip_norm'])
                
                optimizer.step()
            
            # Update metrics
            train_loss += loss.item()
            train_batches += 1
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_loss': f'{train_loss/train_batches:.4f}'
            })
            
            # Log to W&B
            if config['use_wandb'] and batch_idx % config.get('log_frequency', 10) == 0:
                wandb.log({
                    f'{model_name}_batch_loss': loss.item(),
                    f'{model_name}_learning_rate': current_lr,
                    'epoch': epoch,
                    'batch': batch_idx
                })
        
        # Calculate epoch metrics
        avg_train_loss = train_loss / train_batches
        
        # Validation phase
        avg_val_loss = validate_model(model, val_loader, criterion, device)
        
        # Update history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['learning_rate'].append(current_lr)
        
        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}, LR = {current_lr:.2e}")
        
        # Log epoch metrics
        if config['use_wandb']:
            wandb.log({
                f'{model_name}_train_loss': avg_train_loss,
                f'{model_name}_val_loss': avg_val_loss,
                f'{model_name}_learning_rate': current_lr,
                'epoch': epoch + 1
            })
        
        # Update learning rate
        if scheduler:
            scheduler.step()
        
        # Save checkpoint
        save_checkpoint = False
        
        if (epoch + 1) % config['save_frequency'] == 0:
            save_checkpoint = True
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            save_checkpoint = True
            
            # Save best model
            best_path = f"{config['checkpoint_dir']}/{model_name}_best.pth"
            torch.save(model.state_dict(), best_path)
            print(f"✅ Best model saved: {best_path} (Val Loss: {best_val_loss:.4f})")
        
        if save_checkpoint and not config.get('keep_best_only', False):
            checkpoint_path = f"{config['checkpoint_dir']}/{model_name}_epoch_{epoch+1}.pth"
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'history': history,
                'config': config
            }, checkpoint_path)
        
        # Early stopping check
        if early_stopping(avg_val_loss):
            print(f"🛑 Early stopping triggered at epoch {epoch+1}")
            break
    
    return model, history

# Train models
all_histories = {}
trained_models = {}

print(f"🚀 Starting enhanced training for {len(models_to_train)} model(s)...")

for model_name, model in models_to_train:
    start_time = time.time()
    
    # Enhanced training
    trained_model, history = enhanced_train_model(
        model_name, model, train_loader, val_loader, CONFIG
    )
    
    # Store results
    trained_models[model_name] = trained_model
    all_histories[model_name] = history
    
    # Calculate training time
    training_time = (time.time() - start_time) / 60
    print(f"\n✅ {model_name} enhanced training completed in {training_time:.1f} minutes")
    
    # Clear cache between models
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print(f"\n🎉 All models trained successfully!")

## 8. Enhanced Training Visualization

In [None]:
# Enhanced training visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot 1: Training and Validation Loss
for model_name, history in all_histories.items():
    epochs = range(1, len(history['train_loss']) + 1)
    axes[0, 0].plot(epochs, history['train_loss'], label=f'{model_name} train', linestyle='-')
    axes[0, 0].plot(epochs, history['val_loss'], label=f'{model_name} val', linestyle='--')

axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss (MSE)')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Learning Rate Schedule
for model_name, history in all_histories.items():
    epochs = range(1, len(history['learning_rate']) + 1)
    axes[0, 1].plot(epochs, history['learning_rate'], label=f'{model_name} LR')

axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Learning Rate')
axes[0, 1].set_title('Learning Rate Schedule')
axes[0, 1].set_yscale('log')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Validation Loss Improvement
for model_name, history in all_histories.items():
    val_losses = history['val_loss']
    best_val_loss = [min(val_losses[:i+1]) for i in range(len(val_losses))]
    epochs = range(1, len(best_val_loss) + 1)
    axes[1, 0].plot(epochs, best_val_loss, label=f'{model_name} best val')

axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Best Validation Loss')
axes[1, 0].set_title('Best Validation Loss Progress')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Train vs Val Loss Correlation
for model_name, history in all_histories.items():
    axes[1, 1].scatter(history['train_loss'], history['val_loss'], 
                      label=f'{model_name}', alpha=0.7, s=30)

# Add diagonal line
min_loss = min([min(h['train_loss'] + h['val_loss']) for h in all_histories.values()])
max_loss = max([max(h['train_loss'] + h['val_loss']) for h in all_histories.values()])
axes[1, 1].plot([min_loss, max_loss], [min_loss, max_loss], 'k--', alpha=0.5, label='Perfect correlation')

axes[1, 1].set_xlabel('Training Loss')
axes[1, 1].set_ylabel('Validation Loss')
axes[1, 1].set_title('Training vs Validation Loss Correlation')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle('Enhanced Training Analysis', fontsize=16, y=0.98)
plt.tight_layout()
plt.show()

# Save enhanced plots
plt.savefig(f'{CONFIG["checkpoint_dir"]}/enhanced_training_analysis.png', dpi=300, bbox_inches='tight')
print("✅ Enhanced training analysis saved")

## 9. Test Reconstruction Quality

In [None]:
def visualize_enhanced_reconstructions(models, train_loader, val_loader, num_samples=5):
    """Enhanced reconstruction visualization with train and validation samples"""
    
    fig, axes = plt.subplots(len(models) + 1, num_samples * 2, figsize=(20, 3 * (len(models) + 1)))
    
    # Get samples from both train and validation
    train_batch = next(iter(train_loader))[:num_samples].to(device)
    val_batch = next(iter(val_loader))[:num_samples].to(device)
    
    sample_batches = [train_batch, val_batch]
    batch_labels = ['Train', 'Validation']
    
    # Original images
    for batch_idx, (batch, label) in enumerate(zip(sample_batches, batch_labels)):
        for i in range(num_samples):
            col_idx = batch_idx * num_samples + i
            img = batch[i].cpu().squeeze().numpy()
            axes[0, col_idx].imshow(img, cmap='gray', vmin=-1, vmax=1)
            axes[0, col_idx].axis('off')
            if col_idx == 0:
                axes[0, col_idx].set_ylabel('Original', fontsize=12)
            if batch_idx == 0 and i == num_samples // 2:
                axes[0, col_idx].set_title('Training Samples', fontsize=10)
            elif batch_idx == 1 and i == num_samples // 2:
                axes[0, col_idx].set_title('Validation Samples', fontsize=10)
    
    # Reconstructions for each model
    for model_idx, (model_name, model) in enumerate(models.items()):
        model.eval()
        
        for batch_idx, batch in enumerate(sample_batches):
            with torch.no_grad():
                recon = model(batch)
            
            for i in range(num_samples):
                col_idx = batch_idx * num_samples + i
                img = recon[i].cpu().squeeze().numpy()
                axes[model_idx + 1, col_idx].imshow(img, cmap='gray', vmin=-1, vmax=1)
                axes[model_idx + 1, col_idx].axis('off')
                
                if col_idx == 0:
                    axes[model_idx + 1, col_idx].set_ylabel(model_name, fontsize=12)
    
    plt.suptitle('Enhanced Reconstruction Quality Assessment', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Save figure
    plt.savefig(f'{CONFIG["checkpoint_dir"]}/enhanced_reconstruction_comparison.png', dpi=300, bbox_inches='tight')
    print("✅ Enhanced reconstruction comparison saved")

# Visualize enhanced reconstructions
if trained_models:
    visualize_enhanced_reconstructions(trained_models, train_loader, val_loader)

## 10. Enhanced Results Summary

In [None]:
import json
from datetime import datetime

# Enhanced summary with validation metrics
summary = {
    'metadata': {
        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'notebook_version': 'enhanced_training',
        'dataset': CONFIG['hf_dataset'],
        'total_samples': CONFIG['samples'],
        'train_samples': len(train_dataset),
        'val_samples': len(val_dataset),
        'validation_split': CONFIG['validation_split']
    },
    'config': CONFIG,
    'training_results': {},
    'advanced_features_used': [
        'validation_monitoring',
        'early_stopping',
        'cosine_lr_scheduling',
        'gradient_clipping',
        'warmup_epochs',
        'mixed_precision'
    ]
}

# Calculate detailed metrics for each model
for model_name, history in all_histories.items():
    train_losses = history['train_loss']
    val_losses = history['val_loss']
    learning_rates = history['learning_rate']
    
    summary['training_results'][model_name] = {
        'epochs_trained': len(train_losses),
        'final_train_loss': train_losses[-1],
        'final_val_loss': val_losses[-1],
        'best_train_loss': min(train_losses),
        'best_val_loss': min(val_losses),
        'best_val_epoch': val_losses.index(min(val_losses)) + 1,
        'final_learning_rate': learning_rates[-1],
        'convergence_achieved': val_losses[-1] < 0.05,  # Threshold for good convergence
        'overfitting_detected': train_losses[-1] < val_losses[-1] * 0.5,  # Rough heuristic
        'training_stability': max(train_losses) / min(train_losses),  # Lower is more stable
    }

# Save enhanced summary
summary_path = f"{CONFIG['checkpoint_dir']}/enhanced_training_summary.json"
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

# Print comprehensive summary
print("\n🎯 ENHANCED TRAINING SUMMARY")
print("=" * 80)
print(f"📊 Dataset: {CONFIG['hf_dataset']}")
print(f"📈 Samples: {CONFIG['samples']} ({len(train_dataset)} train + {len(val_dataset)} val)")
print(f"🔧 Advanced Features: {', '.join(summary['advanced_features_used'])}")

for model_name, results in summary['training_results'].items():
    print(f"\n🤖 {model_name.upper()} RESULTS:")
    print(f"   Epochs Trained: {results['epochs_trained']}")
    print(f"   Final Train Loss: {results['final_train_loss']:.4f}")
    print(f"   Final Val Loss: {results['final_val_loss']:.4f}")
    print(f"   Best Val Loss: {results['best_val_loss']:.4f} (Epoch {results['best_val_epoch']})")
    print(f"   Final Learning Rate: {results['final_learning_rate']:.2e}")
    print(f"   Convergence: {'✅' if results['convergence_achieved'] else '⚠️'} {'Good' if results['convergence_achieved'] else 'Needs improvement'}")
    print(f"   Overfitting: {'⚠️' if results['overfitting_detected'] else '✅'} {'Detected' if results['overfitting_detected'] else 'Not detected'}")
    print(f"   Training Stability: {results['training_stability']:.2f}x (lower is better)")

print(f"\n✅ Enhanced summary saved to: {summary_path}")
print(f"📁 All checkpoints and artifacts saved to: {CONFIG['checkpoint_dir']}")

# Finish W&B run
if CONFIG['use_wandb']:
    wandb.finish()
    print("✅ W&B run finished")

print("\n🎉 ENHANCED TRAINING COMPLETE! Ready for evaluation phase.")
print("💡 Next step: Use evaluation cells 21-26 from EVALUATION_GUIDE.md to answer research questions.")