In [None]:
````xml
<VSCode.Cell language="python">
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from models import UNet
from data_utils import load_dataset
from trainer import Trainer, get_default_config, visualize_results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
</VSCode.Cell>
<VSCode.Cell language="python">
config = get_default_config('UNet')
config.update({
    'epochs': 80,
    'learning_rate': 1e-3,
    'batch_size': 32,
    'dataset': 'mnist',
    'noise_type': 'gaussian',
    'noise_level': 0.25
})

train_loader, val_loader, test_loader, channels = load_dataset(
    config['dataset'], 
    config['batch_size'], 
    config['noise_type'], 
    config['noise_level']
)

print(f"Dataset: {config['dataset']}")
print(f"Channels: {channels}")
print(f"Train batches: {len(train_loader)}")
</VSCode.Cell>
<VSCode.Cell language="python">
model = UNet(n_channels=channels, n_classes=channels, bilinear=False)
trainer = Trainer(model, train_loader, val_loader, test_loader, device, config)

print("Model Architecture:")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
</VSCode.Cell>
<VSCode.Cell language="python">
trainer.train()
</VSCode.Cell>
<VSCode.Cell language="python">
test_loss, test_psnr, test_ssim = trainer.test()
</VSCode.Cell>
<VSCode.Cell language="python">
trainer.plot_metrics()
</VSCode.Cell>
<VSCode.Cell language="python">
visualize_results(model, test_loader, device, num_samples=6)
</VSCode.Cell>
<VSCode.Cell language="python">
noise_types = ['gaussian', 'speckle', 'salt_pepper']
results = {}

for noise_type in noise_types:
    print(f"\nTesting noise type: {noise_type}")
    _, _, test_loader_noise, _ = load_dataset(
        config['dataset'], 
        config['batch_size'], 
        noise_type, 
        0.2
    )
    
    model.eval()
    total_psnr = 0
    total_ssim = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in test_loader_noise:
            noisy, clean, _ = batch
            noisy, clean = noisy.to(device), clean.to(device)
            
            output = model(noisy)
            
            from data_utils import calculate_psnr, calculate_ssim
            total_psnr += calculate_psnr(output, clean).item()
            total_ssim += calculate_ssim(output, clean).item()
            num_batches += 1
    
    avg_psnr = total_psnr / num_batches
    avg_ssim = total_ssim / num_batches
    results[noise_type] = (avg_psnr, avg_ssim)
    print(f"PSNR: {avg_psnr:.2f}dB, SSIM: {avg_ssim:.4f}")

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
noise_types_list = list(results.keys())
psnr_values = [results[nt][0] for nt in noise_types_list]
plt.bar(noise_types_list, psnr_values)
plt.ylabel('PSNR (dB)')
plt.title('U-Net Performance on Different Noise Types')
plt.xticks(rotation=45)

plt.subplot(1, 2, 2)
ssim_values = [results[nt][1] for nt in noise_types_list]
plt.bar(noise_types_list, ssim_values)
plt.ylabel('SSIM')
plt.title('U-Net SSIM on Different Noise Types')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()
</VSCode.Cell>
````