# Initiation

In [1]:
import os
import random
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchmetrics import PeakSignalNoiseRatio as PSNR, StructuralSimilarityIndexMeasure as SSIM
from piq import LPIPS
import os
import torch.nn.functional as F
from torchvision.models import vgg19
from torchvision.utils import save_image
from torchvision.datasets import DatasetFolder
from datetime import datetime
from sklearn.model_selection import train_test_split
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
from copy import deepcopy

In [2]:
def set_seed(seed=42):
    """Set all random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed (choose any number you like)
set_seed(42)

# DataLoaders


In [3]:
class RandomAugment:
    def __init__(self):
        self.color_jitter = transforms.ColorJitter(
            contrast=0.05,
            saturation=0.1,
            hue=0.2
        )
        self.gaussian_blur = transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.5))
        # self.to_grayscale = transforms.RandomGrayscale(p=0.3)
        self.invert = transforms.RandomInvert(p=0.3)
        
    def __call__(self, img):
        # Apply augmentations in random order
        aug_order = random.sample(['color', 'blur', 'invert'], k=random.randint(0, 3))
        
        for aug in aug_order:
            if aug == 'color' and random.random() < 0.7:  # 70% chance
                img = self.color_jitter(img)
            elif aug == 'blur' and random.random() < 0.5:  # 50% chance
                img = self.gaussian_blur(img)
            elif aug == 'invert':
                img = self.invert(img)
                
        return img
    
def show_true_images(low_batch, high_batch, n=4):

    """
    Shows EXACTLY what's in your dataset with correct color handling.
    Works for both normalized and unnormalized images.
    """

    plt.figure(figsize=(18, 8))

    for i in range(min(n, len(low_batch))):
        # --- Low-light ---
        plt.subplot(2, n, i+1)
        low_img = low_batch[i].permute(1, 2, 0).numpy()

        # Handle normalization if present
        if low_img.min() < 0:  # Likely normalized
            low_img = (low_img * 0.5 + 0.5)  # Reverse imagenet norm

        # Ensure proper image range
        low_img = np.clip(low_img, 0, 1)
        plt.imshow(low_img)
        plt.title(f"Low")
        plt.axis('off')

        # --- Normal-light ---
        plt.subplot(2, n, n+i+1)
        high_img = high_batch[i].permute(1, 2, 0).numpy()

        if high_img.min() < 0:  # Likely normalized
            high_img = (high_img * 0.5 + 0.5)

        high_img = np.clip(high_img, 0, 1)
        plt.imshow(high_img)
        plt.title(f"High")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# Run Data

In [4]:
# Your existing dataset code
dataset_path1 = r"/home/ahansviar2/Deep Learning Project (GAN for Light)/downloaded_images"
train_path = f'{dataset_path1}/train'
val_path = f'{dataset_path1}/val'
test_path = f'{dataset_path1}/test'

# Usage in Dataset
transform = transforms.Compose([
    # RandomAugment(),  # Your custom augmentations
    # transforms.GaussianBlur(kernel_size=3, sigma=(0.1,0.5)),
    # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Scales to [-1, 1]
    transforms.Resize((256, 256)), 
    transforms.ToTensor(),
])

low_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

high_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [5]:
class CleanDataset(Dataset):
    def __init__(self, root_dir, low_transform=None, high_transform=None):
        self.root_dir = root_dir
        self.low_transform = low_transform
        self.high_transform = high_transform
        self.low_dir = os.path.join(root_dir, "low")
        self.high_dir = os.path.join(root_dir, "high")
        self.image_names = sorted(os.listdir(self.low_dir))  # Ensures matching pairs

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        low_img_path = os.path.join(self.low_dir, self.image_names[idx])
        high_img_path = os.path.join(self.high_dir, self.image_names[idx])

        low_img = Image.open(low_img_path).convert("RGB")
        high_img = Image.open(high_img_path).convert("RGB")

        # Apply separate transforms
        if self.low_transform:
            low_img = self.low_transform(low_img)
        if self.high_transform:
            high_img = self.high_transform(high_img)
            
        return low_img, high_img, os.path.basename(low_img_path)

# Initialize datasets ONCE with correct parameters
train_dataset = CleanDataset(
    root_dir=train_path, 
    low_transform=low_transform,
    high_transform=high_transform
)

val_dataset = CleanDataset(
    root_dir=val_path,
    low_transform=low_transform,
    high_transform=high_transform
)

test_dataset = CleanDataset(
    root_dir=test_path,
    low_transform=low_transform,
    high_transform=high_transform
)

# DataLoader setup (this part is correct)
batch_size = 8

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=1
)

device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda:5


## Data Checking

In [6]:
for split in ['train', 'val', 'test']:
    low_dir = os.path.join(dataset_path1, split, 'low')
    high_dir = os.path.join(dataset_path1, split, 'high')
    low_files = set(os.listdir(low_dir))
    high_files = set(os.listdir(high_dir))
    assert low_files == high_files, f"Mismatch in {split} set!"
print("✅ All datasets have matched low/high pairs.")

✅ All datasets have matched low/high pairs.


In [7]:
# Test one batch
low_batch, high_batch, _ = next(iter(train_loader))
print(f"Low batch shape: {low_batch.shape}")
print(f"High batch shape: {high_batch.shape}")
print(f"Pixel range - Low: [{low_batch.min():.2f}, {low_batch.max():.2f}]")
print(f"Pixel range - High: [{high_batch.min():.2f}, {high_batch.max():.2f}]")

Low batch shape: torch.Size([8, 3, 256, 256])
High batch shape: torch.Size([8, 3, 256, 256])
Pixel range - Low: [0.00, 1.00]
Pixel range - High: [0.00, 1.00]


In [8]:
def validate_image_pairs(dataset, n_samples=5):
    """Plot random pairs to visually inspect alignment/quality"""
    indices = random.sample(range(len(dataset)), n_samples)
    for idx in indices:
        low, high, _ = dataset[idx]
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(low.permute(1, 2, 0).clamp(0, 1))
        plt.title("Low-light")
        plt.subplot(1, 2, 2)
        plt.imshow(high.permute(1, 2, 0).clamp(0, 1))
        plt.title("Normal-light")
        plt.show()

# validate_image_pairs(train_dataset)

In [9]:
def analyze_channel_stats(loader):
    """Calculate mean/std per channel across dataset"""
    channels_sum, channels_sq_sum, num_batches = 0, 0, 0
    for low, high, _, in loader:
        channels_sum += torch.mean(low, dim=[0, 2, 3])
        channels_sq_sum += torch.mean(low**2, dim=[0, 2, 3])
        num_batches += 1
    mean = channels_sum / num_batches
    std = (channels_sq_sum / num_batches - mean**2)**0.5
    print(f"Low-light stats - Mean: {mean}, Std: {std}")

analyze_channel_stats(train_loader)

Low-light stats - Mean: tensor([0.1954, 0.1936, 0.1911]), Std: tensor([0.3196, 0.3206, 0.3242])


In [10]:
low, high, _ = next(iter(train_loader))
print(f"Batch memory: {low.element_size() * low.nelement() / 1024**2:.2f} MB")

Batch memory: 6.00 MB


# MODEL ARCHITECTURE

## Diffusion Experiment 2

In [11]:
import math
import os
import time
import torch
import torch.nn.functional as F
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
from piq import LPIPS
from tqdm import tqdm
from torch import nn, optim

class TimeEmbedding(nn.Module):
    def __init__(self, dim, device):
        super().__init__()
        self.dim = dim
        half_dim = dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        self.register_buffer('emb', emb)

    def forward(self, t):
        emb = t[:, None] * self.emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm = nn.BatchNorm2d(out_ch)
        self.act = nn.SiLU()

    def forward(self, x, t):
        h = self.conv1(x)
        time_emb = self.act(self.time_mlp(t))
        h = h + time_emb[:, :, None, None]
        h = self.norm(h)
        h = self.act(h)
        h = self.conv2(h)
        h = self.norm(h)
        h = self.act(h)
        return h
    
class UNet(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.time_mlp = nn.Sequential(
            TimeEmbedding(32, device),
            nn.Linear(32, 32),
            nn.SiLU(),
            nn.Linear(32, 32)
        )
        
        # Down blocks
        self.down1 = UNetBlock(3, 64, 32)
        self.down2 = UNetBlock(64, 128, 32)
        self.down3 = UNetBlock(128, 256, 32)
        
        # Bottleneck
        self.bottleneck = UNetBlock(256, 512, 32)
        
        # Up blocks
        self.up1 = UNetBlock(512 + 256, 256, 32)
        self.up2 = UNetBlock(256 + 128, 128, 32)
        self.up3 = UNetBlock(128 + 64, 64, 32)
        
        # Output
        self.out = nn.Conv2d(64, 3, 1)

    def forward(self, x, t):
        t = self.time_mlp(t)
        d1 = self.down1(x, t)
        d2 = self.down2(F.max_pool2d(d1, 2), t)
        d3 = self.down3(F.max_pool2d(d2, 2), t)
        
        h = F.max_pool2d(d3, 2)
        h = self.bottleneck(h, t)
        
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d3], dim=1)
        h = self.up1(h, t)
        
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d2], dim=1)
        h = self.up2(h, t)
        
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d1], dim=1)
        h = self.up3(h, t)
        
        return self.out(h)

class FrequencyCompensatedDecoder(nn.Module):
    def __init__(self, latent_dim=256, output_ch=3):
        super().__init__()
        # Input projection to match latent_dim
        self.input_proj = nn.Conv2d(3, latent_dim, 1)  # Changed from 259 to 3
        
        # LR feature extractor
        self.lr_feature_extractor = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.GroupNorm(32, 64),
            nn.SiLU(),
            nn.Conv2d(64, latent_dim, 3, padding=1)
        )
        
        # Feature fusion
        self.fusion = nn.Sequential(
            nn.Conv2d(latent_dim*2, latent_dim, 1),
            nn.GroupNorm(32, latent_dim),
            nn.SiLU()
        )
        
        # Frequency blocks
        self.aff_blocks = nn.ModuleList([
            AFFBlock(latent_dim) for _ in range(6)
        ])
        
        # Output
        self.output_conv = nn.Sequential(
            nn.Conv2d(latent_dim, 64, 3, padding=1),
            nn.GroupNorm(32, 64),
            nn.SiLU(),
            nn.Conv2d(64, output_ch, 3, padding=1)
        )

    def forward(self, x, lr_condition):
        # Project input to latent_dim (3 -> 256)
        x = self.input_proj(x)
        
        # Extract LR features (3 -> 256)
        lr_features = self.lr_feature_extractor(lr_condition)
        
        # Fuse features (256 + 256 -> 256)
        x = self.fusion(torch.cat([x, lr_features], dim=1))
        
        # Frequency processing
        for block in self.aff_blocks:
            x = block(x)
            
        return self.output_conv(x)

class AFFBlock(nn.Module):
    """Adaptive Frequency Filtering block with proper FFT handling"""
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm = nn.GroupNorm(32, channels)
        self.act = nn.SiLU()
        
        # Learnable frequency filter
        self.filter = nn.Parameter(torch.ones(1, channels, 1, 1))
        
    def forward(self, x):
        # Spatial processing
        h = self.act(self.norm(self.conv(x)))
        
        # Frequency processing
        fft = torch.fft.rfft2(h, norm='ortho')
        mag = torch.abs(fft)
        phase = torch.angle(fft)
        
        # Apply learned frequency filter
        filtered = mag * self.filter
        
        # Inverse FFT
        real = filtered * torch.cos(phase)
        imag = filtered * torch.sin(phase)
        complex_tensor = torch.complex(real, imag)
        reconstructed = torch.fft.irfft2(complex_tensor, norm='ortho')
        
        return x + reconstructed

# Update SSMoE_UNet to handle input_channels
class SSMoE_UNet(nn.Module):
    def __init__(self, device, num_experts=4, T=1000, in_channels=3):
        super().__init__()
        self.device = device
        self.num_experts = num_experts
        self.T = T
        self.time_embed_dim = 32
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            TimeEmbedding(self.time_embed_dim, device),
            nn.Linear(self.time_embed_dim, self.time_embed_dim),
            nn.SiLU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim)
        )
        
        # Down blocks
        self.down1 = UNetBlock(in_channels, 64, self.time_embed_dim)
        self.down2 = UNetBlock(64, 128, self.time_embed_dim)
        self.down3 = UNetBlock(128, 256, self.time_embed_dim)
        
        # Bottleneck with space MoE
        self.bottleneck = UNetBlock(256, 512, self.time_embed_dim)
        
        # Up blocks
        self.up1 = UNetBlock(512 + 256, 256, self.time_embed_dim)
        self.up2 = UNetBlock(256 + 128, 128, self.time_embed_dim)
        self.up3 = UNetBlock(128 + 64, 64, self.time_embed_dim)
        
        # Output
        self.out = nn.Conv2d(64, 3, 1)
        
        # Initialize sampling MoEs
        self.sampling_experts = nn.ModuleList([
            self._create_expert() for _ in range(num_experts)
        ])
        
    def _create_expert(self):
        """Create one expert network"""
        return nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.GroupNorm(32, 64),
            nn.SiLU(),
            nn.Conv2d(64, 3, 3, padding=1)
        )
    
    def forward(self, x, t):
        # Time embedding
        t_emb = self.time_mlp(t)
        
        # Select expert based on timestep
        stage = (t.float() / self.T * self.num_experts).long().clamp(0, self.num_experts-1)
        expert_out = torch.zeros_like(x, device=self.device)
        
        # Process each expert batch separately
        for i in range(self.num_experts):
            mask = (stage == i)
            if mask.any():
                expert_out[mask] = self.sampling_experts[i](x[mask])
        
        # Downsample
        d1 = self.down1(x + expert_out, t_emb)
        d2 = self.down2(F.max_pool2d(d1, 2), t_emb)
        d3 = self.down3(F.max_pool2d(d2, 2), t_emb)
        
        # Bottleneck
        h = F.max_pool2d(d3, 2)
        h = self.bottleneck(h, t_emb)
        
        # Upsample with skip connections
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d3], dim=1)
        h = self.up1(h, t_emb)
        
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d2], dim=1)
        h = self.up2(h, t_emb)
        
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d1], dim=1)
        h = self.up3(h, t_emb)
        
        return self.out(h)

class LatentDiffusionModel:
    def __init__(self, T=1000, beta_start=1e-4, beta_end=0.02, device='cuda'):
        self.T = T
        self.device = device
        self.betas = torch.linspace(beta_start, beta_end, T, device=device)
        self.alphas = 1. - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
        
        # Initialize components with proper T
        self.denoise_unet = SSMoE_UNet(device, T=T).to(device)
        self.decoder = FrequencyCompensatedDecoder().to(device)
        
        # Metrics
        self.lpips = LPIPS(replace_pooling=True).to(device).eval()
        self.ssim = SSIM().to(device).eval()
        
        # Tracking training progress
        self.current_epoch = 0
        self.total_epochs = 200
    
    def forward_diffuse(self, x0, t):
        """Forward diffusion with noise schedule"""
        t = t.to(self.device)
        noise = torch.randn_like(x0)
        sqrt_alpha_bar = torch.sqrt(self.alpha_bars[t])[:, None, None, None]
        sqrt_one_minus_alpha_bar = torch.sqrt(1. - self.alpha_bars[t])[:, None, None, None]
        xt = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise
        return xt, noise
    
    def train_step(self, x0, lr_condition, optimizer):
        optimizer.zero_grad()
        
        # Sample random timesteps
        t = torch.randint(0, self.T, (x0.size(0),), device=self.device)
        
        # Forward diffusion
        xt, noise = self.forward_diffuse(x0, t)
        
        # Predict noise with UNet
        pred_noise = self.denoise_unet(xt, t)
        
        # Decode to pixel space
        pred_img = self.decoder(xt - pred_noise, lr_condition)
        
        # Hybrid loss
        loss, metrics = self.hybrid_loss(pred_img, x0, pred_noise, noise)
        
        loss.backward()
        optimizer.step()
        return metrics
    
    def hybrid_loss(self, pred, target, noise_pred, true_noise):
        """Combined loss with dynamic weighting"""
        # Base losses
        mse_loss = F.mse_loss(noise_pred, true_noise)
        lpips_loss = self.lpips(pred, target)
        ssim_loss = 1 - self.ssim(pred, target)
        
        # Dynamic weights (progressively focus more on perceptual quality)
        progress = self.current_epoch / self.total_epochs
        lpips_w = 0.4  # Fixed high importance for perceptual quality
        ssim_w = 0.3 * progress  # Increasing structural importance
        mse_w = 1.0 - lpips_w - ssim_w  # Decreasing noise prediction importance
        
        total_loss = mse_w * mse_loss + lpips_w * lpips_loss + ssim_w * ssim_loss
        
        # Additional metrics
        with torch.no_grad():
            psnr = 10 * torch.log10(1 / F.mse_loss(pred, target))
            
        return total_loss, {
            'loss': total_loss.item(),
            'mse': mse_loss.item(),
            'lpips': lpips_loss.item(),
            'ssim': 1 - ssim_loss.item(),
            'psnr': psnr.item()
        }

# Setup device
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# Ensure checkpoint directory exists
os.makedirs('checkpoints_diffusion_nonorm', exist_ok=True)

# Initialize model and optimizer
diffusion = LatentDiffusionModel(device=device)
optimizer = optim.Adam(
    list(diffusion.denoise_unet.parameters()) + 
    list(diffusion.decoder.parameters()),
    lr=1e-4
)

# Training loop
epochs = 200
diffusion.total_epochs = epochs
best_metrics = {'loss': float('inf'), 'psnr': 0, 'ssim': 0, 'lpips': float('inf')}

# Record total training time
start_time = time.time()

for epoch in range(epochs):
    diffusion.current_epoch = epoch
    epoch_metrics = {'loss': 0, 'mse': 0, 'lpips': 0, 'ssim': 0, 'psnr': 0}
    
    for low_imgs, high_imgs, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        low_imgs, high_imgs = low_imgs.to(device), high_imgs.to(device)
        
        # Train step with LR conditioning
        metrics = diffusion.train_step(high_imgs, low_imgs, optimizer)
        
        # Update metrics
        for k in epoch_metrics:
            epoch_metrics[k] += metrics[k]
    
    # Average metrics
    for k in epoch_metrics:
        epoch_metrics[k] /= len(train_loader)
    
    # Save every epoch
    torch.save({
        'epoch': epoch,
        'denoise_unet': diffusion.denoise_unet.state_dict(),
        'decoder': diffusion.decoder.state_dict(),
        'optimizer': optimizer.state_dict(),
        'metrics': epoch_metrics
    }, f'checkpoints_diffusion_nonorm/diffusion_epoch_{epoch+1}.pth')
    
    # Save best model
    if epoch_metrics['loss'] < best_metrics['loss']:
        best_metrics = epoch_metrics
        torch.save({
            'epoch': epoch,
            'denoise_unet': diffusion.denoise_unet.state_dict(),
            'decoder': diffusion.decoder.state_dict(),
            'optimizer': optimizer.state_dict(),
            'metrics': epoch_metrics
        }, 'checkpoints_diffusion_nonorm/diffusion_best_model.pth')
    
    print(f"Epoch {epoch+1} | "
          f"Loss: {epoch_metrics['loss']:.4f} | "
          f"PSNR: {epoch_metrics['psnr']:.2f} dB | "
          f"SSIM: {epoch_metrics['ssim']:.3f} | "
          f"LPIPS: {epoch_metrics['lpips']:.3f}")

total_training_time = time.time() - start_time
hours, remainder = divmod(total_training_time, 3600)
minutes, seconds = divmod(remainder, 60)

print(f"\nTraining complete! Best metrics: "
      f"Loss: {best_metrics['loss']:.4f}, "
      f"PSNR: {best_metrics['psnr']:.2f} dB, "
      f"SSIM: {best_metrics['ssim']:.3f}, "
      f"LPIPS: {best_metrics['lpips']:.3f}")
print(f"Total training time: {int(hours)}h {int(minutes)}m {seconds:.2f}s")

# Save final model
torch.save({
    'epoch': epochs,
    'denoise_unet': diffusion.denoise_unet.state_dict(),
    'decoder': diffusion.decoder.state_dict(),
    'optimizer': optimizer.state_dict(),
    'metrics': epoch_metrics,
    'training_time': total_training_time
}, 'checkpoints_diffusion_nonorm/diffusion_final_model.pth')

Epoch 1/200: 100%|██████████| 49/49 [00:48<00:00,  1.02it/s]


Epoch 1 | Loss: 0.4660 | PSNR: 12.68 dB | SSIM: 0.606 | LPIPS: 0.439


Epoch 2/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 2 | Loss: 0.2007 | PSNR: 15.65 dB | SSIM: 0.781 | LPIPS: 0.276


Epoch 3/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 3 | Loss: 0.1730 | PSNR: 16.08 dB | SSIM: 0.790 | LPIPS: 0.267


Epoch 4/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 4 | Loss: 0.1532 | PSNR: 16.21 dB | SSIM: 0.794 | LPIPS: 0.250


Epoch 5/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 5 | Loss: 0.1384 | PSNR: 16.40 dB | SSIM: 0.801 | LPIPS: 0.244


Epoch 6/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 6 | Loss: 0.1358 | PSNR: 16.60 dB | SSIM: 0.814 | LPIPS: 0.234


Epoch 7/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 7 | Loss: 0.1324 | PSNR: 16.62 dB | SSIM: 0.818 | LPIPS: 0.240


Epoch 8/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 8 | Loss: 0.1214 | PSNR: 16.85 dB | SSIM: 0.828 | LPIPS: 0.226


Epoch 9/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 9 | Loss: 0.1204 | PSNR: 16.86 dB | SSIM: 0.832 | LPIPS: 0.223


Epoch 10/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 10 | Loss: 0.1125 | PSNR: 17.85 dB | SSIM: 0.862 | LPIPS: 0.216


Epoch 11/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 11 | Loss: 0.1107 | PSNR: 18.45 dB | SSIM: 0.880 | LPIPS: 0.210


Epoch 12/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 12 | Loss: 0.1035 | PSNR: 18.53 dB | SSIM: 0.887 | LPIPS: 0.198


Epoch 13/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 13 | Loss: 0.1055 | PSNR: 18.62 dB | SSIM: 0.894 | LPIPS: 0.198


Epoch 14/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 14 | Loss: 0.1032 | PSNR: 18.69 dB | SSIM: 0.893 | LPIPS: 0.195


Epoch 15/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 15 | Loss: 0.1010 | PSNR: 18.63 dB | SSIM: 0.892 | LPIPS: 0.201


Epoch 16/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 16 | Loss: 0.1004 | PSNR: 18.69 dB | SSIM: 0.903 | LPIPS: 0.191


Epoch 17/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 17 | Loss: 0.0961 | PSNR: 18.81 dB | SSIM: 0.907 | LPIPS: 0.188


Epoch 18/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 18 | Loss: 0.0988 | PSNR: 18.81 dB | SSIM: 0.909 | LPIPS: 0.189


Epoch 19/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 19 | Loss: 0.0912 | PSNR: 18.78 dB | SSIM: 0.911 | LPIPS: 0.183


Epoch 20/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 20 | Loss: 0.0986 | PSNR: 18.75 dB | SSIM: 0.910 | LPIPS: 0.188


Epoch 21/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 21 | Loss: 0.0899 | PSNR: 18.82 dB | SSIM: 0.917 | LPIPS: 0.183


Epoch 22/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 22 | Loss: 0.0866 | PSNR: 18.96 dB | SSIM: 0.917 | LPIPS: 0.178


Epoch 23/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 23 | Loss: 0.0883 | PSNR: 18.89 dB | SSIM: 0.919 | LPIPS: 0.181


Epoch 24/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 24 | Loss: 0.0852 | PSNR: 19.08 dB | SSIM: 0.922 | LPIPS: 0.175


Epoch 25/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 25 | Loss: 0.0849 | PSNR: 19.03 dB | SSIM: 0.927 | LPIPS: 0.176


Epoch 26/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 26 | Loss: 0.0839 | PSNR: 19.00 dB | SSIM: 0.927 | LPIPS: 0.172


Epoch 27/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 27 | Loss: 0.0837 | PSNR: 19.05 dB | SSIM: 0.928 | LPIPS: 0.169


Epoch 28/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 28 | Loss: 0.0887 | PSNR: 19.04 dB | SSIM: 0.929 | LPIPS: 0.176


Epoch 29/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 29 | Loss: 0.0870 | PSNR: 19.14 dB | SSIM: 0.934 | LPIPS: 0.175


Epoch 30/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 30 | Loss: 0.0825 | PSNR: 19.21 dB | SSIM: 0.937 | LPIPS: 0.166


Epoch 31/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 31 | Loss: 0.0799 | PSNR: 19.19 dB | SSIM: 0.937 | LPIPS: 0.163


Epoch 32/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 32 | Loss: 0.0831 | PSNR: 19.14 dB | SSIM: 0.937 | LPIPS: 0.170


Epoch 33/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 33 | Loss: 0.0793 | PSNR: 19.21 dB | SSIM: 0.938 | LPIPS: 0.160


Epoch 34/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 34 | Loss: 0.0776 | PSNR: 19.23 dB | SSIM: 0.942 | LPIPS: 0.158


Epoch 35/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 35 | Loss: 0.0787 | PSNR: 19.19 dB | SSIM: 0.942 | LPIPS: 0.159


Epoch 36/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 36 | Loss: 0.0787 | PSNR: 19.29 dB | SSIM: 0.948 | LPIPS: 0.161


Epoch 37/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 37 | Loss: 0.0750 | PSNR: 19.28 dB | SSIM: 0.951 | LPIPS: 0.156


Epoch 38/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 38 | Loss: 0.0749 | PSNR: 19.24 dB | SSIM: 0.949 | LPIPS: 0.156


Epoch 39/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 39 | Loss: 0.0722 | PSNR: 19.37 dB | SSIM: 0.950 | LPIPS: 0.150


Epoch 40/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 40 | Loss: 0.0734 | PSNR: 19.23 dB | SSIM: 0.953 | LPIPS: 0.150


Epoch 41/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 41 | Loss: 0.0758 | PSNR: 19.14 dB | SSIM: 0.952 | LPIPS: 0.152


Epoch 42/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 42 | Loss: 0.0703 | PSNR: 19.36 dB | SSIM: 0.957 | LPIPS: 0.148


Epoch 43/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 43 | Loss: 0.0736 | PSNR: 19.20 dB | SSIM: 0.958 | LPIPS: 0.151


Epoch 44/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 44 | Loss: 0.0711 | PSNR: 19.31 dB | SSIM: 0.960 | LPIPS: 0.148


Epoch 45/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 45 | Loss: 0.0683 | PSNR: 19.21 dB | SSIM: 0.961 | LPIPS: 0.144


Epoch 46/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 46 | Loss: 0.0686 | PSNR: 19.16 dB | SSIM: 0.964 | LPIPS: 0.144


Epoch 47/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 47 | Loss: 0.0713 | PSNR: 19.09 dB | SSIM: 0.961 | LPIPS: 0.148


Epoch 48/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 48 | Loss: 0.0695 | PSNR: 19.04 dB | SSIM: 0.964 | LPIPS: 0.144


Epoch 49/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 49 | Loss: 0.0671 | PSNR: 19.11 dB | SSIM: 0.965 | LPIPS: 0.141


Epoch 50/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 50 | Loss: 0.0659 | PSNR: 19.19 dB | SSIM: 0.968 | LPIPS: 0.140


Epoch 51/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 51 | Loss: 0.0664 | PSNR: 19.05 dB | SSIM: 0.969 | LPIPS: 0.137


Epoch 52/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 52 | Loss: 0.0668 | PSNR: 19.08 dB | SSIM: 0.971 | LPIPS: 0.136


Epoch 53/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 53 | Loss: 0.0642 | PSNR: 18.87 dB | SSIM: 0.971 | LPIPS: 0.135


Epoch 54/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 54 | Loss: 0.0659 | PSNR: 18.81 dB | SSIM: 0.972 | LPIPS: 0.138


Epoch 55/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 55 | Loss: 0.0688 | PSNR: 18.75 dB | SSIM: 0.972 | LPIPS: 0.143


Epoch 56/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 56 | Loss: 0.0636 | PSNR: 18.78 dB | SSIM: 0.973 | LPIPS: 0.134


Epoch 57/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 57 | Loss: 0.0652 | PSNR: 18.81 dB | SSIM: 0.975 | LPIPS: 0.132


Epoch 58/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 58 | Loss: 0.0638 | PSNR: 18.78 dB | SSIM: 0.977 | LPIPS: 0.133


Epoch 59/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 59 | Loss: 0.0624 | PSNR: 18.45 dB | SSIM: 0.974 | LPIPS: 0.132


Epoch 60/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 60 | Loss: 0.0602 | PSNR: 18.40 dB | SSIM: 0.978 | LPIPS: 0.128


Epoch 61/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 61 | Loss: 0.0612 | PSNR: 18.39 dB | SSIM: 0.977 | LPIPS: 0.130


Epoch 62/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 62 | Loss: 0.0616 | PSNR: 18.46 dB | SSIM: 0.980 | LPIPS: 0.129


Epoch 63/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 63 | Loss: 0.0583 | PSNR: 18.19 dB | SSIM: 0.981 | LPIPS: 0.128


Epoch 64/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 64 | Loss: 0.0616 | PSNR: 18.27 dB | SSIM: 0.980 | LPIPS: 0.129


Epoch 65/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 65 | Loss: 0.0606 | PSNR: 18.37 dB | SSIM: 0.981 | LPIPS: 0.125


Epoch 66/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 66 | Loss: 0.0616 | PSNR: 18.13 dB | SSIM: 0.980 | LPIPS: 0.127


Epoch 67/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 67 | Loss: 0.0582 | PSNR: 17.94 dB | SSIM: 0.982 | LPIPS: 0.127


Epoch 68/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 68 | Loss: 0.0593 | PSNR: 18.33 dB | SSIM: 0.982 | LPIPS: 0.128


Epoch 69/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 69 | Loss: 0.0570 | PSNR: 17.77 dB | SSIM: 0.983 | LPIPS: 0.125


Epoch 70/200: 100%|██████████| 49/49 [00:53<00:00,  1.10s/it]


Epoch 70 | Loss: 0.0567 | PSNR: 18.04 dB | SSIM: 0.983 | LPIPS: 0.124


Epoch 71/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 71 | Loss: 0.0612 | PSNR: 17.90 dB | SSIM: 0.983 | LPIPS: 0.130


Epoch 72/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 72 | Loss: 0.0585 | PSNR: 17.89 dB | SSIM: 0.985 | LPIPS: 0.122


Epoch 73/200: 100%|██████████| 49/49 [00:48<00:00,  1.02it/s]


Epoch 73 | Loss: 0.0559 | PSNR: 17.57 dB | SSIM: 0.986 | LPIPS: 0.122


Epoch 74/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 74 | Loss: 0.0538 | PSNR: 17.78 dB | SSIM: 0.986 | LPIPS: 0.118


Epoch 75/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 75 | Loss: 0.0574 | PSNR: 17.70 dB | SSIM: 0.985 | LPIPS: 0.119


Epoch 76/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 76 | Loss: 0.0564 | PSNR: 17.53 dB | SSIM: 0.986 | LPIPS: 0.117


Epoch 77/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 77 | Loss: 0.0560 | PSNR: 17.67 dB | SSIM: 0.987 | LPIPS: 0.121


Epoch 78/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 78 | Loss: 0.0541 | PSNR: 17.25 dB | SSIM: 0.989 | LPIPS: 0.117


Epoch 79/200: 100%|██████████| 49/49 [00:53<00:00,  1.10s/it]


Epoch 79 | Loss: 0.0546 | PSNR: 17.47 dB | SSIM: 0.987 | LPIPS: 0.116


Epoch 80/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 80 | Loss: 0.0547 | PSNR: 17.43 dB | SSIM: 0.989 | LPIPS: 0.117


Epoch 81/200: 100%|██████████| 49/49 [00:53<00:00,  1.10s/it]


Epoch 81 | Loss: 0.0522 | PSNR: 17.60 dB | SSIM: 0.989 | LPIPS: 0.113


Epoch 82/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 82 | Loss: 0.0528 | PSNR: 17.04 dB | SSIM: 0.990 | LPIPS: 0.114


Epoch 83/200: 100%|██████████| 49/49 [00:54<00:00,  1.12s/it]


Epoch 83 | Loss: 0.0549 | PSNR: 17.30 dB | SSIM: 0.990 | LPIPS: 0.117


Epoch 84/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 84 | Loss: 0.0572 | PSNR: 16.92 dB | SSIM: 0.990 | LPIPS: 0.116


Epoch 85/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 85 | Loss: 0.0538 | PSNR: 17.08 dB | SSIM: 0.990 | LPIPS: 0.113


Epoch 86/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 86 | Loss: 0.0539 | PSNR: 16.80 dB | SSIM: 0.990 | LPIPS: 0.112


Epoch 87/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 87 | Loss: 0.0521 | PSNR: 16.54 dB | SSIM: 0.991 | LPIPS: 0.114


Epoch 88/200: 100%|██████████| 49/49 [00:54<00:00,  1.12s/it]


Epoch 88 | Loss: 0.0516 | PSNR: 16.74 dB | SSIM: 0.991 | LPIPS: 0.111


Epoch 89/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 89 | Loss: 0.0485 | PSNR: 16.76 dB | SSIM: 0.993 | LPIPS: 0.108


Epoch 90/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 90 | Loss: 0.0499 | PSNR: 16.55 dB | SSIM: 0.992 | LPIPS: 0.109


Epoch 91/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 91 | Loss: 0.0538 | PSNR: 16.63 dB | SSIM: 0.992 | LPIPS: 0.112


Epoch 92/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 92 | Loss: 0.0514 | PSNR: 16.37 dB | SSIM: 0.992 | LPIPS: 0.110


Epoch 93/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 93 | Loss: 0.0524 | PSNR: 16.26 dB | SSIM: 0.993 | LPIPS: 0.111


Epoch 94/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 94 | Loss: 0.0495 | PSNR: 16.35 dB | SSIM: 0.994 | LPIPS: 0.107


Epoch 95/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 95 | Loss: 0.0497 | PSNR: 16.50 dB | SSIM: 0.993 | LPIPS: 0.108


Epoch 96/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 96 | Loss: 0.0510 | PSNR: 16.25 dB | SSIM: 0.992 | LPIPS: 0.109


Epoch 97/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 97 | Loss: 0.0498 | PSNR: 16.06 dB | SSIM: 0.993 | LPIPS: 0.109


Epoch 98/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 98 | Loss: 0.0488 | PSNR: 16.18 dB | SSIM: 0.994 | LPIPS: 0.104


Epoch 99/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 99 | Loss: 0.0481 | PSNR: 15.64 dB | SSIM: 0.994 | LPIPS: 0.104


Epoch 100/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 100 | Loss: 0.0487 | PSNR: 16.05 dB | SSIM: 0.994 | LPIPS: 0.106


Epoch 101/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 101 | Loss: 0.0495 | PSNR: 15.70 dB | SSIM: 0.994 | LPIPS: 0.105


Epoch 102/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 102 | Loss: 0.0478 | PSNR: 16.26 dB | SSIM: 0.994 | LPIPS: 0.104


Epoch 103/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 103 | Loss: 0.0469 | PSNR: 15.65 dB | SSIM: 0.994 | LPIPS: 0.105


Epoch 104/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 104 | Loss: 0.0517 | PSNR: 16.08 dB | SSIM: 0.994 | LPIPS: 0.112


Epoch 105/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 105 | Loss: 0.0469 | PSNR: 15.55 dB | SSIM: 0.994 | LPIPS: 0.103


Epoch 106/200: 100%|██████████| 49/49 [00:54<00:00,  1.11s/it]


Epoch 106 | Loss: 0.0467 | PSNR: 15.47 dB | SSIM: 0.994 | LPIPS: 0.103


Epoch 107/200: 100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


Epoch 107 | Loss: 0.0480 | PSNR: 15.04 dB | SSIM: 0.995 | LPIPS: 0.104


Epoch 108/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 108 | Loss: 0.0471 | PSNR: 16.04 dB | SSIM: 0.994 | LPIPS: 0.104


Epoch 109/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 109 | Loss: 0.0478 | PSNR: 15.40 dB | SSIM: 0.995 | LPIPS: 0.103


Epoch 110/200: 100%|██████████| 49/49 [00:48<00:00,  1.01it/s]


Epoch 110 | Loss: 0.0449 | PSNR: 15.74 dB | SSIM: 0.995 | LPIPS: 0.099


Epoch 111/200: 100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


Epoch 111 | Loss: 0.0471 | PSNR: 15.60 dB | SSIM: 0.996 | LPIPS: 0.097


Epoch 112/200: 100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


Epoch 112 | Loss: 0.0467 | PSNR: 15.52 dB | SSIM: 0.995 | LPIPS: 0.098


Epoch 113/200: 100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


Epoch 113 | Loss: 0.0452 | PSNR: 15.41 dB | SSIM: 0.995 | LPIPS: 0.101


Epoch 114/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 114 | Loss: 0.0484 | PSNR: 15.80 dB | SSIM: 0.994 | LPIPS: 0.102


Epoch 115/200: 100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


Epoch 115 | Loss: 0.0471 | PSNR: 15.51 dB | SSIM: 0.995 | LPIPS: 0.098


Epoch 116/200: 100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


Epoch 116 | Loss: 0.0447 | PSNR: 15.01 dB | SSIM: 0.994 | LPIPS: 0.097


Epoch 117/200: 100%|██████████| 49/49 [00:48<00:00,  1.00it/s]


Epoch 117 | Loss: 0.0452 | PSNR: 15.19 dB | SSIM: 0.995 | LPIPS: 0.099


Epoch 118/200: 100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


Epoch 118 | Loss: 0.0458 | PSNR: 15.71 dB | SSIM: 0.995 | LPIPS: 0.100


Epoch 119/200: 100%|██████████| 49/49 [00:51<00:00,  1.04s/it]


Epoch 119 | Loss: 0.0427 | PSNR: 15.47 dB | SSIM: 0.995 | LPIPS: 0.096


Epoch 120/200: 100%|██████████| 49/49 [00:49<00:00,  1.00s/it]


Epoch 120 | Loss: 0.0448 | PSNR: 15.65 dB | SSIM: 0.995 | LPIPS: 0.099


Epoch 121/200: 100%|██████████| 49/49 [00:51<00:00,  1.05s/it]


Epoch 121 | Loss: 0.0438 | PSNR: 15.19 dB | SSIM: 0.995 | LPIPS: 0.099


Epoch 122/200: 100%|██████████| 49/49 [00:49<00:00,  1.02s/it]


Epoch 122 | Loss: 0.0425 | PSNR: 15.57 dB | SSIM: 0.996 | LPIPS: 0.095


Epoch 123/200: 100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


Epoch 123 | Loss: 0.0447 | PSNR: 15.59 dB | SSIM: 0.995 | LPIPS: 0.096


Epoch 124/200: 100%|██████████| 49/49 [00:48<00:00,  1.01it/s]


Epoch 124 | Loss: 0.0418 | PSNR: 15.45 dB | SSIM: 0.996 | LPIPS: 0.093


Epoch 125/200: 100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


Epoch 125 | Loss: 0.0454 | PSNR: 15.11 dB | SSIM: 0.995 | LPIPS: 0.097


Epoch 126/200: 100%|██████████| 49/49 [00:48<00:00,  1.00it/s]


Epoch 126 | Loss: 0.0429 | PSNR: 15.32 dB | SSIM: 0.996 | LPIPS: 0.095


Epoch 127/200: 100%|██████████| 49/49 [00:49<00:00,  1.00s/it]


Epoch 127 | Loss: 0.0429 | PSNR: 15.27 dB | SSIM: 0.995 | LPIPS: 0.094


Epoch 128/200: 100%|██████████| 49/49 [00:49<00:00,  1.00s/it]


Epoch 128 | Loss: 0.0426 | PSNR: 15.25 dB | SSIM: 0.996 | LPIPS: 0.093


Epoch 129/200: 100%|██████████| 49/49 [00:48<00:00,  1.00it/s]


Epoch 129 | Loss: 0.0435 | PSNR: 15.42 dB | SSIM: 0.996 | LPIPS: 0.092


Epoch 130/200: 100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


Epoch 130 | Loss: 0.0442 | PSNR: 15.45 dB | SSIM: 0.996 | LPIPS: 0.093


Epoch 131/200: 100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


Epoch 131 | Loss: 0.0427 | PSNR: 15.06 dB | SSIM: 0.997 | LPIPS: 0.091


Epoch 132/200: 100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


Epoch 132 | Loss: 0.0413 | PSNR: 15.20 dB | SSIM: 0.997 | LPIPS: 0.090


Epoch 133/200: 100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


Epoch 133 | Loss: 0.0420 | PSNR: 15.40 dB | SSIM: 0.996 | LPIPS: 0.088


Epoch 134/200: 100%|██████████| 49/49 [00:50<00:00,  1.02s/it]


Epoch 134 | Loss: 0.0410 | PSNR: 14.89 dB | SSIM: 0.997 | LPIPS: 0.091


Epoch 135/200: 100%|██████████| 49/49 [00:48<00:00,  1.01it/s]


Epoch 135 | Loss: 0.0416 | PSNR: 15.38 dB | SSIM: 0.996 | LPIPS: 0.090


Epoch 136/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 136 | Loss: 0.0421 | PSNR: 15.34 dB | SSIM: 0.996 | LPIPS: 0.091


Epoch 137/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 137 | Loss: 0.0394 | PSNR: 15.07 dB | SSIM: 0.996 | LPIPS: 0.088


Epoch 138/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 138 | Loss: 0.0397 | PSNR: 15.50 dB | SSIM: 0.997 | LPIPS: 0.087


Epoch 139/200: 100%|██████████| 49/49 [00:48<00:00,  1.01it/s]


Epoch 139 | Loss: 0.0422 | PSNR: 15.07 dB | SSIM: 0.997 | LPIPS: 0.093


Epoch 140/200: 100%|██████████| 49/49 [00:48<00:00,  1.01it/s]


Epoch 140 | Loss: 0.0419 | PSNR: 15.29 dB | SSIM: 0.997 | LPIPS: 0.088


Epoch 141/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 141 | Loss: 0.0417 | PSNR: 15.42 dB | SSIM: 0.997 | LPIPS: 0.091


Epoch 142/200: 100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


Epoch 142 | Loss: 0.0401 | PSNR: 15.24 dB | SSIM: 0.997 | LPIPS: 0.089


Epoch 143/200: 100%|██████████| 49/49 [00:49<00:00,  1.01s/it]


Epoch 143 | Loss: 0.0415 | PSNR: 15.07 dB | SSIM: 0.997 | LPIPS: 0.091


Epoch 144/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 144 | Loss: 0.0416 | PSNR: 14.82 dB | SSIM: 0.996 | LPIPS: 0.090


Epoch 145/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 145 | Loss: 0.0398 | PSNR: 15.58 dB | SSIM: 0.996 | LPIPS: 0.086


Epoch 146/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 146 | Loss: 0.0390 | PSNR: 15.29 dB | SSIM: 0.997 | LPIPS: 0.087


Epoch 147/200: 100%|██████████| 49/49 [00:49<00:00,  1.00s/it]


Epoch 147 | Loss: 0.0388 | PSNR: 14.98 dB | SSIM: 0.996 | LPIPS: 0.086


Epoch 148/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 148 | Loss: 0.0388 | PSNR: 15.17 dB | SSIM: 0.997 | LPIPS: 0.087


Epoch 149/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 149 | Loss: 0.0399 | PSNR: 15.52 dB | SSIM: 0.997 | LPIPS: 0.086


Epoch 150/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 150 | Loss: 0.0398 | PSNR: 15.30 dB | SSIM: 0.997 | LPIPS: 0.086


Epoch 151/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 151 | Loss: 0.0378 | PSNR: 15.00 dB | SSIM: 0.997 | LPIPS: 0.086


Epoch 152/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 152 | Loss: 0.0393 | PSNR: 15.34 dB | SSIM: 0.997 | LPIPS: 0.086


Epoch 153/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 153 | Loss: 0.0391 | PSNR: 15.21 dB | SSIM: 0.997 | LPIPS: 0.087


Epoch 154/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 154 | Loss: 0.0387 | PSNR: 15.16 dB | SSIM: 0.997 | LPIPS: 0.087


Epoch 155/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 155 | Loss: 0.0386 | PSNR: 15.44 dB | SSIM: 0.996 | LPIPS: 0.086


Epoch 156/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 156 | Loss: 0.0375 | PSNR: 15.16 dB | SSIM: 0.997 | LPIPS: 0.084


Epoch 157/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 157 | Loss: 0.0380 | PSNR: 14.97 dB | SSIM: 0.998 | LPIPS: 0.084


Epoch 158/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 158 | Loss: 0.0391 | PSNR: 15.44 dB | SSIM: 0.997 | LPIPS: 0.086


Epoch 159/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 159 | Loss: 0.0364 | PSNR: 15.23 dB | SSIM: 0.998 | LPIPS: 0.082


Epoch 160/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 160 | Loss: 0.0369 | PSNR: 15.44 dB | SSIM: 0.996 | LPIPS: 0.083


Epoch 161/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 161 | Loss: 0.0371 | PSNR: 14.94 dB | SSIM: 0.997 | LPIPS: 0.082


Epoch 162/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 162 | Loss: 0.0380 | PSNR: 15.16 dB | SSIM: 0.997 | LPIPS: 0.084


Epoch 163/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 163 | Loss: 0.0375 | PSNR: 15.55 dB | SSIM: 0.998 | LPIPS: 0.083


Epoch 164/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 164 | Loss: 0.0381 | PSNR: 14.96 dB | SSIM: 0.997 | LPIPS: 0.087


Epoch 165/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 165 | Loss: 0.0363 | PSNR: 15.38 dB | SSIM: 0.997 | LPIPS: 0.082


Epoch 166/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 166 | Loss: 0.0380 | PSNR: 14.90 dB | SSIM: 0.997 | LPIPS: 0.084


Epoch 167/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 167 | Loss: 0.0369 | PSNR: 15.42 dB | SSIM: 0.997 | LPIPS: 0.082


Epoch 168/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 168 | Loss: 0.0372 | PSNR: 15.24 dB | SSIM: 0.997 | LPIPS: 0.082


Epoch 169/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 169 | Loss: 0.0368 | PSNR: 14.69 dB | SSIM: 0.997 | LPIPS: 0.082


Epoch 170/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 170 | Loss: 0.0370 | PSNR: 15.25 dB | SSIM: 0.997 | LPIPS: 0.083


Epoch 171/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 171 | Loss: 0.0386 | PSNR: 15.21 dB | SSIM: 0.997 | LPIPS: 0.088


Epoch 172/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 172 | Loss: 0.0387 | PSNR: 15.26 dB | SSIM: 0.997 | LPIPS: 0.085


Epoch 173/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 173 | Loss: 0.0363 | PSNR: 15.11 dB | SSIM: 0.997 | LPIPS: 0.081


Epoch 174/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 174 | Loss: 0.0352 | PSNR: 15.32 dB | SSIM: 0.997 | LPIPS: 0.080


Epoch 175/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 175 | Loss: 0.0355 | PSNR: 15.19 dB | SSIM: 0.998 | LPIPS: 0.079


Epoch 176/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 176 | Loss: 0.0353 | PSNR: 15.62 dB | SSIM: 0.997 | LPIPS: 0.079


Epoch 177/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 177 | Loss: 0.0349 | PSNR: 15.08 dB | SSIM: 0.998 | LPIPS: 0.078


Epoch 178/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 178 | Loss: 0.0348 | PSNR: 15.17 dB | SSIM: 0.998 | LPIPS: 0.078


Epoch 179/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 179 | Loss: 0.0356 | PSNR: 15.34 dB | SSIM: 0.998 | LPIPS: 0.079


Epoch 180/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 180 | Loss: 0.0350 | PSNR: 15.49 dB | SSIM: 0.998 | LPIPS: 0.078


Epoch 181/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 181 | Loss: 0.0352 | PSNR: 15.73 dB | SSIM: 0.997 | LPIPS: 0.078


Epoch 182/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 182 | Loss: 0.0359 | PSNR: 15.03 dB | SSIM: 0.998 | LPIPS: 0.080


Epoch 183/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 183 | Loss: 0.0355 | PSNR: 14.95 dB | SSIM: 0.998 | LPIPS: 0.078


Epoch 184/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 184 | Loss: 0.0363 | PSNR: 15.35 dB | SSIM: 0.998 | LPIPS: 0.081


Epoch 185/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 185 | Loss: 0.0389 | PSNR: 15.27 dB | SSIM: 0.998 | LPIPS: 0.087


Epoch 186/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 186 | Loss: 0.0382 | PSNR: 15.18 dB | SSIM: 0.997 | LPIPS: 0.084


Epoch 187/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 187 | Loss: 0.0344 | PSNR: 15.14 dB | SSIM: 0.998 | LPIPS: 0.078


Epoch 188/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 188 | Loss: 0.0341 | PSNR: 15.20 dB | SSIM: 0.998 | LPIPS: 0.076


Epoch 189/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 189 | Loss: 0.0331 | PSNR: 15.02 dB | SSIM: 0.998 | LPIPS: 0.076


Epoch 190/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 190 | Loss: 0.0336 | PSNR: 15.14 dB | SSIM: 0.998 | LPIPS: 0.076


Epoch 191/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 191 | Loss: 0.0332 | PSNR: 15.44 dB | SSIM: 0.998 | LPIPS: 0.075


Epoch 192/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 192 | Loss: 0.0335 | PSNR: 15.22 dB | SSIM: 0.998 | LPIPS: 0.074


Epoch 193/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 193 | Loss: 0.0337 | PSNR: 15.55 dB | SSIM: 0.998 | LPIPS: 0.073


Epoch 194/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 194 | Loss: 0.0334 | PSNR: 15.41 dB | SSIM: 0.998 | LPIPS: 0.076


Epoch 195/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 195 | Loss: 0.0337 | PSNR: 15.26 dB | SSIM: 0.998 | LPIPS: 0.075


Epoch 196/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 196 | Loss: 0.0364 | PSNR: 14.83 dB | SSIM: 0.998 | LPIPS: 0.079


Epoch 197/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 197 | Loss: 0.0330 | PSNR: 15.31 dB | SSIM: 0.998 | LPIPS: 0.075


Epoch 198/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 198 | Loss: 0.0345 | PSNR: 15.57 dB | SSIM: 0.998 | LPIPS: 0.076


Epoch 199/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 199 | Loss: 0.0350 | PSNR: 15.49 dB | SSIM: 0.998 | LPIPS: 0.079


Epoch 200/200: 100%|██████████| 49/49 [00:47<00:00,  1.03it/s]


Epoch 200 | Loss: 0.0340 | PSNR: 15.29 dB | SSIM: 0.998 | LPIPS: 0.075

Training complete! Best metrics: Loss: 0.0330, PSNR: 15.31 dB, SSIM: 0.998, LPIPS: 0.075
Total training time: 2h 46m 13.89s


In [21]:
import os
import torch
from torchvision import transforms
from PIL import ImageDraw, ImageFont
from tqdm import tqdm

def evaluate_diffusion(denoise_unet, decoder, test_loader, device, diffusion, save_samples=True, sample_dir="sample"):
    # Initialize metrics
    psnr = PSNR().to(device)
    ssim = SSIM().to(device)
    lpips = LPIPS(replace_pooling=True).to(device)
    
    metrics = {
        'per_image': [],
        'average': {
            'psnr': 0.0,
            'ssim': 0.0,
            'lpips': 0.0
        }
    }
    
    if save_samples:
        os.makedirs(sample_dir, exist_ok=True)
    
    denoise_unet.eval()
    decoder.eval()
    sample_counter = 0
    
    with torch.no_grad():
        for batch_idx, (low_imgs, high_imgs, *_) in enumerate(tqdm(test_loader, desc="Evaluating")):
            low_imgs, high_imgs = low_imgs.to(device), high_imgs.to(device)
            
            # Use t=0 (final timestep) for evaluation
            t = torch.zeros(low_imgs.size(0), dtype=torch.long, device=device)
            
            # Forward diffusion process
            xt, _ = diffusion.forward_diffuse(high_imgs, t)
            
            # Denoise and decode
            pred_noise = denoise_unet(xt, t)
            enhanced_imgs = decoder(xt - pred_noise, low_imgs)
            
            # Clamp outputs to valid range
            enhanced_imgs = enhanced_imgs.clamp(0, 1)
            high_imgs = high_imgs.clamp(0, 1)
            
            for img_idx in range(low_imgs.size(0)):
                # Calculate metrics
                img_metrics = {
                    'psnr': psnr(enhanced_imgs[img_idx].unsqueeze(0), high_imgs[img_idx].unsqueeze(0)).item(),
                    'ssim': ssim(enhanced_imgs[img_idx].unsqueeze(0), high_imgs[img_idx].unsqueeze(0)).item(),
                    'lpips': lpips(enhanced_imgs[img_idx].unsqueeze(0), high_imgs[img_idx].unsqueeze(0)).item()
                }
                
                # Denormalize low-light image for visualization
                low_img_denorm = (low_imgs[img_idx] * 0.5) + 0.5
                low_img_denorm = low_img_denorm.clamp(0, 1)
                
                # Create a 4-column comparison image (C, H, W*4)
                comparison = torch.cat([
                    low_imgs[img_idx].clamp(0, 1),  # Low-light (normalized)
                    low_img_denorm,                 # Low-light (denormalized)
                    enhanced_imgs[img_idx],         # Enhanced
                    high_imgs[img_idx]              # Ground truth
                ], dim=-1)
                
                # Convert to PIL
                comparison_pil = transforms.ToPILImage()(comparison.cpu())
                
                if save_samples:
                    # Add labels
                    draw = ImageDraw.Draw(comparison_pil)
                    try:
                        font = ImageFont.truetype("arial.ttf", 15)
                    except:
                        font = ImageFont.load_default()
                    
                    width = comparison_pil.width // 4
                    draw.text((10, 10), "Low Light\n(Normalized)", fill="white", font=font)
                    draw.text((width + 10, 10), "Low Light\n(Denormalized)", fill="white", font=font)
                    draw.text((2 * width + 10, 10), "Enhanced", fill="white", font=font)
                    draw.text((3 * width + 10, 10), "Ground Truth", fill="white", font=font)

                    # Metrics under Enhanced
                    text_y = comparison_pil.height - 40
                    draw.text((2 * width + 10, text_y), 
                              f"PSNR: {img_metrics['psnr']:.2f}\nSSIM: {img_metrics['ssim']:.3f}\nLPIPS: {img_metrics['lpips']:.3f}", 
                              fill="white", font=font)

                    # Save image
                    sample_path = os.path.join(sample_dir, f"sample_{sample_counter:04d}.png")
                    comparison_pil.save(sample_path)
                    img_metrics['image_path'] = sample_path
                    sample_counter += 1
                
                # Store metrics
                metrics['per_image'].append(img_metrics)
                metrics['average']['psnr'] += img_metrics['psnr']
                metrics['average']['ssim'] += img_metrics['ssim']
                metrics['average']['lpips'] += img_metrics['lpips']
    
    # Compute averages
    total_samples = len(metrics['per_image'])
    metrics['average']['psnr'] /= total_samples
    metrics['average']['ssim'] /= total_samples
    metrics['average']['lpips'] /= total_samples
    
    print(f"\nEvaluation Results (Averages):")
    print(f"PSNR: {metrics['average']['psnr']:.2f} dB")
    print(f"SSIM: {metrics['average']['ssim']:.4f}")
    print(f"LPIPS: {metrics['average']['lpips']:.4f}")
    
    return metrics

# Initialize model
diffusion = LatentDiffusionModel(device=device)

# Load checkpoint
checkpoint = torch.load('checkpoints_diffusion_nonorm/diffusion_epoch_39.pth', map_location=device)

# Load state_dict into model components
diffusion.denoise_unet.load_state_dict(checkpoint['denoise_unet'])
diffusion.decoder.load_state_dict(checkpoint['decoder'])

# Evaluate
metrics = evaluate_diffusion(
    denoise_unet=diffusion.denoise_unet,
    decoder=diffusion.decoder,
    test_loader=test_loader,  # Your test DataLoader
    device=device,
    diffusion=diffusion,      # Already has forward_diffuse()
    save_samples=True,         # Save all comparison images
    sample_dir="diffusion_nonorm_sample"
)

Evaluating: 100%|██████████| 15/15 [00:02<00:00,  6.27it/s]


Evaluation Results (Averages):
PSNR: 20.18 dB
SSIM: 0.8569
LPIPS: 0.1440



