In [None]:
````xml
<VSCode.Cell language="python">
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from models import DRUNet
from data_utils import load_dataset
from trainer import Trainer, get_default_config, visualize_results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
</VSCode.Cell>
<VSCode.Cell language="python">
config = get_default_config('DRUNet')
config.update({
    'epochs': 90,
    'learning_rate': 1e-3,
    'batch_size': 32,
    'dataset': 'mnist',
    'noise_type': 'gaussian',
    'noise_level': 0.25,
    'optimizer': 'adam'
})

train_loader, val_loader, test_loader, channels = load_dataset(
    config['dataset'], 
    config['batch_size'], 
    config['noise_type'], 
    config['noise_level']
)

print(f"Dataset: {config['dataset']}")
print(f"Channels: {channels}")
</VSCode.Cell>
<VSCode.Cell language="python">
model = DRUNet(in_nc=channels, out_nc=channels, nc=[64, 128, 256, 512], nb=4)
trainer = Trainer(model, train_loader, val_loader, test_loader, device, config)

print("DRUNet Model:")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
</VSCode.Cell>
<VSCode.Cell language="python">
trainer.train()
</VSCode.Cell>
<VSCode.Cell language="python">
test_loss, test_psnr, test_ssim = trainer.test()
</VSCode.Cell>
<VSCode.Cell language="python">
trainer.plot_metrics()
</VSCode.Cell>
<VSCode.Cell language="python">
visualize_results(model, test_loader, device, num_samples=8)
</VSCode.Cell>
<VSCode.Cell language="python">
print("Model Comparison on Same Test Set")
from models import DnCNN, UNet, RCAN

models = {
    'DRUNet': DRUNet(in_nc=channels, out_nc=channels, nc=[64, 128, 256, 512], nb=4),
    'DnCNN': DnCNN(channels=channels, num_layers=17, features=64),
    'UNet': UNet(n_channels=channels, n_classes=channels),
    'RCAN': RCAN(n_channels=channels, n_feats=64, n_blocks=10)
}

results = {}
test_config = config.copy()
test_config['epochs'] = 30

for name, test_model in models.items():
    print(f"\nQuick training {name}...")
    test_trainer = Trainer(test_model, train_loader, val_loader, test_loader, device, test_config)
    test_trainer.train()
    _, test_psnr, test_ssim = test_trainer.test()
    results[name] = (test_psnr, test_ssim)

plt.figure(figsize=(15, 5))
plt.subplot(131)
model_names = list(results.keys())
psnr_values = [results[name][0] for name in model_names]
bars = plt.bar(model_names, psnr_values, color=['red', 'blue', 'green', 'orange'])
plt.ylabel('PSNR (dB)')
plt.title('Model Comparison - PSNR')
plt.xticks(rotation=45)
for i, v in enumerate(psnr_values):
    plt.text(i, v + 0.1, f'{v:.1f}', ha='center')

plt.subplot(132)
ssim_values = [results[name][1] for name in model_names]
bars = plt.bar(model_names, ssim_values, color=['red', 'blue', 'green', 'orange'])
plt.ylabel('SSIM')
plt.title('Model Comparison - SSIM')
plt.xticks(rotation=45)
for i, v in enumerate(ssim_values):
    plt.text(i, v + 0.005, f'{v:.3f}', ha='center')

plt.subplot(133)
param_counts = [sum(p.numel() for p in model.parameters()) for model in models.values()]
bars = plt.bar(model_names, param_counts, color=['red', 'blue', 'green', 'orange'])
plt.ylabel('Parameters (M)')
plt.title('Model Size Comparison')
plt.xticks(rotation=45)
for i, v in enumerate(param_counts):
    plt.text(i, v + max(param_counts)*0.01, f'{v/1e6:.1f}M', ha='center')

plt.tight_layout()
plt.show()
</VSCode.Cell>
````