# Initiation

In [3]:
import os
import random
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchmetrics import PeakSignalNoiseRatio as PSNR, StructuralSimilarityIndexMeasure as SSIM
from piq import LPIPS
import os
import torch.nn.functional as F
from torchvision.models import vgg19
from torchvision.utils import save_image
from torchvision.datasets import DatasetFolder
from datetime import datetime
from sklearn.model_selection import train_test_split
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
from copy import deepcopy

In [4]:
def set_seed(seed=42):
    """Set all random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed (choose any number you like)
set_seed(42)

# DataLoaders


In [5]:
class RandomAugment:
    def __init__(self):
        self.color_jitter = transforms.ColorJitter(
            contrast=0.05,
            saturation=0.1,
            hue=0.2
        )
        self.gaussian_blur = transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.5))
        # self.to_grayscale = transforms.RandomGrayscale(p=0.3)
        self.invert = transforms.RandomInvert(p=0.3)
        
    def __call__(self, img):
        # Apply augmentations in random order
        aug_order = random.sample(['color', 'blur', 'invert'], k=random.randint(0, 3))
        
        for aug in aug_order:
            if aug == 'color' and random.random() < 0.7:  # 70% chance
                img = self.color_jitter(img)
            elif aug == 'blur' and random.random() < 0.5:  # 50% chance
                img = self.gaussian_blur(img)
            elif aug == 'invert':
                img = self.invert(img)
                
        return img
    
def show_true_images(low_batch, high_batch, n=4):

    """
    Shows EXACTLY what's in your dataset with correct color handling.
    Works for both normalized and unnormalized images.
    """

    plt.figure(figsize=(18, 8))

    for i in range(min(n, len(low_batch))):
        # --- Low-light ---
        plt.subplot(2, n, i+1)
        low_img = low_batch[i].permute(1, 2, 0).numpy()

        # Handle normalization if present
        if low_img.min() < 0:  # Likely normalized
            low_img = (low_img * 0.5 + 0.5)  # Reverse imagenet norm

        # Ensure proper image range
        low_img = np.clip(low_img, 0, 1)
        plt.imshow(low_img)
        plt.title(f"Low")
        plt.axis('off')

        # --- Normal-light ---
        plt.subplot(2, n, n+i+1)
        high_img = high_batch[i].permute(1, 2, 0).numpy()

        if high_img.min() < 0:  # Likely normalized
            high_img = (high_img * 0.5 + 0.5)

        high_img = np.clip(high_img, 0, 1)
        plt.imshow(high_img)
        plt.title(f"High")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# Run Data

In [8]:
# Your existing dataset code
dataset_path1 = r"/home/ahansviar2/Deep Learning Project (GAN for Light)/downloaded_images"
train_path = f'{dataset_path1}/train'
val_path = f'{dataset_path1}/val'
test_path = f'{dataset_path1}/test'

# Usage in Dataset
transform = transforms.Compose([
    # RandomAugment(),  # Your custom augmentations
    # transforms.GaussianBlur(kernel_size=3, sigma=(0.1,0.5)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Scales to [-1, 1]
    transforms.Resize((256, 256)), 
    transforms.ToTensor(),
])

# Transform Pipelines
train_input_transform = transforms.Compose([
    # RandomAugment(),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1,0.5)),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

low_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

high_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

target_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

In [11]:
class CleanDataset(Dataset):
    def __init__(self, root_dir, low_transform=None, high_transform=None):
        self.root_dir = root_dir
        self.low_transform = low_transform
        self.high_transform = high_transform
        self.low_dir = os.path.join(root_dir, "low")
        self.high_dir = os.path.join(root_dir, "high")
        self.image_names = sorted(os.listdir(self.low_dir))  # Ensures matching pairs

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        low_img_path = os.path.join(self.low_dir, self.image_names[idx])
        high_img_path = os.path.join(self.high_dir, self.image_names[idx])

        low_img = Image.open(low_img_path).convert("RGB")
        high_img = Image.open(high_img_path).convert("RGB")

        # Apply separate transforms
        if self.low_transform:
            low_img = self.low_transform(low_img)
        if self.high_transform:
            high_img = self.high_transform(high_img)
            
        return low_img, high_img, os.path.basename(low_img_path)

# Initialize datasets ONCE with correct parameters
train_dataset = CleanDataset(
    root_dir=train_path, 
    low_transform=low_transform,
    high_transform=high_transform
)

val_dataset = CleanDataset(
    root_dir=val_path,
    low_transform=low_transform,
    high_transform=high_transform
)

test_dataset = CleanDataset(
    root_dir=test_path,
    low_transform=low_transform,
    high_transform=high_transform
)

# DataLoader setup (this part is correct)
batch_size = 8

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=1
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda:0


## Data Checking

In [9]:
for split in ['train', 'val', 'test']:
    low_dir = os.path.join(dataset_path1, split, 'low')
    high_dir = os.path.join(dataset_path1, split, 'high')
    low_files = set(os.listdir(low_dir))
    high_files = set(os.listdir(high_dir))
    assert low_files == high_files, f"Mismatch in {split} set!"
print("✅ All datasets have matched low/high pairs.")

✅ All datasets have matched low/high pairs.


In [12]:
# Test one batch
low_batch, high_batch, _ = next(iter(train_loader))
print(f"Low batch shape: {low_batch.shape}")
print(f"High batch shape: {high_batch.shape}")
print(f"Pixel range - Low: [{low_batch.min():.2f}, {low_batch.max():.2f}]")
print(f"Pixel range - High: [{high_batch.min():.2f}, {high_batch.max():.2f}]")

Low batch shape: torch.Size([8, 3, 256, 256])
High batch shape: torch.Size([8, 3, 256, 256])
Pixel range - Low: [-1.00, 0.98]
Pixel range - High: [-1.00, 1.00]


In [13]:
def validate_image_pairs(dataset, n_samples=5):
    """Plot random pairs to visually inspect alignment/quality"""
    indices = random.sample(range(len(dataset)), n_samples)
    for idx in indices:
        low, high, _ = dataset[idx]
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(low.permute(1, 2, 0).clamp(0, 1))
        plt.title("Low-light")
        plt.subplot(1, 2, 2)
        plt.imshow(high.permute(1, 2, 0).clamp(0, 1))
        plt.title("Normal-light")
        plt.show()

# validate_image_pairs(train_dataset)

In [14]:
def analyze_channel_stats(loader):
    """Calculate mean/std per channel across dataset"""
    channels_sum, channels_sq_sum, num_batches = 0, 0, 0
    for low, high, _, in loader:
        channels_sum += torch.mean(low, dim=[0, 2, 3])
        channels_sq_sum += torch.mean(low**2, dim=[0, 2, 3])
        num_batches += 1
    mean = channels_sum / num_batches
    std = (channels_sq_sum / num_batches - mean**2)**0.5
    print(f"Low-light stats - Mean: {mean}, Std: {std}")

analyze_channel_stats(train_loader)

Low-light stats - Mean: tensor([-0.6094, -0.6131, -0.6181]), Std: tensor([0.6393, 0.6413, 0.6484])


In [15]:
low, high, _ = next(iter(train_loader))
print(f"Batch memory: {low.element_size() * low.nelement() / 1024**2:.2f} MB")

Batch memory: 6.00 MB


# MODEL ARCHITECTURE

## Diffusion Experiment 2

In [16]:
import math
import os
import time
import torch
import torch.nn.functional as F
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
from piq import LPIPS
from tqdm import tqdm
from torch import nn, optim

class TimeEmbedding(nn.Module):
    def __init__(self, dim, device):
        super().__init__()
        self.dim = dim
        half_dim = dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        self.register_buffer('emb', emb)

    def forward(self, t):
        emb = t[:, None] * self.emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm = nn.BatchNorm2d(out_ch)
        self.act = nn.SiLU()

    def forward(self, x, t):
        h = self.conv1(x)
        time_emb = self.act(self.time_mlp(t))
        h = h + time_emb[:, :, None, None]
        h = self.norm(h)
        h = self.act(h)
        h = self.conv2(h)
        h = self.norm(h)
        h = self.act(h)
        return h
    
class UNet(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.time_mlp = nn.Sequential(
            TimeEmbedding(32, device),
            nn.Linear(32, 32),
            nn.SiLU(),
            nn.Linear(32, 32)
        )
        
        # Down blocks
        self.down1 = UNetBlock(3, 64, 32)
        self.down2 = UNetBlock(64, 128, 32)
        self.down3 = UNetBlock(128, 256, 32)
        
        # Bottleneck
        self.bottleneck = UNetBlock(256, 512, 32)
        
        # Up blocks
        self.up1 = UNetBlock(512 + 256, 256, 32)
        self.up2 = UNetBlock(256 + 128, 128, 32)
        self.up3 = UNetBlock(128 + 64, 64, 32)
        
        # Output
        self.out = nn.Conv2d(64, 3, 1)

    def forward(self, x, t):
        t = self.time_mlp(t)
        d1 = self.down1(x, t)
        d2 = self.down2(F.max_pool2d(d1, 2), t)
        d3 = self.down3(F.max_pool2d(d2, 2), t)
        
        h = F.max_pool2d(d3, 2)
        h = self.bottleneck(h, t)
        
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d3], dim=1)
        h = self.up1(h, t)
        
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d2], dim=1)
        h = self.up2(h, t)
        
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d1], dim=1)
        h = self.up3(h, t)
        
        return self.out(h)

class FrequencyCompensatedDecoder(nn.Module):
    def __init__(self, latent_dim=256, output_ch=3):
        super().__init__()
        # Input projection to match latent_dim
        self.input_proj = nn.Conv2d(3, latent_dim, 1)  # Changed from 259 to 3
        
        # LR feature extractor
        self.lr_feature_extractor = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.GroupNorm(32, 64),
            nn.SiLU(),
            nn.Conv2d(64, latent_dim, 3, padding=1)
        )
        
        # Feature fusion
        self.fusion = nn.Sequential(
            nn.Conv2d(latent_dim*2, latent_dim, 1),
            nn.GroupNorm(32, latent_dim),
            nn.SiLU()
        )
        
        # Frequency blocks
        self.aff_blocks = nn.ModuleList([
            AFFBlock(latent_dim) for _ in range(6)
        ])
        
        # Output
        self.output_conv = nn.Sequential(
            nn.Conv2d(latent_dim, 64, 3, padding=1),
            nn.GroupNorm(32, 64),
            nn.SiLU(),
            nn.Conv2d(64, output_ch, 3, padding=1)
        )

    def forward(self, x, lr_condition):
        # Project input to latent_dim (3 -> 256)
        x = self.input_proj(x)
        
        # Extract LR features (3 -> 256)
        lr_features = self.lr_feature_extractor(lr_condition)
        
        # Fuse features (256 + 256 -> 256)
        x = self.fusion(torch.cat([x, lr_features], dim=1))
        
        # Frequency processing
        for block in self.aff_blocks:
            x = block(x)
            
        return self.output_conv(x)

class AFFBlock(nn.Module):
    """Adaptive Frequency Filtering block with proper FFT handling"""
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm = nn.GroupNorm(32, channels)
        self.act = nn.SiLU()
        
        # Learnable frequency filter
        self.filter = nn.Parameter(torch.ones(1, channels, 1, 1))
        
    def forward(self, x):
        # Spatial processing
        h = self.act(self.norm(self.conv(x)))
        
        # Frequency processing
        fft = torch.fft.rfft2(h, norm='ortho')
        mag = torch.abs(fft)
        phase = torch.angle(fft)
        
        # Apply learned frequency filter
        filtered = mag * self.filter
        
        # Inverse FFT
        real = filtered * torch.cos(phase)
        imag = filtered * torch.sin(phase)
        complex_tensor = torch.complex(real, imag)
        reconstructed = torch.fft.irfft2(complex_tensor, norm='ortho')
        
        return x + reconstructed

# Update SSMoE_UNet to handle input_channels
class SSMoE_UNet(nn.Module):
    def __init__(self, device, num_experts=4, T=1000, in_channels=3):
        super().__init__()
        self.device = device
        self.num_experts = num_experts
        self.T = T
        self.time_embed_dim = 32
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            TimeEmbedding(self.time_embed_dim, device),
            nn.Linear(self.time_embed_dim, self.time_embed_dim),
            nn.SiLU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim)
        )
        
        # Down blocks
        self.down1 = UNetBlock(in_channels, 64, self.time_embed_dim)
        self.down2 = UNetBlock(64, 128, self.time_embed_dim)
        self.down3 = UNetBlock(128, 256, self.time_embed_dim)
        
        # Bottleneck with space MoE
        self.bottleneck = UNetBlock(256, 512, self.time_embed_dim)
        
        # Up blocks
        self.up1 = UNetBlock(512 + 256, 256, self.time_embed_dim)
        self.up2 = UNetBlock(256 + 128, 128, self.time_embed_dim)
        self.up3 = UNetBlock(128 + 64, 64, self.time_embed_dim)
        
        # Output
        self.out = nn.Conv2d(64, 3, 1)
        
        # Initialize sampling MoEs
        self.sampling_experts = nn.ModuleList([
            self._create_expert() for _ in range(num_experts)
        ])
        
    def _create_expert(self):
        """Create one expert network"""
        return nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.GroupNorm(32, 64),
            nn.SiLU(),
            nn.Conv2d(64, 3, 3, padding=1)
        )
    
    def forward(self, x, t):
        # Time embedding
        t_emb = self.time_mlp(t)
        
        # Select expert based on timestep
        stage = (t.float() / self.T * self.num_experts).long().clamp(0, self.num_experts-1)
        expert_out = torch.zeros_like(x, device=self.device)
        
        # Process each expert batch separately
        for i in range(self.num_experts):
            mask = (stage == i)
            if mask.any():
                expert_out[mask] = self.sampling_experts[i](x[mask])
        
        # Downsample
        d1 = self.down1(x + expert_out, t_emb)
        d2 = self.down2(F.max_pool2d(d1, 2), t_emb)
        d3 = self.down3(F.max_pool2d(d2, 2), t_emb)
        
        # Bottleneck
        h = F.max_pool2d(d3, 2)
        h = self.bottleneck(h, t_emb)
        
        # Upsample with skip connections
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d3], dim=1)
        h = self.up1(h, t_emb)
        
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d2], dim=1)
        h = self.up2(h, t_emb)
        
        h = F.interpolate(h, scale_factor=2)
        h = torch.cat([h, d1], dim=1)
        h = self.up3(h, t_emb)
        
        return self.out(h)

class LatentDiffusionModel:
    def __init__(self, T=1000, beta_start=1e-4, beta_end=0.02, device='cuda'):
        self.T = T
        self.device = device
        self.betas = torch.linspace(beta_start, beta_end, T, device=device)
        self.alphas = 1. - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
        
        # Initialize components with proper T
        self.denoise_unet = SSMoE_UNet(device, T=T).to(device)
        self.decoder = FrequencyCompensatedDecoder().to(device)
        
        # Metrics
        self.lpips = LPIPS(replace_pooling=True).to(device).eval()
        self.ssim = SSIM().to(device).eval()
        
        # Tracking training progress
        self.current_epoch = 0
        self.total_epochs = 200
    
    def forward_diffuse(self, x0, t):
        """Forward diffusion with noise schedule"""
        t = t.to(self.device)
        noise = torch.randn_like(x0)
        sqrt_alpha_bar = torch.sqrt(self.alpha_bars[t])[:, None, None, None]
        sqrt_one_minus_alpha_bar = torch.sqrt(1. - self.alpha_bars[t])[:, None, None, None]
        xt = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise
        return xt, noise
    
    def train_step(self, x0, lr_condition, optimizer):
        optimizer.zero_grad()
        
        # Sample random timesteps
        t = torch.randint(0, self.T, (x0.size(0),), device=self.device)
        
        # Forward diffusion
        xt, noise = self.forward_diffuse(x0, t)
        
        # Predict noise with UNet
        pred_noise = self.denoise_unet(xt, t)
        
        # Decode to pixel space
        pred_img = self.decoder(xt - pred_noise, lr_condition)
        
        # Hybrid loss
        loss, metrics = self.hybrid_loss(pred_img, x0, pred_noise, noise)
        
        loss.backward()
        optimizer.step()
        return metrics
    
    def hybrid_loss(self, pred, target, noise_pred, true_noise):
        """Combined loss with dynamic weighting"""
        # Base losses
        mse_loss = F.mse_loss(noise_pred, true_noise)
        lpips_loss = self.lpips(pred, target)
        ssim_loss = 1 - self.ssim(pred, target)
        
        # Dynamic weights (progressively focus more on perceptual quality)
        progress = self.current_epoch / self.total_epochs
        lpips_w = 0.4  # Fixed high importance for perceptual quality
        ssim_w = 0.3 * progress  # Increasing structural importance
        mse_w = 1.0 - lpips_w - ssim_w  # Decreasing noise prediction importance
        
        total_loss = mse_w * mse_loss + lpips_w * lpips_loss + ssim_w * ssim_loss
        
        # Additional metrics
        with torch.no_grad():
            psnr = 10 * torch.log10(1 / F.mse_loss(pred, target))
            
        return total_loss, {
            'loss': total_loss.item(),
            'mse': mse_loss.item(),
            'lpips': lpips_loss.item(),
            'ssim': 1 - ssim_loss.item(),
            'psnr': psnr.item()
        }

# Setup device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Ensure checkpoint directory exists
os.makedirs('checkpoints', exist_ok=True)

# Initialize model and optimizer
diffusion = LatentDiffusionModel(device=device)
optimizer = optim.Adam(
    list(diffusion.denoise_unet.parameters()) + 
    list(diffusion.decoder.parameters()),
    lr=1e-4
)

# Training loop
epochs = 200
diffusion.total_epochs = epochs
best_metrics = {'loss': float('inf'), 'psnr': 0, 'ssim': 0, 'lpips': float('inf')}

# Record total training time
start_time = time.time()

for epoch in range(epochs):
    diffusion.current_epoch = epoch
    epoch_metrics = {'loss': 0, 'mse': 0, 'lpips': 0, 'ssim': 0, 'psnr': 0}
    
    for low_imgs, high_imgs, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        low_imgs, high_imgs = low_imgs.to(device), high_imgs.to(device)
        
        # Train step with LR conditioning
        metrics = diffusion.train_step(high_imgs, low_imgs, optimizer)
        
        # Update metrics
        for k in epoch_metrics:
            epoch_metrics[k] += metrics[k]
    
    # Average metrics
    for k in epoch_metrics:
        epoch_metrics[k] /= len(train_loader)
    
    # Save every epoch
    torch.save({
        'epoch': epoch,
        'denoise_unet': diffusion.denoise_unet.state_dict(),
        'decoder': diffusion.decoder.state_dict(),
        'optimizer': optimizer.state_dict(),
        'metrics': epoch_metrics
    }, f'checkpoints/diffusion_epoch_{epoch+1}.pth')
    
    # Save best model
    if epoch_metrics['loss'] < best_metrics['loss']:
        best_metrics = epoch_metrics
        torch.save({
            'epoch': epoch,
            'denoise_unet': diffusion.denoise_unet.state_dict(),
            'decoder': diffusion.decoder.state_dict(),
            'optimizer': optimizer.state_dict(),
            'metrics': epoch_metrics
        }, 'checkpoints/diffusion_best_model.pth')
    
    print(f"Epoch {epoch+1} | "
          f"Loss: {epoch_metrics['loss']:.4f} | "
          f"PSNR: {epoch_metrics['psnr']:.2f} dB | "
          f"SSIM: {epoch_metrics['ssim']:.3f} | "
          f"LPIPS: {epoch_metrics['lpips']:.3f}")

total_training_time = time.time() - start_time
hours, remainder = divmod(total_training_time, 3600)
minutes, seconds = divmod(remainder, 60)

print(f"\nTraining complete! Best metrics: "
      f"Loss: {best_metrics['loss']:.4f}, "
      f"PSNR: {best_metrics['psnr']:.2f} dB, "
      f"SSIM: {best_metrics['ssim']:.3f}, "
      f"LPIPS: {best_metrics['lpips']:.3f}")
print(f"Total training time: {int(hours)}h {int(minutes)}m {seconds:.2f}s")

# Save final model
torch.save({
    'epoch': epochs,
    'denoise_unet': diffusion.denoise_unet.state_dict(),
    'decoder': diffusion.decoder.state_dict(),
    'optimizer': optimizer.state_dict(),
    'metrics': epoch_metrics,
    'training_time': total_training_time
}, 'checkpoints/diffusion_final_model.pth')

Epoch 1/200: 100%|██████████| 49/49 [01:08<00:00,  1.40s/it]


Epoch 1 | Loss: 0.4734 | PSNR: 7.54 dB | SSIM: 0.307 | LPIPS: 0.421


Epoch 2/200: 100%|██████████| 49/49 [00:53<00:00,  1.08s/it]


Epoch 2 | Loss: 0.2134 | PSNR: 11.41 dB | SSIM: 0.453 | LPIPS: 0.267


Epoch 3/200: 100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


Epoch 3 | Loss: 0.1736 | PSNR: 11.94 dB | SSIM: 0.478 | LPIPS: 0.237


Epoch 4/200: 100%|██████████| 49/49 [00:53<00:00,  1.08s/it]


Epoch 4 | Loss: 0.1548 | PSNR: 12.05 dB | SSIM: 0.493 | LPIPS: 0.222


Epoch 5/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 5 | Loss: 0.1389 | PSNR: 12.22 dB | SSIM: 0.508 | LPIPS: 0.214


Epoch 6/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 6 | Loss: 0.1349 | PSNR: 12.56 dB | SSIM: 0.543 | LPIPS: 0.207


Epoch 7/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 7 | Loss: 0.1258 | PSNR: 12.88 dB | SSIM: 0.581 | LPIPS: 0.198


Epoch 8/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 8 | Loss: 0.1179 | PSNR: 13.35 dB | SSIM: 0.622 | LPIPS: 0.191


Epoch 9/200: 100%|██████████| 49/49 [00:53<00:00,  1.08s/it]


Epoch 9 | Loss: 0.1155 | PSNR: 13.90 dB | SSIM: 0.654 | LPIPS: 0.185


Epoch 10/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 10 | Loss: 0.1089 | PSNR: 13.82 dB | SSIM: 0.656 | LPIPS: 0.180


Epoch 11/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 11 | Loss: 0.1105 | PSNR: 13.97 dB | SSIM: 0.658 | LPIPS: 0.181


Epoch 12/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 12 | Loss: 0.1091 | PSNR: 14.18 dB | SSIM: 0.702 | LPIPS: 0.175


Epoch 13/200: 100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


Epoch 13 | Loss: 0.1100 | PSNR: 14.11 dB | SSIM: 0.709 | LPIPS: 0.177


Epoch 14/200: 100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


Epoch 14 | Loss: 0.1044 | PSNR: 14.44 dB | SSIM: 0.752 | LPIPS: 0.171


Epoch 15/200: 100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


Epoch 15 | Loss: 0.0987 | PSNR: 14.21 dB | SSIM: 0.760 | LPIPS: 0.170


Epoch 16/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 16 | Loss: 0.1010 | PSNR: 14.10 dB | SSIM: 0.771 | LPIPS: 0.168


Epoch 17/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 17 | Loss: 0.0955 | PSNR: 14.12 dB | SSIM: 0.785 | LPIPS: 0.166


Epoch 18/200: 100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


Epoch 18 | Loss: 0.0964 | PSNR: 14.18 dB | SSIM: 0.812 | LPIPS: 0.159


Epoch 19/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 19 | Loss: 0.0935 | PSNR: 13.87 dB | SSIM: 0.811 | LPIPS: 0.162


Epoch 20/200: 100%|██████████| 49/49 [00:53<00:00,  1.08s/it]


Epoch 20 | Loss: 0.0954 | PSNR: 13.94 dB | SSIM: 0.835 | LPIPS: 0.154


Epoch 21/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 21 | Loss: 0.0894 | PSNR: 13.72 dB | SSIM: 0.842 | LPIPS: 0.157


Epoch 22/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 22 | Loss: 0.0864 | PSNR: 13.36 dB | SSIM: 0.849 | LPIPS: 0.156


Epoch 23/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 23 | Loss: 0.0882 | PSNR: 13.19 dB | SSIM: 0.852 | LPIPS: 0.158


Epoch 24/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 24 | Loss: 0.0862 | PSNR: 13.14 dB | SSIM: 0.866 | LPIPS: 0.154


Epoch 25/200: 100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


Epoch 25 | Loss: 0.0820 | PSNR: 13.13 dB | SSIM: 0.887 | LPIPS: 0.148


Epoch 26/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 26 | Loss: 0.0846 | PSNR: 13.03 dB | SSIM: 0.892 | LPIPS: 0.150


Epoch 27/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 27 | Loss: 0.0842 | PSNR: 12.98 dB | SSIM: 0.891 | LPIPS: 0.149


Epoch 28/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 28 | Loss: 0.0842 | PSNR: 12.70 dB | SSIM: 0.908 | LPIPS: 0.147


Epoch 29/200: 100%|██████████| 49/49 [00:53<00:00,  1.08s/it]


Epoch 29 | Loss: 0.0830 | PSNR: 12.68 dB | SSIM: 0.912 | LPIPS: 0.147


Epoch 30/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 30 | Loss: 0.0790 | PSNR: 12.36 dB | SSIM: 0.933 | LPIPS: 0.139


Epoch 31/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 31 | Loss: 0.0805 | PSNR: 12.34 dB | SSIM: 0.922 | LPIPS: 0.146


Epoch 32/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 32 | Loss: 0.0777 | PSNR: 12.64 dB | SSIM: 0.927 | LPIPS: 0.141


Epoch 33/200: 100%|██████████| 49/49 [00:53<00:00,  1.08s/it]


Epoch 33 | Loss: 0.0781 | PSNR: 12.02 dB | SSIM: 0.941 | LPIPS: 0.138


Epoch 34/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 34 | Loss: 0.0789 | PSNR: 11.64 dB | SSIM: 0.945 | LPIPS: 0.145


Epoch 35/200: 100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


Epoch 35 | Loss: 0.0771 | PSNR: 11.97 dB | SSIM: 0.958 | LPIPS: 0.134


Epoch 36/200: 100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


Epoch 36 | Loss: 0.0784 | PSNR: 11.35 dB | SSIM: 0.950 | LPIPS: 0.143


Epoch 37/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 37 | Loss: 0.0721 | PSNR: 11.03 dB | SSIM: 0.959 | LPIPS: 0.133


Epoch 38/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 38 | Loss: 0.0722 | PSNR: 10.86 dB | SSIM: 0.968 | LPIPS: 0.131


Epoch 39/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 39 | Loss: 0.0676 | PSNR: 11.07 dB | SSIM: 0.981 | LPIPS: 0.128


Epoch 40/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 40 | Loss: 0.0697 | PSNR: 11.09 dB | SSIM: 0.984 | LPIPS: 0.126


Epoch 41/200: 100%|██████████| 49/49 [00:55<00:00,  1.14s/it]


Epoch 41 | Loss: 0.0693 | PSNR: 10.62 dB | SSIM: 0.987 | LPIPS: 0.125


Epoch 42/200: 100%|██████████| 49/49 [00:58<00:00,  1.20s/it]


Epoch 42 | Loss: 0.0629 | PSNR: 10.69 dB | SSIM: 0.989 | LPIPS: 0.121


Epoch 43/200: 100%|██████████| 49/49 [00:58<00:00,  1.19s/it]


Epoch 43 | Loss: 0.0675 | PSNR: 10.63 dB | SSIM: 0.988 | LPIPS: 0.123


Epoch 44/200: 100%|██████████| 49/49 [00:58<00:00,  1.19s/it]


Epoch 44 | Loss: 0.0692 | PSNR: 9.82 dB | SSIM: 0.986 | LPIPS: 0.133


Epoch 45/200: 100%|██████████| 49/49 [00:58<00:00,  1.20s/it]


Epoch 45 | Loss: 0.0628 | PSNR: 10.24 dB | SSIM: 0.989 | LPIPS: 0.123


Epoch 46/200: 100%|██████████| 49/49 [00:56<00:00,  1.15s/it]


Epoch 46 | Loss: 0.0626 | PSNR: 10.25 dB | SSIM: 0.990 | LPIPS: 0.119


Epoch 47/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 47 | Loss: 0.0618 | PSNR: 10.55 dB | SSIM: 0.990 | LPIPS: 0.117


Epoch 48/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 48 | Loss: 0.0643 | PSNR: 10.45 dB | SSIM: 0.990 | LPIPS: 0.118


Epoch 49/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 49 | Loss: 0.0631 | PSNR: 9.30 dB | SSIM: 0.990 | LPIPS: 0.121


Epoch 50/200: 100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


Epoch 50 | Loss: 0.0613 | PSNR: 9.69 dB | SSIM: 0.992 | LPIPS: 0.117


Epoch 51/200: 100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


Epoch 51 | Loss: 0.0641 | PSNR: 9.78 dB | SSIM: 0.990 | LPIPS: 0.120


Epoch 52/200: 100%|██████████| 49/49 [00:53<00:00,  1.08s/it]


Epoch 52 | Loss: 0.0644 | PSNR: 9.58 dB | SSIM: 0.990 | LPIPS: 0.118


Epoch 53/200: 100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


Epoch 53 | Loss: 0.0625 | PSNR: 9.29 dB | SSIM: 0.992 | LPIPS: 0.118


Epoch 54/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 54 | Loss: 0.0617 | PSNR: 9.34 dB | SSIM: 0.992 | LPIPS: 0.116


Epoch 55/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 55 | Loss: 0.0626 | PSNR: 9.62 dB | SSIM: 0.992 | LPIPS: 0.116


Epoch 56/200: 100%|██████████| 49/49 [00:55<00:00,  1.13s/it]


Epoch 56 | Loss: 0.0570 | PSNR: 10.12 dB | SSIM: 0.992 | LPIPS: 0.106


Epoch 57/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 57 | Loss: 0.0619 | PSNR: 9.52 dB | SSIM: 0.993 | LPIPS: 0.111


Epoch 58/200: 100%|██████████| 49/49 [00:58<00:00,  1.19s/it]


Epoch 58 | Loss: 0.0606 | PSNR: 9.73 dB | SSIM: 0.993 | LPIPS: 0.111


Epoch 59/200: 100%|██████████| 49/49 [00:58<00:00,  1.19s/it]


Epoch 59 | Loss: 0.0594 | PSNR: 9.89 dB | SSIM: 0.992 | LPIPS: 0.114


Epoch 60/200: 100%|██████████| 49/49 [01:01<00:00,  1.25s/it]


Epoch 60 | Loss: 0.0577 | PSNR: 10.06 dB | SSIM: 0.993 | LPIPS: 0.109


Epoch 61/200: 100%|██████████| 49/49 [01:01<00:00,  1.26s/it]


Epoch 61 | Loss: 0.0580 | PSNR: 9.33 dB | SSIM: 0.993 | LPIPS: 0.112


Epoch 62/200: 100%|██████████| 49/49 [00:56<00:00,  1.15s/it]


Epoch 62 | Loss: 0.0606 | PSNR: 9.37 dB | SSIM: 0.993 | LPIPS: 0.110


Epoch 63/200: 100%|██████████| 49/49 [01:00<00:00,  1.24s/it]


Epoch 63 | Loss: 0.0543 | PSNR: 9.26 dB | SSIM: 0.993 | LPIPS: 0.108


Epoch 64/200: 100%|██████████| 49/49 [00:58<00:00,  1.20s/it]


Epoch 64 | Loss: 0.0563 | PSNR: 9.51 dB | SSIM: 0.993 | LPIPS: 0.107


Epoch 65/200: 100%|██████████| 49/49 [00:52<00:00,  1.07s/it]


Epoch 65 | Loss: 0.0570 | PSNR: 9.54 dB | SSIM: 0.993 | LPIPS: 0.110


Epoch 66/200: 100%|██████████| 49/49 [00:53<00:00,  1.09s/it]


Epoch 66 | Loss: 0.0597 | PSNR: 9.61 dB | SSIM: 0.993 | LPIPS: 0.108


Epoch 67/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 67 | Loss: 0.0529 | PSNR: 9.45 dB | SSIM: 0.994 | LPIPS: 0.105


Epoch 68/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 68 | Loss: 0.0578 | PSNR: 9.64 dB | SSIM: 0.993 | LPIPS: 0.112


Epoch 69/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 69 | Loss: 0.0530 | PSNR: 9.40 dB | SSIM: 0.995 | LPIPS: 0.105


Epoch 70/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 70 | Loss: 0.0525 | PSNR: 9.33 dB | SSIM: 0.995 | LPIPS: 0.102


Epoch 71/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 71 | Loss: 0.0544 | PSNR: 9.95 dB | SSIM: 0.995 | LPIPS: 0.102


Epoch 72/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 72 | Loss: 0.0557 | PSNR: 9.51 dB | SSIM: 0.995 | LPIPS: 0.102


Epoch 73/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 73 | Loss: 0.0525 | PSNR: 9.36 dB | SSIM: 0.995 | LPIPS: 0.104


Epoch 74/200: 100%|██████████| 49/49 [00:54<00:00,  1.12s/it]


Epoch 74 | Loss: 0.0535 | PSNR: 9.66 dB | SSIM: 0.995 | LPIPS: 0.105


Epoch 75/200: 100%|██████████| 49/49 [00:57<00:00,  1.18s/it]


Epoch 75 | Loss: 0.0572 | PSNR: 9.71 dB | SSIM: 0.995 | LPIPS: 0.104


Epoch 76/200: 100%|██████████| 49/49 [00:59<00:00,  1.20s/it]


Epoch 76 | Loss: 0.0577 | PSNR: 9.07 dB | SSIM: 0.994 | LPIPS: 0.108


Epoch 77/200: 100%|██████████| 49/49 [00:58<00:00,  1.20s/it]


Epoch 77 | Loss: 0.0526 | PSNR: 9.24 dB | SSIM: 0.996 | LPIPS: 0.103


Epoch 78/200: 100%|██████████| 49/49 [00:58<00:00,  1.19s/it]


Epoch 78 | Loss: 0.0516 | PSNR: 9.02 dB | SSIM: 0.995 | LPIPS: 0.102


Epoch 79/200: 100%|██████████| 49/49 [00:59<00:00,  1.21s/it]


Epoch 79 | Loss: 0.0541 | PSNR: 9.53 dB | SSIM: 0.995 | LPIPS: 0.101


Epoch 80/200: 100%|██████████| 49/49 [00:59<00:00,  1.22s/it]


Epoch 80 | Loss: 0.0522 | PSNR: 9.51 dB | SSIM: 0.995 | LPIPS: 0.102


Epoch 81/200: 100%|██████████| 49/49 [00:59<00:00,  1.21s/it]


Epoch 81 | Loss: 0.0511 | PSNR: 9.18 dB | SSIM: 0.996 | LPIPS: 0.098


Epoch 82/200: 100%|██████████| 49/49 [00:59<00:00,  1.21s/it]


Epoch 82 | Loss: 0.0505 | PSNR: 9.36 dB | SSIM: 0.996 | LPIPS: 0.097


Epoch 83/200: 100%|██████████| 49/49 [00:59<00:00,  1.22s/it]


Epoch 83 | Loss: 0.0499 | PSNR: 9.39 dB | SSIM: 0.996 | LPIPS: 0.098


Epoch 84/200: 100%|██████████| 49/49 [00:59<00:00,  1.22s/it]


Epoch 84 | Loss: 0.0516 | PSNR: 9.31 dB | SSIM: 0.996 | LPIPS: 0.094


Epoch 85/200: 100%|██████████| 49/49 [00:58<00:00,  1.19s/it]


Epoch 85 | Loss: 0.0504 | PSNR: 9.44 dB | SSIM: 0.996 | LPIPS: 0.095


Epoch 86/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 86 | Loss: 0.0544 | PSNR: 9.44 dB | SSIM: 0.996 | LPIPS: 0.099


Epoch 87/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 87 | Loss: 0.0497 | PSNR: 8.83 dB | SSIM: 0.995 | LPIPS: 0.097


Epoch 88/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 88 | Loss: 0.0470 | PSNR: 8.85 dB | SSIM: 0.997 | LPIPS: 0.091


Epoch 89/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 89 | Loss: 0.0453 | PSNR: 9.22 dB | SSIM: 0.996 | LPIPS: 0.093


Epoch 90/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 90 | Loss: 0.0475 | PSNR: 9.11 dB | SSIM: 0.996 | LPIPS: 0.095


Epoch 91/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 91 | Loss: 0.0490 | PSNR: 8.68 dB | SSIM: 0.996 | LPIPS: 0.094


Epoch 92/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 92 | Loss: 0.0513 | PSNR: 9.02 dB | SSIM: 0.996 | LPIPS: 0.099


Epoch 93/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 93 | Loss: 0.0522 | PSNR: 8.69 dB | SSIM: 0.996 | LPIPS: 0.096


Epoch 94/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 94 | Loss: 0.0502 | PSNR: 8.67 dB | SSIM: 0.996 | LPIPS: 0.094


Epoch 95/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 95 | Loss: 0.0481 | PSNR: 9.08 dB | SSIM: 0.997 | LPIPS: 0.093


Epoch 96/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 96 | Loss: 0.0487 | PSNR: 9.08 dB | SSIM: 0.997 | LPIPS: 0.095


Epoch 97/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 97 | Loss: 0.0466 | PSNR: 9.01 dB | SSIM: 0.997 | LPIPS: 0.092


Epoch 98/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 98 | Loss: 0.0457 | PSNR: 9.12 dB | SSIM: 0.997 | LPIPS: 0.086


Epoch 99/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 99 | Loss: 0.0458 | PSNR: 9.21 dB | SSIM: 0.997 | LPIPS: 0.090


Epoch 100/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 100 | Loss: 0.0461 | PSNR: 8.47 dB | SSIM: 0.997 | LPIPS: 0.092


Epoch 101/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 101 | Loss: 0.0459 | PSNR: 9.04 dB | SSIM: 0.997 | LPIPS: 0.088


Epoch 102/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 102 | Loss: 0.0475 | PSNR: 8.79 dB | SSIM: 0.997 | LPIPS: 0.090


Epoch 103/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 103 | Loss: 0.0445 | PSNR: 8.83 dB | SSIM: 0.997 | LPIPS: 0.089


Epoch 104/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 104 | Loss: 0.0488 | PSNR: 8.80 dB | SSIM: 0.996 | LPIPS: 0.093


Epoch 105/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 105 | Loss: 0.0456 | PSNR: 8.84 dB | SSIM: 0.997 | LPIPS: 0.091


Epoch 106/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 106 | Loss: 0.0428 | PSNR: 8.94 dB | SSIM: 0.997 | LPIPS: 0.086


Epoch 107/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 107 | Loss: 0.0454 | PSNR: 8.95 dB | SSIM: 0.997 | LPIPS: 0.088


Epoch 108/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 108 | Loss: 0.0446 | PSNR: 8.36 dB | SSIM: 0.997 | LPIPS: 0.089


Epoch 109/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 109 | Loss: 0.0446 | PSNR: 8.92 dB | SSIM: 0.997 | LPIPS: 0.083


Epoch 110/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 110 | Loss: 0.0434 | PSNR: 8.69 dB | SSIM: 0.997 | LPIPS: 0.086


Epoch 111/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 111 | Loss: 0.0462 | PSNR: 8.91 dB | SSIM: 0.997 | LPIPS: 0.087


Epoch 112/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 112 | Loss: 0.0489 | PSNR: 8.77 dB | SSIM: 0.997 | LPIPS: 0.092


Epoch 113/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 113 | Loss: 0.0442 | PSNR: 8.42 dB | SSIM: 0.997 | LPIPS: 0.090


Epoch 114/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 114 | Loss: 0.0474 | PSNR: 8.52 dB | SSIM: 0.997 | LPIPS: 0.090


Epoch 115/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 115 | Loss: 0.0490 | PSNR: 8.57 dB | SSIM: 0.997 | LPIPS: 0.089


Epoch 116/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 116 | Loss: 0.0433 | PSNR: 8.73 dB | SSIM: 0.998 | LPIPS: 0.086


Epoch 117/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 117 | Loss: 0.0427 | PSNR: 8.41 dB | SSIM: 0.998 | LPIPS: 0.084


Epoch 118/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 118 | Loss: 0.0454 | PSNR: 8.29 dB | SSIM: 0.997 | LPIPS: 0.087


Epoch 119/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 119 | Loss: 0.0418 | PSNR: 8.85 dB | SSIM: 0.998 | LPIPS: 0.086


Epoch 120/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 120 | Loss: 0.0438 | PSNR: 8.28 dB | SSIM: 0.997 | LPIPS: 0.086


Epoch 121/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 121 | Loss: 0.0411 | PSNR: 8.32 dB | SSIM: 0.998 | LPIPS: 0.084


Epoch 122/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 122 | Loss: 0.0424 | PSNR: 8.43 dB | SSIM: 0.997 | LPIPS: 0.087


Epoch 123/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 123 | Loss: 0.0452 | PSNR: 8.36 dB | SSIM: 0.997 | LPIPS: 0.088


Epoch 124/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 124 | Loss: 0.0409 | PSNR: 8.55 dB | SSIM: 0.998 | LPIPS: 0.084


Epoch 125/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 125 | Loss: 0.0463 | PSNR: 8.54 dB | SSIM: 0.997 | LPIPS: 0.085


Epoch 126/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 126 | Loss: 0.0414 | PSNR: 8.52 dB | SSIM: 0.998 | LPIPS: 0.082


Epoch 127/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 127 | Loss: 0.0412 | PSNR: 8.43 dB | SSIM: 0.998 | LPIPS: 0.082


Epoch 128/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 128 | Loss: 0.0429 | PSNR: 8.23 dB | SSIM: 0.997 | LPIPS: 0.084


Epoch 129/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 129 | Loss: 0.0426 | PSNR: 8.44 dB | SSIM: 0.998 | LPIPS: 0.081


Epoch 130/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 130 | Loss: 0.0438 | PSNR: 7.95 dB | SSIM: 0.998 | LPIPS: 0.084


Epoch 131/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 131 | Loss: 0.0427 | PSNR: 8.32 dB | SSIM: 0.997 | LPIPS: 0.084


Epoch 132/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 132 | Loss: 0.0413 | PSNR: 8.65 dB | SSIM: 0.998 | LPIPS: 0.082


Epoch 133/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 133 | Loss: 0.0423 | PSNR: 8.39 dB | SSIM: 0.998 | LPIPS: 0.081


Epoch 134/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 134 | Loss: 0.0411 | PSNR: 8.04 dB | SSIM: 0.998 | LPIPS: 0.084


Epoch 135/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 135 | Loss: 0.0405 | PSNR: 8.75 dB | SSIM: 0.998 | LPIPS: 0.078


Epoch 136/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 136 | Loss: 0.0416 | PSNR: 8.25 dB | SSIM: 0.998 | LPIPS: 0.080


Epoch 137/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 137 | Loss: 0.0404 | PSNR: 8.38 dB | SSIM: 0.998 | LPIPS: 0.083


Epoch 138/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 138 | Loss: 0.0394 | PSNR: 8.29 dB | SSIM: 0.998 | LPIPS: 0.077


Epoch 139/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 139 | Loss: 0.0411 | PSNR: 8.36 dB | SSIM: 0.998 | LPIPS: 0.081


Epoch 140/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 140 | Loss: 0.0422 | PSNR: 7.92 dB | SSIM: 0.998 | LPIPS: 0.080


Epoch 141/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 141 | Loss: 0.0405 | PSNR: 8.14 dB | SSIM: 0.998 | LPIPS: 0.078


Epoch 142/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 142 | Loss: 0.0385 | PSNR: 8.28 dB | SSIM: 0.998 | LPIPS: 0.077


Epoch 143/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 143 | Loss: 0.0393 | PSNR: 8.36 dB | SSIM: 0.998 | LPIPS: 0.077


Epoch 144/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 144 | Loss: 0.0391 | PSNR: 8.44 dB | SSIM: 0.998 | LPIPS: 0.077


Epoch 145/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 145 | Loss: 0.0390 | PSNR: 8.26 dB | SSIM: 0.998 | LPIPS: 0.076


Epoch 146/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 146 | Loss: 0.0388 | PSNR: 8.39 dB | SSIM: 0.998 | LPIPS: 0.079


Epoch 147/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 147 | Loss: 0.0388 | PSNR: 8.74 dB | SSIM: 0.998 | LPIPS: 0.076


Epoch 148/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 148 | Loss: 0.0392 | PSNR: 8.28 dB | SSIM: 0.998 | LPIPS: 0.079


Epoch 149/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 149 | Loss: 0.0375 | PSNR: 8.28 dB | SSIM: 0.998 | LPIPS: 0.073


Epoch 150/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 150 | Loss: 0.0386 | PSNR: 8.64 dB | SSIM: 0.998 | LPIPS: 0.074


Epoch 151/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 151 | Loss: 0.0361 | PSNR: 8.07 dB | SSIM: 0.998 | LPIPS: 0.076


Epoch 152/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 152 | Loss: 0.0399 | PSNR: 7.95 dB | SSIM: 0.998 | LPIPS: 0.081


Epoch 153/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 153 | Loss: 0.0405 | PSNR: 8.31 dB | SSIM: 0.998 | LPIPS: 0.081


Epoch 154/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 154 | Loss: 0.0381 | PSNR: 8.22 dB | SSIM: 0.998 | LPIPS: 0.077


Epoch 155/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 155 | Loss: 0.0378 | PSNR: 8.03 dB | SSIM: 0.998 | LPIPS: 0.077


Epoch 156/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 156 | Loss: 0.0377 | PSNR: 8.42 dB | SSIM: 0.998 | LPIPS: 0.078


Epoch 157/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 157 | Loss: 0.0394 | PSNR: 7.70 dB | SSIM: 0.998 | LPIPS: 0.078


Epoch 158/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 158 | Loss: 0.0395 | PSNR: 8.49 dB | SSIM: 0.998 | LPIPS: 0.077


Epoch 159/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 159 | Loss: 0.0366 | PSNR: 7.83 dB | SSIM: 0.998 | LPIPS: 0.075


Epoch 160/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 160 | Loss: 0.0373 | PSNR: 8.07 dB | SSIM: 0.998 | LPIPS: 0.077


Epoch 161/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 161 | Loss: 0.0369 | PSNR: 8.02 dB | SSIM: 0.998 | LPIPS: 0.074


Epoch 162/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 162 | Loss: 0.0362 | PSNR: 8.08 dB | SSIM: 0.998 | LPIPS: 0.072


Epoch 163/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 163 | Loss: 0.0397 | PSNR: 8.03 dB | SSIM: 0.998 | LPIPS: 0.080


Epoch 164/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 164 | Loss: 0.0342 | PSNR: 8.23 dB | SSIM: 0.998 | LPIPS: 0.070


Epoch 165/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 165 | Loss: 0.0354 | PSNR: 8.22 dB | SSIM: 0.998 | LPIPS: 0.074


Epoch 166/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 166 | Loss: 0.0405 | PSNR: 7.48 dB | SSIM: 0.998 | LPIPS: 0.083


Epoch 167/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 167 | Loss: 0.0364 | PSNR: 7.80 dB | SSIM: 0.998 | LPIPS: 0.074


Epoch 168/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 168 | Loss: 0.0376 | PSNR: 8.10 dB | SSIM: 0.998 | LPIPS: 0.075


Epoch 169/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 169 | Loss: 0.0373 | PSNR: 7.87 dB | SSIM: 0.998 | LPIPS: 0.075


Epoch 170/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 170 | Loss: 0.0359 | PSNR: 7.78 dB | SSIM: 0.999 | LPIPS: 0.073


Epoch 171/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 171 | Loss: 0.0357 | PSNR: 7.82 dB | SSIM: 0.998 | LPIPS: 0.074


Epoch 172/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 172 | Loss: 0.0382 | PSNR: 8.06 dB | SSIM: 0.998 | LPIPS: 0.075


Epoch 173/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 173 | Loss: 0.0365 | PSNR: 8.16 dB | SSIM: 0.998 | LPIPS: 0.075


Epoch 174/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 174 | Loss: 0.0350 | PSNR: 7.95 dB | SSIM: 0.998 | LPIPS: 0.072


Epoch 175/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 175 | Loss: 0.0345 | PSNR: 7.78 dB | SSIM: 0.998 | LPIPS: 0.070


Epoch 176/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 176 | Loss: 0.0347 | PSNR: 7.91 dB | SSIM: 0.998 | LPIPS: 0.072


Epoch 177/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 177 | Loss: 0.0362 | PSNR: 7.87 dB | SSIM: 0.998 | LPIPS: 0.073


Epoch 178/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 178 | Loss: 0.0329 | PSNR: 8.24 dB | SSIM: 0.998 | LPIPS: 0.068


Epoch 179/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 179 | Loss: 0.0347 | PSNR: 7.88 dB | SSIM: 0.998 | LPIPS: 0.070


Epoch 180/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 180 | Loss: 0.0367 | PSNR: 7.65 dB | SSIM: 0.998 | LPIPS: 0.075


Epoch 181/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 181 | Loss: 0.0370 | PSNR: 7.69 dB | SSIM: 0.998 | LPIPS: 0.074


Epoch 182/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 182 | Loss: 0.0362 | PSNR: 8.03 dB | SSIM: 0.998 | LPIPS: 0.073


Epoch 183/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 183 | Loss: 0.0353 | PSNR: 7.33 dB | SSIM: 0.999 | LPIPS: 0.071


Epoch 184/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 184 | Loss: 0.0339 | PSNR: 7.99 dB | SSIM: 0.999 | LPIPS: 0.069


Epoch 185/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 185 | Loss: 0.0345 | PSNR: 8.08 dB | SSIM: 0.999 | LPIPS: 0.070


Epoch 186/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 186 | Loss: 0.0363 | PSNR: 7.78 dB | SSIM: 0.998 | LPIPS: 0.074


Epoch 187/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 187 | Loss: 0.0359 | PSNR: 7.50 dB | SSIM: 0.999 | LPIPS: 0.074


Epoch 188/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 188 | Loss: 0.0348 | PSNR: 7.81 dB | SSIM: 0.999 | LPIPS: 0.072


Epoch 189/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 189 | Loss: 0.0338 | PSNR: 7.86 dB | SSIM: 0.999 | LPIPS: 0.070


Epoch 190/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 190 | Loss: 0.0321 | PSNR: 7.73 dB | SSIM: 0.999 | LPIPS: 0.066


Epoch 191/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 191 | Loss: 0.0318 | PSNR: 7.81 dB | SSIM: 0.999 | LPIPS: 0.067


Epoch 192/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 192 | Loss: 0.0351 | PSNR: 8.00 dB | SSIM: 0.998 | LPIPS: 0.070


Epoch 193/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 193 | Loss: 0.0348 | PSNR: 7.54 dB | SSIM: 0.999 | LPIPS: 0.069


Epoch 194/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 194 | Loss: 0.0350 | PSNR: 7.76 dB | SSIM: 0.998 | LPIPS: 0.073


Epoch 195/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 195 | Loss: 0.0365 | PSNR: 7.63 dB | SSIM: 0.998 | LPIPS: 0.075


Epoch 196/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 196 | Loss: 0.0364 | PSNR: 7.54 dB | SSIM: 0.998 | LPIPS: 0.072


Epoch 197/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 197 | Loss: 0.0341 | PSNR: 7.30 dB | SSIM: 0.999 | LPIPS: 0.071


Epoch 198/200: 100%|██████████| 49/49 [00:50<00:00,  1.03s/it]


Epoch 198 | Loss: 0.0353 | PSNR: 7.61 dB | SSIM: 0.998 | LPIPS: 0.073


Epoch 199/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 199 | Loss: 0.0346 | PSNR: 6.90 dB | SSIM: 0.999 | LPIPS: 0.071


Epoch 200/200: 100%|██████████| 49/49 [00:50<00:00,  1.04s/it]


Epoch 200 | Loss: 0.0353 | PSNR: 8.01 dB | SSIM: 0.999 | LPIPS: 0.073

Training complete! Best metrics: Loss: 0.0318, PSNR: 7.81 dB, SSIM: 0.999, LPIPS: 0.067
Total training time: 3h 1m 45.27s


In [17]:
def evaluate_diffusion(denoise_unet, decoder, test_loader, device, diffusion, save_samples=True, sample_dir="experiment2_diffusion_samples_final"):
    # Initialize metrics
    psnr = PSNR().to(device)
    ssim = SSIM().to(device)
    lpips = LPIPS(replace_pooling=True).to(device)
    
    metrics = {
        'per_image': [],  # List to store per-image results
        'average': {
            'psnr': 0.0,
            'ssim': 0.0,
            'lpips': 0.0
        }
    }
    
    if save_samples:
        os.makedirs(sample_dir, exist_ok=True)
    
    denoise_unet.eval()
    decoder.eval()
    sample_counter = 0
    
    with torch.no_grad():
        for batch_idx, (low_imgs, high_imgs, *_) in enumerate(tqdm(test_loader, desc="Evaluating")):
            low_imgs, high_imgs = low_imgs.to(device), high_imgs.to(device)
            
            # Use t=0 (final timestep) for evaluation
            t = torch.zeros(low_imgs.size(0), dtype=torch.long, device=device)
            
            # Forward diffusion process
            xt, _ = diffusion.forward_diffuse(high_imgs, t)
            
            # Denoise and decode
            pred_noise = denoise_unet(xt, t)
            enhanced_imgs = decoder(xt - pred_noise, low_imgs)
            
            # Clamp outputs to valid range
            enhanced_imgs = enhanced_imgs.clamp(0, 1)
            high_imgs = high_imgs.clamp(0, 1)
            
            # Process each image in the batch individually
            for img_idx in range(low_imgs.size(0)):
                # Calculate metrics for this single image
                img_metrics = {
                    'psnr': psnr(enhanced_imgs[img_idx].unsqueeze(0), high_imgs[img_idx].unsqueeze(0)).item(),
                    'ssim': ssim(enhanced_imgs[img_idx].unsqueeze(0), high_imgs[img_idx].unsqueeze(0)).item(),
                    'lpips': lpips(enhanced_imgs[img_idx].unsqueeze(0), high_imgs[img_idx].unsqueeze(0)).item()
                }
                
                # Create comparison grid (C,H,W)
                comparison = torch.cat([
                    low_imgs[img_idx].clamp(0, 1),  # Low-light
                    enhanced_imgs[img_idx],         # Enhanced
                    high_imgs[img_idx]              # Ground truth
                ], dim=-1)
                
                # Convert to PIL
                comparison_pil = transforms.ToPILImage()(comparison.cpu())
                
                if save_samples and sample_counter < 15:
                    # Add labels to the image
                    draw = ImageDraw.Draw(comparison_pil)
                    try:
                        font = ImageFont.truetype("arial.ttf", 15)
                    except:
                        font = ImageFont.load_default()
                    
                    width = comparison_pil.width
                    draw.text((10, 10), "Low Light", fill="white", font=font)
                    draw.text((width//3 + 10, 10), "Enhanced", fill="white", font=font)
                    draw.text((2*width//3 + 10, 10), "Ground Truth", fill="white", font=font)
                    
                    # Add metric text below each section
                    text_y = comparison_pil.height - 30
                    draw.text((10, text_y), f"PSNR: -", fill="white", font=font)  # Low-light has no PSNR
                    draw.text((width//3 + 10, text_y), 
                             f"PSNR: {img_metrics['psnr']:.2f}\nSSIM: {img_metrics['ssim']:.3f}\nLPIPS: {img_metrics['lpips']:.3f}", 
                             fill="white", font=font)
                    draw.text((2*width//3 + 10, text_y), 
                             f"PSNR: ∞\nSSIM: 1.000\nLPIPS: 0.000", 
                             fill="white", font=font)
                    
                    # Save sample
                    sample_path = os.path.join(sample_dir, f"sample_{sample_counter:02d}.png")
                    comparison_pil.save(sample_path)
                    img_metrics['image_path'] = sample_path
                    sample_counter += 1
                
                # Store per-image results
                metrics['per_image'].append(img_metrics)
                
                # Accumulate for averages
                metrics['average']['psnr'] += img_metrics['psnr']
                metrics['average']['ssim'] += img_metrics['ssim']
                metrics['average']['lpips'] += img_metrics['lpips']
    
    # Calculate final averages
    total_samples = len(metrics['per_image'])
    metrics['average']['psnr'] /= total_samples
    metrics['average']['ssim'] /= total_samples
    metrics['average']['lpips'] /= total_samples
    
    print(f"\nEvaluation Results (Averages):")
    print(f"PSNR: {metrics['average']['psnr']:.2f} dB")
    print(f"SSIM: {metrics['average']['ssim']:.4f}")
    print(f"LPIPS: {metrics['average']['lpips']:.4f}")
    
    return metrics

metrics = evaluate_diffusion(
    denoise_unet=diffusion.denoise_unet,
    decoder=diffusion.decoder,
    test_loader=test_loader,
    device=device,
    diffusion=diffusion,
    save_samples=True
)

# Process all samples
print("\nDetailed Results for All Samples:")
for idx, img_result in enumerate(metrics['per_image']):
    print(f"\nSample {idx + 1}:")
    print(f"  PSNR: {img_result['psnr']:.2f} dB")
    print(f"  SSIM: {img_result['ssim']:.4f}")
    print(f"  LPIPS: {img_result['lpips']:.4f}")
    if 'image_path' in img_result:
        print(f"  Visualization saved at: {img_result['image_path']}")

# Print averages
print("\nAverage Metrics Across All Samples:")
print(f"PSNR: {metrics['average']['psnr']:.2f} dB")
print(f"SSIM: {metrics['average']['ssim']:.4f}")
print(f"LPIPS: {metrics['average']['lpips']:.4f}")

# Save full metrics with timestamp
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
metrics_filename = f'diffusion_metrics_{timestamp}.pth'
torch.save(metrics, metrics_filename)
print(f"\nAll metrics saved to: {metrics_filename}")

Evaluating: 100%|██████████| 15/15 [00:06<00:00,  2.47it/s]


Evaluation Results (Averages):
PSNR: 24.31 dB
SSIM: 0.8772
LPIPS: 0.1053

Detailed Results for All Samples:

Sample 1:
  PSNR: 22.06 dB
  SSIM: 0.8629
  LPIPS: 0.1336
  Visualization saved at: experiment2_diffusion_samples_final/sample_00.png

Sample 2:
  PSNR: 20.87 dB
  SSIM: 0.8442
  LPIPS: 0.1023
  Visualization saved at: experiment2_diffusion_samples_final/sample_01.png

Sample 3:
  PSNR: 20.02 dB
  SSIM: 0.8876
  LPIPS: 0.0920
  Visualization saved at: experiment2_diffusion_samples_final/sample_02.png

Sample 4:
  PSNR: 19.74 dB
  SSIM: 0.8932
  LPIPS: 0.1201
  Visualization saved at: experiment2_diffusion_samples_final/sample_03.png

Sample 5:
  PSNR: 27.36 dB
  SSIM: 0.9235
  LPIPS: 0.0791
  Visualization saved at: experiment2_diffusion_samples_final/sample_04.png

Sample 6:
  PSNR: 25.83 dB
  SSIM: 0.9006
  LPIPS: 0.1200
  Visualization saved at: experiment2_diffusion_samples_final/sample_05.png

Sample 7:
  PSNR: 23.39 dB
  SSIM: 0.8523
  LPIPS: 0.1209
  Visualization saved 




In [27]:
import os
import torch
from torchvision import transforms
from PIL import ImageDraw, ImageFont
from tqdm import tqdm

def evaluate_diffusion(denoise_unet, decoder, test_loader, device, diffusion, save_samples=True, sample_dir="sample"):
    # Initialize metrics
    psnr = PSNR().to(device)
    ssim = SSIM().to(device)
    lpips = LPIPS(replace_pooling=True).to(device)
    
    metrics = {
        'per_image': [],
        'average': {
            'psnr': 0.0,
            'ssim': 0.0,
            'lpips': 0.0
        }
    }
    
    if save_samples:
        os.makedirs(sample_dir, exist_ok=True)
    
    denoise_unet.eval()
    decoder.eval()
    sample_counter = 0
    
    with torch.no_grad():
        for batch_idx, (low_imgs, high_imgs, *_) in enumerate(tqdm(test_loader, desc="Evaluating")):
            low_imgs, high_imgs = low_imgs.to(device), high_imgs.to(device)
            
            # Use t=0 (final timestep) for evaluation
            t = torch.zeros(low_imgs.size(0), dtype=torch.long, device=device)
            
            # Forward diffusion process
            xt, _ = diffusion.forward_diffuse(high_imgs, t)
            
            # Denoise and decode
            pred_noise = denoise_unet(xt, t)
            enhanced_imgs = decoder(xt - pred_noise, low_imgs)
            
            # Clamp outputs to valid range
            enhanced_imgs = enhanced_imgs.clamp(0, 1)
            high_imgs = high_imgs.clamp(0, 1)
            
            for img_idx in range(low_imgs.size(0)):
                # Calculate metrics
                img_metrics = {
                    'psnr': psnr(enhanced_imgs[img_idx].unsqueeze(0), high_imgs[img_idx].unsqueeze(0)).item(),
                    'ssim': ssim(enhanced_imgs[img_idx].unsqueeze(0), high_imgs[img_idx].unsqueeze(0)).item(),
                    'lpips': lpips(enhanced_imgs[img_idx].unsqueeze(0), high_imgs[img_idx].unsqueeze(0)).item()
                }
                
                # Denormalize low-light image for visualization
                low_img_denorm = (low_imgs[img_idx] * 0.5) + 0.5
                low_img_denorm = low_img_denorm.clamp(0, 1)
                
                # Create a 4-column comparison image (C, H, W*4)
                comparison = torch.cat([
                    low_imgs[img_idx].clamp(0, 1),  # Low-light (normalized)
                    low_img_denorm,                 # Low-light (denormalized)
                    enhanced_imgs[img_idx],         # Enhanced
                    high_imgs[img_idx]              # Ground truth
                ], dim=-1)
                
                # Convert to PIL
                comparison_pil = transforms.ToPILImage()(comparison.cpu())
                
                if save_samples:
                    # Add labels
                    draw = ImageDraw.Draw(comparison_pil)
                    try:
                        font = ImageFont.truetype("arial.ttf", 15)
                    except:
                        font = ImageFont.load_default()
                    
                    width = comparison_pil.width // 4
                    draw.text((10, 10), "Low Light\n(Normalized)", fill="white", font=font)
                    draw.text((width + 10, 10), "Low Light\n(Denormalized)", fill="white", font=font)
                    draw.text((2 * width + 10, 10), "Enhanced", fill="white", font=font)
                    draw.text((3 * width + 10, 10), "Ground Truth", fill="white", font=font)

                    # Metrics under Enhanced
                    text_y = comparison_pil.height - 40
                    draw.text((2 * width + 10, text_y), 
                              f"PSNR: {img_metrics['psnr']:.2f}\nSSIM: {img_metrics['ssim']:.3f}\nLPIPS: {img_metrics['lpips']:.3f}", 
                              fill="white", font=font)

                    # Save image
                    sample_path = os.path.join(sample_dir, f"sample_{sample_counter:04d}.png")
                    comparison_pil.save(sample_path)
                    img_metrics['image_path'] = sample_path
                    sample_counter += 1
                
                # Store metrics
                metrics['per_image'].append(img_metrics)
                metrics['average']['psnr'] += img_metrics['psnr']
                metrics['average']['ssim'] += img_metrics['ssim']
                metrics['average']['lpips'] += img_metrics['lpips']
    
    # Compute averages
    total_samples = len(metrics['per_image'])
    metrics['average']['psnr'] /= total_samples
    metrics['average']['ssim'] /= total_samples
    metrics['average']['lpips'] /= total_samples
    
    print(f"\nEvaluation Results (Averages):")
    print(f"PSNR: {metrics['average']['psnr']:.2f} dB")
    print(f"SSIM: {metrics['average']['ssim']:.4f}")
    print(f"LPIPS: {metrics['average']['lpips']:.4f}")
    
    return metrics

# Initialize model
diffusion = LatentDiffusionModel(device=device)

# Load checkpoint
checkpoint = torch.load('checkpoints/diffusion_epoch_200.pth', map_location=device)

# Load state_dict into model components
diffusion.denoise_unet.load_state_dict(checkpoint['denoise_unet'])
diffusion.decoder.load_state_dict(checkpoint['decoder'])

# Evaluate
metrics = evaluate_diffusion(
    denoise_unet=diffusion.denoise_unet,
    decoder=diffusion.decoder,
    test_loader=test_loader,  # Your test DataLoader
    device=device,
    diffusion=diffusion,      # Already has forward_diffuse()
    save_samples=True,         # Save all comparison images
    sample_dir="diffusion_sample"
)

Evaluating: 100%|██████████| 15/15 [00:08<00:00,  1.69it/s]


Evaluation Results (Averages):
PSNR: 24.30 dB
SSIM: 0.8772
LPIPS: 0.1054





In [28]:
import os
import torch
from torchvision import transforms
from PIL import ImageDraw, ImageFont
from tqdm import tqdm
from torchmetrics import PeakSignalNoiseRatio as PSNR
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
from piq import LPIPS

def denormalize(img):
    """Convert image from [-1, 1] to [0, 1] for visualization"""
    return img * 0.5 + 0.5

def evaluate_diffusion(denoise_unet, decoder, test_loader, device, diffusion, save_samples=True, sample_dir="diffusion_normalized_result"):
    # Initialize metrics
    psnr = PSNR().to(device)
    ssim = SSIM().to(device)
    lpips = LPIPS(replace_pooling=True).to(device)

    metrics = {
        'per_image': [],
        'average': {
            'psnr': 0.0,
            'ssim': 0.0,
            'lpips': 0.0
        }
    }

    if save_samples:
        os.makedirs(sample_dir, exist_ok=True)

    denoise_unet.eval()
    decoder.eval()
    sample_counter = 0

    with torch.no_grad():
        for batch_idx, (low_imgs, high_imgs, *_) in enumerate(tqdm(test_loader, desc="Evaluating")):
            low_imgs, high_imgs = low_imgs.to(device), high_imgs.to(device)

            # Use t=0 (final timestep) for evaluation
            t = torch.zeros(low_imgs.size(0), dtype=torch.long, device=device)

            # Forward diffusion process
            xt, _ = diffusion.forward_diffuse(high_imgs, t)

            # Denoise and decode
            pred_noise = denoise_unet(xt, t)
            enhanced_imgs = decoder(xt - pred_noise, low_imgs)

            # Clamp all images to [-1, 1] range first
            low_imgs = low_imgs.clamp(-1, 1)
            enhanced_imgs = enhanced_imgs.clamp(-1, 1)
            high_imgs = high_imgs.clamp(-1, 1)

            for img_idx in range(low_imgs.size(0)):
                # Denormalize for metric computation
                enhanced = denormalize(enhanced_imgs[img_idx])
                high = denormalize(high_imgs[img_idx])

                # Calculate metrics
                img_metrics = {
                    'psnr': psnr(enhanced.unsqueeze(0), high.unsqueeze(0)).item(),
                    'ssim': ssim(enhanced.unsqueeze(0), high.unsqueeze(0)).item(),
                    'lpips': lpips(enhanced.unsqueeze(0), high.unsqueeze(0)).item()
                }

                # Denormalize all images for visualization
                low_img_norm = denormalize(low_imgs[img_idx])               # Low-light (normalized)
                low_img_denorm = low_img_norm.clone()                       # Low-light (denormalized)
                enhanced_img_vis = denormalize(enhanced_imgs[img_idx])      # Enhanced
                high_img_vis = denormalize(high_imgs[img_idx])              # Ground Truth

                # Create a 4-column comparison image (C, H, W*4)
                comparison = torch.cat([
                    low_imgs[img_idx].clamp(0, 1),  # Low-light (still normalized for first col)
                    low_img_denorm.clamp(0, 1),     # Low-light (denormalized)
                    enhanced_img_vis.clamp(0, 1),   # Enhanced
                    high_img_vis.clamp(0, 1)        # Ground truth
                ], dim=-1)

                # Convert to PIL
                comparison_pil = transforms.ToPILImage()(comparison.cpu())

                if save_samples:
                    draw = ImageDraw.Draw(comparison_pil)
                    try:
                        font = ImageFont.truetype("arial.ttf", 15)
                    except:
                        font = ImageFont.load_default()

                    width = comparison_pil.width // 4
                    draw.text((10, 10), "Low Light\n(Normalized)", fill="white", font=font)
                    draw.text((width + 10, 10), "Low Light\n(Denormalized)", fill="white", font=font)
                    draw.text((2 * width + 10, 10), "Enhanced", fill="white", font=font)
                    draw.text((3 * width + 10, 10), "Ground Truth", fill="white", font=font)

                    # Metrics under Enhanced
                    text_y = comparison_pil.height - 40
                    draw.text((2 * width + 10, text_y),
                              f"PSNR: {img_metrics['psnr']:.2f}\nSSIM: {img_metrics['ssim']:.3f}\nLPIPS: {img_metrics['lpips']:.3f}",
                              fill="white", font=font)

                    # Save image
                    sample_path = os.path.join(sample_dir, f"sample_{sample_counter:04d}.png")
                    comparison_pil.save(sample_path)
                    img_metrics['image_path'] = sample_path
                    sample_counter += 1

                # Store metrics
                metrics['per_image'].append(img_metrics)
                metrics['average']['psnr'] += img_metrics['psnr']
                metrics['average']['ssim'] += img_metrics['ssim']
                metrics['average']['lpips'] += img_metrics['lpips']

    # Compute averages
    total_samples = len(metrics['per_image'])
    metrics['average']['psnr'] /= total_samples
    metrics['average']['ssim'] /= total_samples
    metrics['average']['lpips'] /= total_samples

    print(f"\nEvaluation Results (Averages):")
    print(f"PSNR: {metrics['average']['psnr']:.2f} dB")
    print(f"SSIM: {metrics['average']['ssim']:.4f}")
    print(f"LPIPS: {metrics['average']['lpips']:.4f}")

    return metrics

# Initialize model
diffusion = LatentDiffusionModel(device=device)

# Load checkpoint
checkpoint = torch.load('checkpoints/diffusion_epoch_190.pth', map_location=device)

# Load state_dict into model components
diffusion.denoise_unet.load_state_dict(checkpoint['denoise_unet'])
diffusion.decoder.load_state_dict(checkpoint['decoder'])

# Evaluate
metrics = evaluate_diffusion(
    denoise_unet=diffusion.denoise_unet,
    decoder=diffusion.decoder,
    test_loader=test_loader,  # Your test DataLoader
    device=device,
    diffusion=diffusion,      # Already has forward_diffuse()
    save_samples=True         # Save all comparison images
)

Evaluating: 100%|██████████| 15/15 [00:12<00:00,  1.23it/s]


Evaluation Results (Averages):
PSNR: 24.85 dB
SSIM: 0.8796
LPIPS: 0.1200



