In [None]:
````xml
<VSCode.Cell language="python">
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
from models import DnCNN, UNet, RCAN, NAFNet, DRUNet
from data_utils import load_dataset, calculate_psnr, calculate_ssim
from trainer import Trainer, get_default_config

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.manual_seed(42)
np.random.seed(42)
</VSCode.Cell>
<VSCode.Cell language="python">
base_config = {
    'epochs': 50,
    'batch_size': 32,
    'dataset': 'mnist',
    'noise_type': 'gaussian',
    'noise_level': 0.2,
    'optimizer': 'adamw',
    'weight_decay': 1e-4
}

train_loader, val_loader, test_loader, channels = load_dataset(
    base_config['dataset'], 
    base_config['batch_size'], 
    base_config['noise_type'], 
    base_config['noise_level']
)

print(f"Loaded {base_config['dataset']} dataset")
print(f"Channels: {channels}")
print(f"Train batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")
</VSCode.Cell>
<VSCode.Cell language="python">
models = {
    'DnCNN': DnCNN(channels=channels, num_layers=17, features=64),
    'UNet': UNet(n_channels=channels, n_classes=channels, bilinear=False),
    'RCAN': RCAN(n_channels=channels, n_feats=64, n_blocks=8, reduction=16),
    'NAFNet': NAFNet(img_channel=channels, width=32, middle_blk_num=8, 
                     enc_blk_nums=[2, 2, 4, 8], dec_blk_nums=[2, 2, 2, 2]),
    'DRUNet': DRUNet(in_nc=channels, out_nc=channels, nc=[64, 128, 256, 512], nb=4)
}

for name, model in models.items():
    param_count = sum(p.numel() for p in model.parameters())
    print(f"{name}: {param_count:,} parameters")
</VSCode.Cell>
<VSCode.Cell language="python">
results = {}
training_times = {}

for model_name, model in models.items():
    print(f"\n{'='*50}")
    print(f"Training {model_name}")
    print(f"{'='*50}")
    
    config = base_config.copy()
    config['model_name'] = model_name
    
    if model_name == 'NAFNet':
        config['learning_rate'] = 2e-4
        config['batch_size'] = 16
    elif model_name == 'RCAN':
        config['learning_rate'] = 1e-4
    else:
        config['learning_rate'] = 1e-3
    
    trainer = Trainer(model, train_loader, val_loader, test_loader, device, config)
    
    start_time = time.time()
    trainer.train()
    training_time = time.time() - start_time
    
    test_loss, test_psnr, test_ssim = trainer.test()
    
    results[model_name] = {
        'test_loss': test_loss,
        'test_psnr': test_psnr,
        'test_ssim': test_ssim,
        'parameters': sum(p.numel() for p in model.parameters()),
        'training_time': training_time
    }
    
    print(f"{model_name} Results:")
    print(f"  Test Loss: {test_loss:.6f}")
    print(f"  Test PSNR: {test_psnr:.2f} dB")
    print(f"  Test SSIM: {test_ssim:.4f}")
    print(f"  Training Time: {training_time:.1f}s")
</VSCode.Cell>
<VSCode.Cell language="python">
print("\n" + "="*80)
print("COMPREHENSIVE MODEL COMPARISON")
print("="*80)

print(f"{'Model':<12} {'PSNR (dB)':<10} {'SSIM':<8} {'Params (M)':<12} {'Time (s)':<10}")
print("-" * 80)

for model_name, result in results.items():
    print(f"{model_name:<12} {result['test_psnr']:<10.2f} {result['test_ssim']:<8.4f} "
          f"{result['parameters']/1e6:<12.2f} {result['training_time']:<10.1f}")
</VSCode.Cell>
<VSCode.Cell language="python">
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

model_names = list(results.keys())
psnr_values = [results[name]['test_psnr'] for name in model_names]
ssim_values = [results[name]['test_ssim'] for name in model_names]
param_values = [results[name]['parameters']/1e6 for name in model_names]
time_values = [results[name]['training_time'] for name in model_names]

colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7']

axes[0, 0].bar(model_names, psnr_values, color=colors)
axes[0, 0].set_ylabel('PSNR (dB)')
axes[0, 0].set_title('Peak Signal-to-Noise Ratio')
axes[0, 0].tick_params(axis='x', rotation=45)
for i, v in enumerate(psnr_values):
    axes[0, 0].text(i, v + 0.1, f'{v:.1f}', ha='center', fontweight='bold')

axes[0, 1].bar(model_names, ssim_values, color=colors)
axes[0, 1].set_ylabel('SSIM')
axes[0, 1].set_title('Structural Similarity Index')
axes[0, 1].tick_params(axis='x', rotation=45)
for i, v in enumerate(ssim_values):
    axes[0, 1].text(i, v + 0.005, f'{v:.3f}', ha='center', fontweight='bold')

axes[1, 0].bar(model_names, param_values, color=colors)
axes[1, 0].set_ylabel('Parameters (Millions)')
axes[1, 0].set_title('Model Size')
axes[1, 0].tick_params(axis='x', rotation=45)
for i, v in enumerate(param_values):
    axes[1, 0].text(i, v + max(param_values)*0.02, f'{v:.1f}M', ha='center', fontweight='bold')

axes[1, 1].bar(model_names, time_values, color=colors)
axes[1, 1].set_ylabel('Training Time (seconds)')
axes[1, 1].set_title('Training Efficiency')
axes[1, 1].tick_params(axis='x', rotation=45)
for i, v in enumerate(time_values):
    axes[1, 1].text(i, v + max(time_values)*0.02, f'{v:.0f}s', ha='center', fontweight='bold')

plt.tight_layout()
plt.show()
</VSCode.Cell>
<VSCode.Cell language="python">
print("Performance vs Efficiency Analysis")
plt.figure(figsize=(12, 8))

for i, (model_name, result) in enumerate(results.items()):
    plt.scatter(result['parameters']/1e6, result['test_psnr'], 
               s=200, alpha=0.7, color=colors[i], label=model_name)
    plt.annotate(model_name, 
                (result['parameters']/1e6, result['test_psnr']),
                xytext=(5, 5), textcoords='offset points', fontsize=10)

plt.xlabel('Model Size (Million Parameters)')
plt.ylabel('PSNR (dB)')
plt.title('Performance vs Model Size Trade-off')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()
</VSCode.Cell>
<VSCode.Cell language="python">
print("Robustness Analysis: Testing on Multiple Noise Levels")
noise_levels = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35]
robustness_results = {name: [] for name in model_names}

for noise_level in noise_levels:
    print(f"\nTesting noise level: {noise_level}")
    _, _, test_loader_noise, _ = load_dataset(
        base_config['dataset'], 
        base_config['batch_size'], 
        base_config['noise_type'], 
        noise_level
    )
    
    for model_name, model in models.items():
        model.eval()
        total_psnr = 0
        num_batches = 0
        
        with torch.no_grad():
            for batch in test_loader_noise:
                if num_batches > 50:
                    break
                noisy, clean, _ = batch
                noisy, clean = noisy.to(device), clean.to(device)
                
                output = model(noisy)
                total_psnr += calculate_psnr(output, clean).item()
                num_batches += 1
        
        avg_psnr = total_psnr / num_batches
        robustness_results[model_name].append(avg_psnr)

plt.figure(figsize=(12, 6))
for model_name, psnr_list in robustness_results.items():
    plt.plot(noise_levels, psnr_list, 'o-', label=model_name, linewidth=2, markersize=6)

plt.xlabel('Noise Level')
plt.ylabel('PSNR (dB)')
plt.title('Model Robustness: Performance vs Noise Level')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
</VSCode.Cell>
<VSCode.Cell language="python">
print("Visual Comparison on Sample Images")
batch = next(iter(test_loader))
noisy, clean, _ = batch
noisy, clean = noisy[:4].to(device), clean[:4]

outputs = {}
for model_name, model in models.items():
    model.eval()
    with torch.no_grad():
        outputs[model_name] = model(noisy).cpu()

fig, axes = plt.subplots(len(model_names) + 2, 4, figsize=(16, 20))

for i in range(4):
    axes[0, i].imshow(noisy[i].cpu().squeeze(), cmap='gray')
    axes[0, i].set_title(f'Noisy Image {i+1}' if i == 0 else '')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(clean[i].squeeze(), cmap='gray')
    axes[1, i].set_title(f'Clean Image {i+1}' if i == 0 else '')
    axes[1, i].axis('off')
    
    for j, (model_name, output) in enumerate(outputs.items()):
        axes[j+2, i].imshow(output[i].squeeze(), cmap='gray')
        if i == 0:
            psnr_val = calculate_psnr(output[i:i+1], clean[i:i+1]).item()
            axes[j+2, i].set_title(f'{model_name}\nPSNR: {psnr_val:.1f}dB')
        axes[j+2, i].axis('off')

plt.tight_layout()
plt.show()
</VSCode.Cell>
<VSCode.Cell language="python">
print("FINAL RECOMMENDATIONS")
print("="*60)

best_psnr = max(results.items(), key=lambda x: x[1]['test_psnr'])
best_efficiency = min(results.items(), key=lambda x: x[1]['parameters'])
best_speed = min(results.items(), key=lambda x: x[1]['training_time'])

print(f"🏆 Best Overall Performance: {best_psnr[0]} ({best_psnr[1]['test_psnr']:.2f} dB)")
print(f"⚡ Most Efficient (Smallest): {best_efficiency[0]} ({best_efficiency[1]['parameters']/1e6:.1f}M params)")
print(f"🚀 Fastest Training: {best_speed[0]} ({best_speed[1]['training_time']:.1f}s)")

print("\n📊 Summary:")
print("- DnCNN: Classic, simple, fast training")
print("- U-Net: Good skip connections, balanced performance")
print("- RCAN: Attention mechanism, good for detail preservation")
print("- NAFNet: State-of-the-art, best performance but slower")
print("- DRUNet: Deep unfolding, good theoretical foundation")

print(f"\n🎯 For your MNIST denoising task:")
print(f"   • Best choice: {best_psnr[0]} for maximum quality")
print(f"   • Practical choice: DnCNN for speed and simplicity")
print(f"   • Research choice: NAFNet for cutting-edge performance")
</VSCode.Cell>
````