# ImgAE-Dx Enhanced Training with Advanced Features (Fixed)

This notebook implements research-grade training for medical image anomaly detection with:
- Validation split and monitoring
- Early stopping mechanism
- Learning rate scheduling (Cosine Annealing)
- Gradient clipping
- Warmup epochs
- **FIXED:** Advanced checkpoint management with proper save_frequency

## Quick Start:
1. Set `CONFIG['model_type'] = 'both'` to train both U-Net and Reversed AE
2. Set `CONFIG['keep_best_only'] = False` to enable regular checkpoints every N epochs
3. Run cells in order
4. Monitor training progress with validation metrics

## 1. Setup Colab Environment

In [None]:
# Check GPU and mount Google Drive
import torch
import os

# Check GPU
if torch.cuda.is_available():
    gpu_info = !nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
    print(f"GPU: {gpu_info[0]}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
else:
    print("⚠️ No GPU detected! Please enable GPU in Runtime > Change runtime type")

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
print("✅ Google Drive mounted")

## 2. Install ImgAE-Dx

In [None]:
# Clone repository if not exists
if not os.path.exists('/content/ImgAE-Dx'):
    !git clone https://github.com/kinhluan/ImgAE-Dx.git /content/ImgAE-Dx
    %cd /content/ImgAE-Dx
else:
    %cd /content/ImgAE-Dx
    !git pull

# Add the src directory to Python path
import sys
if '/content/ImgAE-Dx/src' not in sys.path:
    sys.path.append('/content/ImgAE-Dx/src')

# Install dependencies
!pip install -e .
!pip install datasets transformers accelerate
!pip install wandb --upgrade

print("✅ ImgAE-Dx installed")

## 3. Enhanced Configuration (FIXED)

In [None]:
# OPTIMIZED CONFIG - Phù hợp với dataset hf-vision/chest-xray-pneumonia
CONFIG = {
    # Model settings
    'model_type': 'both',  # 'unet', 'reversed_ae', or 'both'
    
    # Dataset source
    'dataset_source': 'huggingface',
    
    # HuggingFace Dataset settings
    'hf_dataset': 'hf-vision/chest-xray-pneumonia',
    'hf_config': None,
    'hf_split': 'train',
    'hf_streaming': False,
    'hf_token': None,
    'image_column': 'image',
    'label_column': 'labels',
    
    # Training settings - OPTIMIZED cho dataset nhỏ (1,341 NORMAL images)
    'samples': 1200,        # 90% của NORMAL images available (1,341 total)
    'epochs': 250,          # Tăng epochs vì ít data
    'batch_size': 16,       # Giảm batch size cho dataset nhỏ
    'learning_rate': 5e-6,  # Giảm LR cho stable training
    'image_size': 128,
    
    # Advanced training features - TUNED
    'validation_split': 0.25,         # Tăng validation vì val set gốc chỉ 8 images
    'early_stopping_patience': 50,    # Tăng patience vì cần train lâu
    'lr_scheduler': 'cosine',         # Better convergence
    'gradient_clip_norm': 0.5,        # Giảm cho dataset nhỏ
    'warmup_epochs': 10,              # Tăng warmup cho stability
    'min_lr_factor': 0.0001,          # LR cuối rất nhỏ
    
    # T4 optimizations - TUNED cho small dataset
    'mixed_precision': True,
    'memory_limit': 13,                  # Leave 3GB headroom
    'gradient_accumulation_steps': 4,    # Effective batch = 16*4 = 64
    'num_workers': 2,                    # Reduced for memory
    
    # Enhanced checkpointing
    'checkpoint_dir': '/content/drive/MyDrive/imgae_dx_enhanced_checkpoints',
    'save_frequency': 10,             # Save mỗi 10 epochs
    'keep_best_only': False,          # Lưu cả regular checkpoints
    'resume_from_checkpoint': None,   # Path to resume from
    'resume_model_type': None,        # Which model to resume
    
    # Logging
    'use_wandb': False,
    'wandb_project': 'imgae-dx-enhanced',
    'wandb_run_name': None,
    'log_frequency': 5,               # Log mỗi 5 batches (vì ít batches)
}

# Create directories
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)
os.makedirs('/content/outputs/logs', exist_ok=True)
os.makedirs('/content/outputs/plots', exist_ok=True)

# Validation checks
print("🔍 DATASET & CONFIG VALIDATION:")
print(f"   Dataset: {CONFIG['hf_dataset']}")
print(f"   Max NORMAL images available: 1,341")
print(f"   Requested samples: {CONFIG['samples']} {'✅' if CONFIG['samples'] <= 1341 else '❌ TOO MANY!'}")
print(f"   Train samples: {int(CONFIG['samples'] * (1 - CONFIG['validation_split']))}")
print(f"   Val samples: {int(CONFIG['samples'] * CONFIG['validation_split'])}")

print(f"\n🚀 OPTIMIZED CONFIG FOR SMALL DATASET:")
print(f"   Model: {CONFIG['model_type']}")
print(f"   Samples: {CONFIG['samples']} ({CONFIG['validation_split']:.1%} validation)")
print(f"   Epochs: {CONFIG['epochs']} (patience: {CONFIG['early_stopping_patience']})")
print(f"   Batch: {CONFIG['batch_size']} x {CONFIG['gradient_accumulation_steps']} = {CONFIG['batch_size'] * CONFIG['gradient_accumulation_steps']} effective")
print(f"   LR: {CONFIG['learning_rate']:.0e} → {CONFIG['learning_rate'] * CONFIG['min_lr_factor']:.0e}")
print(f"   Checkpoints: Every {CONFIG['save_frequency']} epochs + best model")

# Estimated training time
batches_per_epoch = int(CONFIG['samples'] * (1 - CONFIG['validation_split'])) // CONFIG['batch_size']
print(f"\n⏱️ TRAINING ESTIMATES:")
print(f"   Batches per epoch: {batches_per_epoch}")
print(f"   Est. time per model: ~{CONFIG['epochs'] * batches_per_epoch * 2 / 3600:.1f} hours")
print(f"   Total est. time (both models): ~{CONFIG['epochs'] * batches_per_epoch * 4 / 3600:.1f} hours")

print(f"\n💡 OPTIMIZATION NOTES:")
print(f"   - Small dataset requires more epochs and lower LR")
print(f"   - High validation split (25%) due to tiny original val set")
print(f"   - Gradient accumulation compensates for small batch size")
print(f"   - Early stopping patience increased for convergence")
print(f"   - Target: Train loss < 0.05, Val loss < 0.08 for good anomaly detection")

## 4. Setup Weights & Biases (Optional)

In [None]:
if CONFIG['use_wandb']:
    import wandb
    
    wandb.login()
    
    run_name = CONFIG['wandb_run_name'] or f"enhanced_{CONFIG['model_type']}_{CONFIG['samples']}samples"
    wandb.init(
        project=CONFIG['wandb_project'],
        name=run_name,
        config=CONFIG
    )
    print(f"✅ W&B initialized: {run_name}")
else:
    print("W&B logging disabled")

## 5. Enhanced Dataset Loading with Validation Split

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

print(f"📂 Loading dataset: {CONFIG['hf_dataset']}...")

try:
    # Authentication if needed
    auth_kwargs = {'use_auth_token': CONFIG['hf_token']} if CONFIG['hf_token'] else {}
    
    # Load dataset
    dataset = load_dataset(
        CONFIG['hf_dataset'],
        CONFIG['hf_config'],
        split=CONFIG['hf_split'],
        **auth_kwargs
    )
    
    # Filter for NORMAL images only (label = 0)
    # For unsupervised anomaly detection, we only train on normal images
    normal_dataset = dataset.filter(lambda x: x[CONFIG['label_column']] == 0)
    print(f"📊 Filtered to {len(normal_dataset)} NORMAL images from {len(dataset)} total")
    
    # Take specified number of samples
    if len(normal_dataset) > CONFIG['samples']:
        normal_dataset = normal_dataset.select(range(CONFIG['samples']))
    
    print(f"✅ Using {len(normal_dataset)} NORMAL images for training")
    print(f"Dataset features: {normal_dataset.features}")
    
except Exception as e:
    print(f"❌ Error loading dataset: {e}")
    raise

# Define transforms
transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Enhanced dataset wrapper
class HFImageDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, transform):
        self.dataset = hf_dataset
        self.transform = transform
        
    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Get image
        image = None
        for col in [CONFIG['image_column'], 'image', 'img', 'pixel_values']:
            if col in item:
                image = item[col]
                break
        
        if image is None:
            raise ValueError(f"No image column found. Available: {list(item.keys())}")
        
        # Convert to PIL if needed
        if not isinstance(image, Image.Image):
            if isinstance(image, np.ndarray):
                image = Image.fromarray(image)
            else:
                image = Image.fromarray(np.array(image))
        
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        return self.transform(image)

# Create full dataset
full_dataset = HFImageDataset(normal_dataset, transform)

# Split into train and validation
val_size = int(CONFIG['validation_split'] * len(full_dataset))
train_size = len(full_dataset) - val_size

train_dataset, val_dataset = random_split(
    full_dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)  # Reproducible split
)

print(f"📊 Dataset split:")
print(f"   Training: {len(train_dataset)} images")
print(f"   Validation: {len(val_dataset)} images")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
    drop_last=False
)

print(f"✅ DataLoaders created")
print(f"   Train batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")

## 6. Initialize Models

In [None]:
from imgae_dx.models import UNet, ReversedAutoencoder
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model initialization
models_to_train = []

if CONFIG['model_type'] in ['unet', 'both']:
    unet = UNet(
        in_channels=1,
        out_channels=1,
        features=[64, 128, 256, 512]
    ).to(device)
    models_to_train.append(('unet', unet))
    print("✅ U-Net initialized")

if CONFIG['model_type'] in ['reversed_ae', 'both']:
    reversed_ae = ReversedAutoencoder(
        in_channels=1,
        latent_dim=128,
        image_size=CONFIG['image_size']
    ).to(device)
    models_to_train.append(('reversed_ae', reversed_ae))
    print("✅ Reversed Autoencoder initialized")

# Count parameters
print("\n📊 Model Parameters:")
for name, model in models_to_train:
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"   {name}: {trainable_params:,} trainable parameters")

## 7. Enhanced Training with FIXED Checkpoint Logic

In [None]:
import time
import math
from torch.cuda.amp import GradScaler, autocast
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import torch.nn.utils as torch_utils

# T4 optimizations
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.cuda.set_per_process_memory_fraction(CONFIG['memory_limit'] / 16.0)
    print(f"✅ T4 optimizations enabled (Memory: {CONFIG['memory_limit']}GB)")

class EarlyStopping:
    """Early stopping utility class"""
    def __init__(self, patience=8, min_delta=0.0001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False
        
    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop

def cosine_warmup_scheduler(optimizer, warmup_epochs, total_epochs, min_lr_factor=0.1):
    """Create cosine annealing scheduler with warmup"""
    initial_lr = optimizer.param_groups[0]['lr']
    min_lr = initial_lr * min_lr_factor
    
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            # Warmup phase: linear increase from min_lr to initial_lr
            return min_lr_factor + (1 - min_lr_factor) * epoch / warmup_epochs
        else:
            # Cosine annealing phase
            progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
            return min_lr_factor + (1 - min_lr_factor) * 0.5 * (1 + math.cos(math.pi * progress))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def validate_model(model, val_loader, criterion, device):
    """Validate model on validation set"""
    model.eval()
    val_loss = 0.0
    val_batches = 0
    
    with torch.no_grad():
        for images in val_loader:
            images = images.to(device)
            
            if CONFIG['mixed_precision']:
                with autocast():
                    reconstructed = model(images)
                    loss = criterion(reconstructed, images)
            else:
                reconstructed = model(images)
                loss = criterion(reconstructed, images)
            
            val_loss += loss.item()
            val_batches += 1
    
    return val_loss / val_batches if val_batches > 0 else float('inf')

def enhanced_train_model(model_name, model, train_loader, val_loader, config):
    """Enhanced training function with FIXED checkpoint logic and gradient accumulation"""
    print(f"\n🚀 Enhanced Training: {model_name}")
    print(f"   Features: Validation monitoring, Early stopping, LR scheduling, Gradient clipping")
    print(f"   Gradient Accumulation: {config['gradient_accumulation_steps']} steps")
    print(f"   Checkpointing: Every {config['save_frequency']} epochs {'+ best model' if not config['keep_best_only'] else '(best only)'}")
    
    # Setup optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
    criterion = nn.MSELoss()
    
    # Mixed precision setup
    scaler = GradScaler() if config['mixed_precision'] else None
    
    # Learning rate scheduler
    if config['lr_scheduler'] == 'cosine':
        scheduler = cosine_warmup_scheduler(
            optimizer, 
            config['warmup_epochs'], 
            config['epochs'],
            config['min_lr_factor']
        )
    else:
        scheduler = None
    
    # Early stopping
    early_stopping = EarlyStopping(patience=config['early_stopping_patience'])
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'learning_rate': []
    }
    
    # Resume from checkpoint if specified
    start_epoch = 0
    best_val_loss = float('inf')
    
    if (config.get('resume_from_checkpoint') and 
        config.get('resume_model_type') == model_name):
        
        print(f"📂 Resuming {model_name} from checkpoint...")
        try:
            checkpoint = torch.load(config['resume_from_checkpoint'], map_location=device)
            
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint.get('val_loss', float('inf'))
                if 'history' in checkpoint:
                    history = checkpoint['history']
                print(f"✅ Resumed from epoch {start_epoch}, best val loss: {best_val_loss:.4f}")
            else:
                model.load_state_dict(checkpoint)
                print(f"✅ Loaded model weights, starting fresh training")
        except Exception as e:
            print(f"⚠️ Could not resume: {e}. Starting fresh...")
    
    # Training loop with gradient accumulation
    accumulation_steps = config['gradient_accumulation_steps']
    
    for epoch in range(start_epoch, config['epochs']):
        # Training phase
        model.train()
        train_loss = 0.0
        train_batches = 0
        accumulated_batches = 0
        
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']} [LR: {current_lr:.2e}]")
        
        # Zero gradients at the start
        optimizer.zero_grad()
        
        for batch_idx, images in enumerate(pbar):
            images = images.to(device)
            
            # Forward pass with mixed precision
            if config['mixed_precision']:
                with autocast():
                    reconstructed = model(images)
                    loss = criterion(reconstructed, images)
                    # Scale loss by accumulation steps
                    loss = loss / accumulation_steps
                
                # Backward pass
                scaler.scale(loss).backward()
                
                # Update weights every accumulation_steps batches
                if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                    if config.get('gradient_clip_norm'):
                        scaler.unscale_(optimizer)
                        torch_utils.clip_grad_norm_(model.parameters(), config['gradient_clip_norm'])
                    
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    accumulated_batches += 1
            else:
                reconstructed = model(images)
                loss = criterion(reconstructed, images)
                # Scale loss by accumulation steps
                loss = loss / accumulation_steps
                
                loss.backward()
                
                # Update weights every accumulation_steps batches
                if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                    if config.get('gradient_clip_norm'):
                        torch_utils.clip_grad_norm_(model.parameters(), config['gradient_clip_norm'])
                    
                    optimizer.step()
                    optimizer.zero_grad()
                    accumulated_batches += 1
            
            # Update metrics (use unscaled loss for tracking)
            train_loss += loss.item() * accumulation_steps
            train_batches += 1
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item() * accumulation_steps:.4f}',
                'avg_loss': f'{train_loss/train_batches:.4f}',
                'accum': f'{(batch_idx % accumulation_steps) + 1}/{accumulation_steps}'
            })
            
            # Log to W&B
            if config['use_wandb'] and batch_idx % config.get('log_frequency', 10) == 0:
                wandb.log({
                    f'{model_name}_batch_loss': loss.item() * accumulation_steps,
                    f'{model_name}_learning_rate': current_lr,
                    'epoch': epoch,
                    'batch': batch_idx
                })
        
        # Calculate epoch metrics
        avg_train_loss = train_loss / train_batches
        
        # Validation phase
        avg_val_loss = validate_model(model, val_loader, criterion, device)
        
        # Update history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['learning_rate'].append(current_lr)
        
        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}, LR = {current_lr:.2e}")
        
        # Log epoch metrics
        if config['use_wandb']:
            wandb.log({
                f'{model_name}_train_loss': avg_train_loss,
                f'{model_name}_val_loss': avg_val_loss,
                f'{model_name}_learning_rate': current_lr,
                'epoch': epoch + 1
            })
        
        # Update learning rate
        if scheduler:
            scheduler.step()
        
        # FIXED: Checkpoint saving logic
        is_best_model = False
        is_regular_save = (epoch + 1) % config['save_frequency'] == 0
        
        # Check if this is the best model so far
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            is_best_model = True
            
            # Always save best model
            best_path = f"{config['checkpoint_dir']}/{model_name}_best.pth"
            torch.save(model.state_dict(), best_path)
            print(f"✅ Best model saved: {best_path} (Val Loss: {best_val_loss:.4f})")
        
        # Save regular checkpoint based on save_frequency
        if is_regular_save:
            if config.get('keep_best_only', False):
                print(f"⏭️ Epoch {epoch+1}: Regular checkpoint skipped (keep_best_only=True)")
            else:
                checkpoint_path = f"{config['checkpoint_dir']}/{model_name}_epoch_{epoch+1}.pth"
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'history': history,
                    'config': config
                }, checkpoint_path)
                print(f"💾 Regular checkpoint saved: {checkpoint_path}")
        
        # Early stopping check
        if early_stopping(avg_val_loss):
            print(f"🛑 Early stopping triggered at epoch {epoch+1}")
            break
    
    return model, history

# Train models
all_histories = {}
trained_models = {}

print(f"🚀 Starting enhanced training for {len(models_to_train)} model(s)...")
print(f"💾 Checkpoint strategy: {'Best model only' if CONFIG['keep_best_only'] else f'Every {CONFIG["save_frequency"]} epochs + best model'}")

for model_name, model in models_to_train:
    start_time = time.time()
    
    # Enhanced training
    trained_model, history = enhanced_train_model(
        model_name, model, train_loader, val_loader, CONFIG
    )
    
    # Store results
    trained_models[model_name] = trained_model
    all_histories[model_name] = history
    
    # Calculate training time
    training_time = (time.time() - start_time) / 60
    print(f"\n✅ {model_name} enhanced training completed in {training_time:.1f} minutes")
    
    # Clear cache between models
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print(f"\n🎉 All models trained successfully with FIXED checkpoint logic and gradient accumulation!")

## 8. Enhanced Training Visualization

In [None]:
# Enhanced training visualization with parameter logging
fig, axes = plt.subplots(3, 3, figsize=(20, 15))

# Training info text for title
training_info = (f"Dataset: {CONFIG['hf_dataset'].split('/')[-1]} | "
                f"Samples: {CONFIG['samples']} | "
                f"Batch: {CONFIG['batch_size']}x{CONFIG['gradient_accumulation_steps']} | "
                f"LR: {CONFIG['learning_rate']:.0e}")

# Plot 1: Training and Validation Loss
for model_name, history in all_histories.items():
    epochs = range(1, len(history['train_loss']) + 1)
    axes[0, 0].plot(epochs, history['train_loss'], label=f'{model_name} train', linestyle='-', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], label=f'{model_name} val', linestyle='--', linewidth=2)

axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss (MSE)')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_yscale('log')  # Log scale for better visualization

# Plot 2: Learning Rate Schedule
for model_name, history in all_histories.items():
    epochs = range(1, len(history['learning_rate']) + 1)
    axes[0, 1].plot(epochs, history['learning_rate'], label=f'{model_name} LR', linewidth=2)

axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Learning Rate')
axes[0, 1].set_title('Learning Rate Schedule')
axes[0, 1].set_yscale('log')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Loss Improvement Rate (derivative)
for model_name, history in all_histories.items():
    train_losses = history['train_loss']
    val_losses = history['val_loss']
    
    # Calculate improvement rate (negative derivative)
    train_improvement = [-1 * (train_losses[i] - train_losses[i-1]) for i in range(1, len(train_losses))]
    val_improvement = [-1 * (val_losses[i] - val_losses[i-1]) for i in range(1, len(val_losses))]
    
    epochs = range(2, len(train_losses) + 1)
    axes[0, 2].plot(epochs, train_improvement, label=f'{model_name} train improve', linestyle='-', alpha=0.7)
    axes[0, 2].plot(epochs, val_improvement, label=f'{model_name} val improve', linestyle='--', alpha=0.7)

axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('Loss Improvement Rate')
axes[0, 2].set_title('Loss Improvement per Epoch')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)
axes[0, 2].axhline(y=0, color='red', linestyle=':', alpha=0.5)

# Plot 4: Best Validation Loss Progress
for model_name, history in all_histories.items():
    val_losses = history['val_loss']
    best_val_loss = [min(val_losses[:i+1]) for i in range(len(val_losses))]
    epochs = range(1, len(best_val_loss) + 1)
    axes[1, 0].plot(epochs, best_val_loss, label=f'{model_name} best val', linewidth=2)

axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Best Validation Loss')
axes[1, 0].set_title('Best Validation Loss Progress')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_yscale('log')

# Plot 5: Train vs Val Loss Correlation
colors = ['blue', 'red', 'green', 'purple']
for i, (model_name, history) in enumerate(all_histories.items()):
    axes[1, 1].scatter(history['train_loss'], history['val_loss'], 
                      label=f'{model_name}', alpha=0.7, s=30, c=colors[i % len(colors)])

# Add diagonal line
min_loss = min([min(h['train_loss'] + h['val_loss']) for h in all_histories.values()])
max_loss = max([max(h['train_loss'] + h['val_loss']) for h in all_histories.values()])
axes[1, 1].plot([min_loss, max_loss], [min_loss, max_loss], 'k--', alpha=0.5, label='Perfect correlation')

axes[1, 1].set_xlabel('Training Loss')
axes[1, 1].set_ylabel('Validation Loss')
axes[1, 1].set_title('Training vs Validation Loss Correlation')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_xscale('log')
axes[1, 1].set_yscale('log')

# Plot 6: Training Summary Statistics
training_stats = []
for model_name, history in all_histories.items():
    stats = {
        'Model': model_name,
        'Final Train Loss': f"{history['train_loss'][-1]:.4f}",
        'Final Val Loss': f"{history['val_loss'][-1]:.4f}",
        'Best Val Loss': f"{min(history['val_loss']):.4f}",
        'Best Val Epoch': f"{history['val_loss'].index(min(history['val_loss'])) + 1}",
        'Total Epochs': f"{len(history['train_loss'])}",
        'Final LR': f"{history['learning_rate'][-1]:.2e}",
        'LR Reduction': f"{history['learning_rate'][0] / history['learning_rate'][-1]:.1f}x",
        'Convergence': '✅' if history['val_loss'][-1] < 0.08 else '⚠️',
        'Overfitting': '⚠️' if history['train_loss'][-1] < history['val_loss'][-1] * 0.7 else '✅'
    }
    training_stats.append(stats)

# Create text summary
axes[1, 2].axis('off')
summary_text = "TRAINING SUMMARY\n" + "="*50 + "\n\n"

for stats in training_stats:
    summary_text += f"🤖 {stats['Model'].upper()}:\n"
    summary_text += f"   Final Train Loss: {stats['Final Train Loss']}\n"
    summary_text += f"   Final Val Loss: {stats['Final Val Loss']}\n"
    summary_text += f"   Best Val Loss: {stats['Best Val Loss']} (Epoch {stats['Best Val Epoch']})\n"
    summary_text += f"   Epochs: {stats['Total Epochs']}\n"
    summary_text += f"   LR: {stats['Final LR']} (reduced {stats['LR Reduction']})\n"
    summary_text += f"   Convergence: {stats['Convergence']} | Overfitting: {stats['Overfitting']}\n\n"

# Add config summary
summary_text += f"📊 CONFIGURATION:\n"
summary_text += f"   Dataset: {CONFIG['hf_dataset']}\n"
summary_text += f"   Samples: {CONFIG['samples']} ({CONFIG['validation_split']:.0%} val split)\n"
summary_text += f"   Batch Size: {CONFIG['batch_size']} x {CONFIG['gradient_accumulation_steps']} = {CONFIG['batch_size'] * CONFIG['gradient_accumulation_steps']}\n"
summary_text += f"   Learning Rate: {CONFIG['learning_rate']:.0e}\n"
summary_text += f"   Scheduler: {CONFIG['lr_scheduler']}\n"
summary_text += f"   Early Stopping: {CONFIG['early_stopping_patience']} epochs\n"
summary_text += f"   Gradient Clip: {CONFIG['gradient_clip_norm']}\n"

axes[1, 2].text(0.05, 0.95, summary_text, transform=axes[1, 2].transAxes, 
               fontsize=9, verticalalignment='top', fontfamily='monospace',
               bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8))

# Plot 7: Loss Smoothed (Moving Average)
window_size = max(5, len(list(all_histories.values())[0]['train_loss']) // 20)
for model_name, history in all_histories.items():
    train_losses = history['train_loss']
    val_losses = history['val_loss']
    
    # Calculate moving average
    def moving_average(data, window):
        return [sum(data[i:i+window])/window for i in range(len(data)-window+1)]
    
    if len(train_losses) >= window_size:
        train_smooth = moving_average(train_losses, window_size)
        val_smooth = moving_average(val_losses, window_size)
        epochs = range(window_size, len(train_losses) + 1)
        
        axes[2, 0].plot(epochs, train_smooth, label=f'{model_name} train (MA{window_size})', linestyle='-', alpha=0.8)
        axes[2, 0].plot(epochs, val_smooth, label=f'{model_name} val (MA{window_size})', linestyle='--', alpha=0.8)

axes[2, 0].set_xlabel('Epoch')
axes[2, 0].set_ylabel('Loss (Smoothed)')
axes[2, 0].set_title(f'Smoothed Loss (Moving Average, window={window_size})')
axes[2, 0].legend()
axes[2, 0].grid(True, alpha=0.3)
axes[2, 0].set_yscale('log')

# Plot 8: Validation Loss vs Learning Rate (LR Finder style)
for model_name, history in all_histories.items():
    axes[2, 1].scatter(history['learning_rate'], history['val_loss'], 
                      label=f'{model_name}', alpha=0.6, s=20)

axes[2, 1].set_xlabel('Learning Rate')
axes[2, 1].set_ylabel('Validation Loss')
axes[2, 1].set_title('Validation Loss vs Learning Rate')
axes[2, 1].set_xscale('log')
axes[2, 1].set_yscale('log')
axes[2, 1].legend()
axes[2, 1].grid(True, alpha=0.3)

# Plot 9: Training Progress Timeline
axes[2, 2].axis('off')
timeline_text = "TRAINING PROGRESS TIMELINE\n" + "="*30 + "\n\n"

for model_name, history in all_histories.items():
    epochs_trained = len(history['train_loss'])
    best_epoch = history['val_loss'].index(min(history['val_loss'])) + 1
    
    timeline_text += f"🤖 {model_name.upper()}:\n"
    timeline_text += f"   Epochs: {epochs_trained}\n"
    timeline_text += f"   Best model at epoch: {best_epoch}\n"
    timeline_text += f"   Loss reduction: {history['train_loss'][0]:.4f} → {history['train_loss'][-1]:.4f}\n"
    timeline_text += f"   Val loss reduction: {history['val_loss'][0]:.4f} → {min(history['val_loss']):.4f}\n\n"

# Add research question validation
timeline_text += "🎯 RESEARCH QUESTIONS:\n"
for model_name, history in all_histories.items():
    best_val_loss = min(history['val_loss'])
    rq1_passed = best_val_loss < 0.08  # Threshold for good anomaly detection
    timeline_text += f"   RQ1 ({model_name}): {'✅ PASSED' if rq1_passed else '❌ FAILED'} (Val Loss: {best_val_loss:.4f})\n"

timeline_text += f"\n💡 NOTES:\n"
timeline_text += f"   Target: Val Loss < 0.08 for anomaly detection\n"
timeline_text += f"   Early stopping patience: {CONFIG['early_stopping_patience']}\n"
timeline_text += f"   Checkpoints saved every {CONFIG['save_frequency']} epochs\n"

axes[2, 2].text(0.05, 0.95, timeline_text, transform=axes[2, 2].transAxes, 
               fontsize=9, verticalalignment='top', fontfamily='monospace',
               bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.8))

plt.suptitle(f'Enhanced Training Analysis - {training_info}', fontsize=14, y=0.98)
plt.tight_layout()

# FIXED: Show plot first before saving
plt.show()

# FIXED: Save with explicit backend and flush
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
plot_filename = f'{CONFIG["checkpoint_dir"]}/enhanced_training_analysis_{timestamp}.png'

# Method 1: Save current figure explicitly
fig.savefig(plot_filename, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
plt.close(fig)  # Close to free memory

# Method 2: Alternative save with PIL (more reliable on Colab)
try:
    # Create a new identical figure for saving
    fig_save, axes_save = plt.subplots(3, 3, figsize=(20, 15))
    
    # Re-plot everything for the save figure
    # Plot 1: Training and Validation Loss
    for model_name, history in all_histories.items():
        epochs = range(1, len(history['train_loss']) + 1)
        axes_save[0, 0].plot(epochs, history['train_loss'], label=f'{model_name} train', linestyle='-', linewidth=2)
        axes_save[0, 0].plot(epochs, history['val_loss'], label=f'{model_name} val', linestyle='--', linewidth=2)
    
    axes_save[0, 0].set_xlabel('Epoch')
    axes_save[0, 0].set_ylabel('Loss (MSE)')
    axes_save[0, 0].set_title('Training and Validation Loss')
    axes_save[0, 0].legend()
    axes_save[0, 0].grid(True, alpha=0.3)
    axes_save[0, 0].set_yscale('log')
    
    # Plot 2: Learning Rate Schedule  
    for model_name, history in all_histories.items():
        epochs = range(1, len(history['learning_rate']) + 1)
        axes_save[0, 1].plot(epochs, history['learning_rate'], label=f'{model_name} LR', linewidth=2)
    
    axes_save[0, 1].set_xlabel('Epoch')
    axes_save[0, 1].set_ylabel('Learning Rate')
    axes_save[0, 1].set_title('Learning Rate Schedule')
    axes_save[0, 1].set_yscale('log')
    axes_save[0, 1].legend()
    axes_save[0, 1].grid(True, alpha=0.3)
    
    # Add summary text
    summary_text_short = f"Training Results Summary\n"
    summary_text_short += f"Dataset: {CONFIG['hf_dataset']}\n"
    summary_text_short += f"Samples: {CONFIG['samples']}\n"
    for model_name, history in all_histories.items():
        summary_text_short += f"{model_name}: Best Val Loss = {min(history['val_loss']):.4f}\n"
    
    axes_save[0, 2].text(0.1, 0.5, summary_text_short, transform=axes_save[0, 2].transAxes,
                        fontsize=12, verticalalignment='center',
                        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.8))
    axes_save[0, 2].axis('off')
    
    # Hide other subplots for cleaner save
    for i in range(3):
        for j in range(3):
            if (i, j) not in [(0, 0), (0, 1), (0, 2)]:
                axes_save[i, j].axis('off')
    
    plt.suptitle(f'Training Analysis - {training_info}', fontsize=16, y=0.95)
    plt.tight_layout()
    
    # Save the simplified version
    plot_filename_simple = f'{CONFIG["checkpoint_dir"]}/training_analysis_simple_{timestamp}.png'
    fig_save.savefig(plot_filename_simple, dpi=300, bbox_inches='tight', 
                    facecolor='white', edgecolor='none', format='png')
    plt.close(fig_save)
    
    print(f"✅ Simple training plot saved: {plot_filename_simple}")
    
except Exception as e:
    print(f"⚠️ Alternative save method failed: {e}")

print(f"✅ Enhanced training analysis saved: {plot_filename}")

# Also save training stats as JSON
import json
stats_filename = f'{CONFIG["checkpoint_dir"]}/training_stats_{timestamp}.json'
detailed_stats = {
    'config': CONFIG,
    'training_results': {}
}

for model_name, history in all_histories.items():
    detailed_stats['training_results'][model_name] = {
        'history': history,
        'final_train_loss': history['train_loss'][-1],
        'final_val_loss': history['val_loss'][-1],
        'best_val_loss': min(history['val_loss']),
        'best_val_epoch': history['val_loss'].index(min(history['val_loss'])) + 1,
        'total_epochs': len(history['train_loss']),
        'convergence_achieved': min(history['val_loss']) < 0.08,
        'research_question_1_passed': min(history['val_loss']) < 0.08
    }

with open(stats_filename, 'w') as f:
    json.dump(detailed_stats, f, indent=2)

print(f"✅ Training statistics saved: {stats_filename}")
print(f"📁 All files saved to: {CONFIG['checkpoint_dir']}")

# FIXED: Force sync to Google Drive
import os
if os.path.exists('/content/drive'):
    os.system('sync')  # Force filesystem sync
    print("🔄 Files synced to Google Drive")

## 9. Test Reconstruction Quality

In [None]:
def visualize_enhanced_reconstructions(models, train_loader, val_loader, num_samples=5):
    """Enhanced reconstruction visualization with train and validation samples"""
    
    fig, axes = plt.subplots(len(models) + 1, num_samples * 2, figsize=(20, 3 * (len(models) + 1)))
    
    # Get samples from both train and validation
    train_batch = next(iter(train_loader))[:num_samples].to(device)
    val_batch = next(iter(val_loader))[:num_samples].to(device)
    
    sample_batches = [train_batch, val_batch]
    batch_labels = ['Train', 'Validation']
    
    # Original images
    for batch_idx, (batch, label) in enumerate(zip(sample_batches, batch_labels)):
        for i in range(num_samples):
            col_idx = batch_idx * num_samples + i
            img = batch[i].cpu().squeeze().numpy()
            axes[0, col_idx].imshow(img, cmap='gray', vmin=-1, vmax=1)
            axes[0, col_idx].axis('off')
            if col_idx == 0:
                axes[0, col_idx].set_ylabel('Original', fontsize=12)
            if batch_idx == 0 and i == num_samples // 2:
                axes[0, col_idx].set_title('Training Samples', fontsize=10)
            elif batch_idx == 1 and i == num_samples // 2:
                axes[0, col_idx].set_title('Validation Samples', fontsize=10)
    
    # Reconstructions for each model
    for model_idx, (model_name, model) in enumerate(models.items()):
        model.eval()
        
        for batch_idx, batch in enumerate(sample_batches):
            with torch.no_grad():
                recon = model(batch)
            
            for i in range(num_samples):
                col_idx = batch_idx * num_samples + i
                img = recon[i].cpu().squeeze().numpy()
                axes[model_idx + 1, col_idx].imshow(img, cmap='gray', vmin=-1, vmax=1)
                axes[model_idx + 1, col_idx].axis('off')
                
                if col_idx == 0:
                    axes[model_idx + 1, col_idx].set_ylabel(model_name, fontsize=12)
    
    plt.suptitle('Enhanced Reconstruction Quality Assessment (FIXED)', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Save figure
    plt.savefig(f'{CONFIG["checkpoint_dir"]}/enhanced_reconstruction_comparison_fixed.png', dpi=300, bbox_inches='tight')
    print("✅ Enhanced reconstruction comparison saved")

# Visualize enhanced reconstructions
if trained_models:
    visualize_enhanced_reconstructions(trained_models, train_loader, val_loader)

## 10. Enhanced Results Summary (FIXED)

In [None]:
import json
from datetime import datetime

# Enhanced summary with validation metrics
summary = {
    'metadata': {
        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'notebook_version': 'enhanced_training_fixed',
        'dataset': CONFIG['hf_dataset'],
        'total_samples': CONFIG['samples'],
        'train_samples': len(train_dataset),
        'val_samples': len(val_dataset),
        'validation_split': CONFIG['validation_split'],
        'checkpoint_strategy': f"Every {CONFIG['save_frequency']} epochs" + (" (best only)" if CONFIG['keep_best_only'] else " + best model")
    },
    'config': CONFIG,
    'training_results': {},
    'advanced_features_used': [
        'validation_monitoring',
        'early_stopping',
        'cosine_lr_scheduling_with_warmup',
        'gradient_clipping',
        'mixed_precision',
        'fixed_checkpoint_logic'
    ]
}

# Calculate detailed metrics for each model
for model_name, history in all_histories.items():
    train_losses = history['train_loss']
    val_losses = history['val_loss']
    learning_rates = history['learning_rate']
    
    summary['training_results'][model_name] = {
        'epochs_trained': len(train_losses),
        'final_train_loss': train_losses[-1],
        'final_val_loss': val_losses[-1],
        'best_train_loss': min(train_losses),
        'best_val_loss': min(val_losses),
        'best_val_epoch': val_losses.index(min(val_losses)) + 1,
        'final_learning_rate': learning_rates[-1],
        'initial_learning_rate': learning_rates[0],
        'lr_reduction_factor': learning_rates[0] / learning_rates[-1],
        'convergence_achieved': val_losses[-1] < 0.05,  # Threshold for good convergence
        'overfitting_detected': train_losses[-1] < val_losses[-1] * 0.7,  # Rough heuristic
        'training_stability': max(train_losses) / min(train_losses),  # Lower is more stable
        'validation_stability': max(val_losses) / min(val_losses),
        'checkpoints_saved': f"Best model + {'regular every ' + str(CONFIG['save_frequency']) + ' epochs' if not CONFIG['keep_best_only'] else 'best only'}"
    }

# Save enhanced summary
summary_path = f"{CONFIG['checkpoint_dir']}/enhanced_training_summary_fixed.json"
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

# Print comprehensive summary
print("\n🎯 ENHANCED TRAINING SUMMARY (FIXED CHECKPOINT LOGIC)")
print("=" * 80)
print(f"📊 Dataset: {CONFIG['hf_dataset']}")
print(f"📈 Samples: {CONFIG['samples']} ({len(train_dataset)} train + {len(val_dataset)} val)")
print(f"💾 Checkpoint Strategy: {summary['metadata']['checkpoint_strategy']}")
print(f"🔧 Advanced Features: {', '.join(summary['advanced_features_used'])}")

for model_name, results in summary['training_results'].items():
    print(f"\n🤖 {model_name.upper()} RESULTS:")
    print(f"   Epochs Trained: {results['epochs_trained']}")
    print(f"   Final Train Loss: {results['final_train_loss']:.4f}")
    print(f"   Final Val Loss: {results['final_val_loss']:.4f}")
    print(f"   Best Val Loss: {results['best_val_loss']:.4f} (Epoch {results['best_val_epoch']})")
    print(f"   Learning Rate: {results['initial_learning_rate']:.2e} → {results['final_learning_rate']:.2e} (reduction: {results['lr_reduction_factor']:.1f}x)")
    print(f"   Convergence: {'✅' if results['convergence_achieved'] else '⚠️'} {'Good' if results['convergence_achieved'] else 'Needs improvement'}")
    print(f"   Overfitting: {'⚠️' if results['overfitting_detected'] else '✅'} {'Detected' if results['overfitting_detected'] else 'Not detected'}")
    print(f"   Training Stability: {results['training_stability']:.2f}x (lower is better)")
    print(f"   Validation Stability: {results['validation_stability']:.2f}x (lower is better)")
    print(f"   Checkpoints: {results['checkpoints_saved']}")

print(f"\n✅ Enhanced summary saved to: {summary_path}")
print(f"📁 All checkpoints and artifacts saved to: {CONFIG['checkpoint_dir']}")

# Check saved files
import os
saved_files = os.listdir(CONFIG['checkpoint_dir'])
checkpoint_files = [f for f in saved_files if f.endswith('.pth')]
if checkpoint_files:
    print(f"\n💾 Checkpoint files saved:")
    for f in sorted(checkpoint_files):
        print(f"   {f}")

# Finish W&B run
if CONFIG['use_wandb']:
    wandb.finish()
    print("✅ W&B run finished")

print("\n🎉 ENHANCED TRAINING COMPLETE WITH FIXED CHECKPOINT LOGIC!")
print("💡 Next step: Use evaluation cells 21-26 from EVALUATION_GUIDE.md to answer research questions.")
print(f"🔧 Checkpoint behavior: {'Best model only saved' if CONFIG['keep_best_only'] else f'Regular checkpoints every {CONFIG["save_frequency"]} epochs + best model'}")