In [None]:
````xml
<VSCode.Cell language="python">
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from models import RCAN
from data_utils import load_dataset
from trainer import Trainer, get_default_config, visualize_results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
</VSCode.Cell>
<VSCode.Cell language="python">
config = get_default_config('RCAN')
config.update({
    'epochs': 120,
    'learning_rate': 1e-4,
    'batch_size': 32,
    'dataset': 'mnist',
    'noise_type': 'gaussian',
    'noise_level': 0.2,
    'optimizer': 'adam'
})

train_loader, val_loader, test_loader, channels = load_dataset(
    config['dataset'], 
    config['batch_size'], 
    config['noise_type'], 
    config['noise_level']
)

print(f"Dataset: {config['dataset']}")
print(f"Channels: {channels}")
</VSCode.Cell>
<VSCode.Cell language="python">
model = RCAN(n_channels=channels, n_feats=64, n_blocks=12, reduction=16)
trainer = Trainer(model, train_loader, val_loader, test_loader, device, config)

print("RCAN Model:")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
</VSCode.Cell>
<VSCode.Cell language="python">
trainer.train()
</VSCode.Cell>
<VSCode.Cell language="python">
test_loss, test_psnr, test_ssim = trainer.test()
</VSCode.Cell>
<VSCode.Cell language="python">
trainer.plot_metrics()
</VSCode.Cell>
<VSCode.Cell language="python">
visualize_results(model, test_loader, device, num_samples=8)
</VSCode.Cell>
<VSCode.Cell language="python">
print("Ablation Study: Effect of Number of Blocks")
block_configs = [6, 8, 10, 12, 16]
results = []

for n_blocks in block_configs:
    print(f"\nTesting with {n_blocks} blocks")
    
    test_model = RCAN(n_channels=channels, n_feats=64, n_blocks=n_blocks, reduction=16).to(device)
    test_trainer = Trainer(test_model, train_loader, val_loader, test_loader, device, config)
    
    quick_config = config.copy()
    quick_config['epochs'] = 20
    test_trainer.config = quick_config
    
    test_trainer.train()
    _, test_psnr, test_ssim = test_trainer.test()
    
    results.append((n_blocks, test_psnr, test_ssim))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot([r[0] for r in results], [r[1] for r in results], 'o-')
plt.xlabel('Number of Blocks')
plt.ylabel('PSNR (dB)')
plt.title('RCAN: Effect of Number of Blocks on PSNR')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot([r[0] for r in results], [r[2] for r in results], 'o-')
plt.xlabel('Number of Blocks')
plt.ylabel('SSIM')
plt.title('RCAN: Effect of Number of Blocks on SSIM')
plt.grid(True)

plt.tight_layout()
plt.show()
</VSCode.Cell>
````