In [1]:
import os
import glob
import random
import math
import time
from PIL import Image, UnidentifiedImageError
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import torchvision.transforms.functional as TF
from torchvision import transforms, models

torch.backends.cudnn.benchmark = True

In [2]:
!nvidia-smi

Sat Nov  1 08:29:05 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.06             Driver Version: 535.183.06   CUDA Version: 12.6     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off | 00000000:4E:00.0 Off |                   On |
| N/A   38C    P0              91W / 400W |                  N/A |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [3]:
# Root path = inner dataset folder
root = os.path.join("dataset_1-6-2024", "dataset_1-6-2024")
rgb_dir     = os.path.join(root, "class1")  # RGB
thermal_dir = os.path.join(root, "class2")  # Thermal

def list_images(d):
    exts = ("*.jpg","*.jpeg","*.png","*.bmp","*.tif","*.tiff")
    files = []
    for e in exts:
        files.extend(glob.glob(os.path.join(d, e)))
    return sorted(files)

rgb_files     = list_images(rgb_dir)
thermal_files = list_images(thermal_dir)
print(f"Found RGB: {len(rgb_files)} | Thermal: {len(thermal_files)}")

# Direct pairing (no skip/verification) ‚Äî keep original pairing logic
n = min(len(rgb_files), len(thermal_files))
rgb_files = rgb_files[:n]
thermal_files = thermal_files[:n]
print(f"‚úÖ Using pairs: {len(thermal_files)}")

Found RGB: 148828 | Thermal: 148828
‚úÖ Using pairs: 148828


In [4]:
class ThermalRGBDataset(Dataset):
    """
    - Reads paired thermal and RGB images.
    - Applies synchronized augmentations for better training.
    - Normalizes images to the [-1, 1] range.
    """

    def __init__(self, thermal_files, rgb_files, image_size=256, resize_to=286, augment=True):
        n = min(len(thermal_files), len(rgb_files))
        self.t_files = list(thermal_files[:n])
        self.r_files = list(rgb_files[:n])
        self.img_size = image_size
        self.resize_to = resize_to
        self.augment = augment

    def __len__(self):
        return len(self.t_files)

    def _normalize(self, t, channels):
        return TF.normalize(t, mean=[0.5] * channels, std=[0.5] * channels)

    def __getitem__(self, idx):
        t_img = Image.open(self.t_files[idx]).convert("L")
        r_img = Image.open(self.r_files[idx]).convert("RGB")

        t_img = TF.resize(t_img, [self.resize_to, self.resize_to], antialias=True)
        r_img = TF.resize(r_img, [self.resize_to, self.resize_to], antialias=True)

        if self.augment:
            i, j, h, w = transforms.RandomCrop.get_params(r_img, output_size=(self.img_size, self.img_size))
            t_img = TF.crop(t_img, i, j, h, w)
            r_img = TF.crop(r_img, i, j, h, w)

            if random.random() < 0.5:
                t_img = TF.hflip(t_img)
                r_img = TF.hflip(r_img)

        else:
            t_img = TF.center_crop(t_img, [self.img_size, self.img_size])
            r_img = TF.center_crop(r_img, [self.img_size, self.img_size])

        t = TF.to_tensor(t_img)
        r = TF.to_tensor(r_img)

        t = self._normalize(t, 1)
        r = self._normalize(r, 3)

        return t, r

In [5]:
import torch
import multiprocessing

dataset = ThermalRGBDataset(thermal_files, rgb_files, image_size=256, resize_to=286, augment=True)

split_file = "train_val_split.pt"
recreate_split = not os.path.exists(split_file)

if not recreate_split:
    try:
        data = torch.load(split_file)
        train_idx, val_idx = data["train"], data["val"]
        if max(train_idx + val_idx) >= len(dataset) or len(train_idx) + len(val_idx) != len(dataset):
            recreate_split = True
    except:
        recreate_split = True

if recreate_split:
    idx = torch.randperm(len(dataset)).tolist()
    trn = int(0.9 * len(dataset))
    train_idx, val_idx = idx[:trn], idx[trn:]
    torch.save({"train": train_idx, "val": val_idx}, split_file)

train_set = Subset(dataset, train_idx)
val_set = Subset(dataset, val_idx)

# 10-way split for progressive training
parts = 10
sizes = [len(train_set) // parts] * (parts - 1)
sizes.append(len(train_set) - sum(sizes))
train_parts = random_split(train_set, sizes)

def make_loader(subset, batch_size=32, shuffle=True):
    num_workers = 3
    return DataLoader(
        subset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=4,
        persistent_workers=True,
    )

val_loader = make_loader(val_set, batch_size=32, shuffle=False)

  data = torch.load(split_file)


In [6]:
# # --- GENERATOR WARM-UP ---
# # Only run this if starting from scratch
# if start_part == 1:
#     print("üöÄ Starting Generator Warm-up for 5 epochs...")
#     warmup_epochs = 5
#     gen.train()
#     # Use the first part of the data for pre-training
#     warmup_loader = make_loader(train_parts[0], batch_size=32, shuffle=True)

#     for epoch in range(warmup_epochs):
#         for thermal_img, real_rgb in warmup_loader:
#             thermal_img, real_rgb = thermal_img.to(device), real_rgb.to(device)
            
#             opt_gen.zero_grad()
#             with torch.amp.autocast(device_type=device.type, enabled=(device.type == "cuda")):
#                 fake_rgb = gen(thermal_img)
#                 # Calculate only reconstruction losses
#                 loss_g_l1 = l1(fake_rgb, real_rgb) * lambda_L1
#                 loss_g_perceptual = perceptual_loss(fake_rgb, real_rgb) * lambda_perceptual
#                 warmup_loss = loss_g_l1 + loss_g_perceptual
            
#             scaler.scale(warmup_loss).backward()
#             scaler.step(opt_gen)
#             scaler.update()
        
#         print(f"Warm-up Epoch {epoch+1}/{warmup_epochs}, Loss: {warmup_loss.item():.4f}")
#     print("‚úÖ Generator Warm-up complete!")

In [7]:
class ResidualConv(nn.Module):
    def __init__(self, in_c, out_c, norm=True):
        super().__init__()
        self.proj = nn.Conv2d(in_c, out_c, 1, 1, 0, bias=False) if in_c != out_c else None
        self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_c) if norm else nn.Identity()
        self.act = nn.GELU()
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_c) if norm else nn.Identity()

    def forward(self, x):
        identity = x if self.proj is None else self.proj(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act(out)
        out = self.conv2(out)
        out = self.bn2(out)
        return self.act(out + identity)


class DenseUNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3, base=32):
        super().__init__()
        self.down1 = self.down_block(in_channels, base, norm=False)
        self.down2 = self.down_block(base, base * 2)
        self.down3 = self.down_block(base * 2, base * 4)
        self.down4 = self.down_block(base * 4, base * 8)
        self.down5 = self.down_block(base * 8, base * 8)
        self.down6 = self.down_block(base * 8, base * 8)
        self.down7 = self.down_block(base * 8, base * 8)
        self.down8 = self.down_block(base * 8, base * 8, norm=False)

        self.up1 = self.up_block(base * 8, base * 8, drop=True)
        self.up2 = self.up_block(base * 16, base * 8, drop=True)
        self.up3 = self.up_block(base * 16, base * 8, drop=True)
        self.up4 = self.up_block(base * 16, base * 8)
        self.up5 = self.up_block(base * 16, base * 4)
        self.up6 = self.up_block(base * 8, base * 2)
        self.up7 = self.up_block(base * 4, base)

        self.final = nn.Sequential(nn.ConvTranspose2d(base * 2, out_channels, 4, 2, 1, bias=False), nn.Tanh())

    def down_block(self, in_c, out_c, norm=True):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_c) if norm else nn.Identity(),
            nn.GELU(),
            ResidualConv(out_c, out_c, norm=norm),
        )

    def up_block(self, in_c, out_c, drop=False):
        layers = [
            nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.GELU(),
            ResidualConv(out_c, out_c),
        ]
        if drop:
            layers.append(nn.Dropout(0.5))
        return nn.Sequential(*layers)

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8)
        u1 = torch.cat([u1, d7], dim=1)
        u2 = self.up2(u1)
        u2 = torch.cat([u2, d6], dim=1)
        u3 = self.up3(u2)
        u3 = torch.cat([u3, d5], dim=1)
        u4 = self.up4(u3)
        u4 = torch.cat([u4, d4], dim=1)
        u5 = self.up5(u4)
        u5 = torch.cat([u5, d3], dim=1)
        u6 = self.up6(u5)
        u6 = torch.cat([u6, d2], dim=1)
        u7 = self.up7(u6)
        u7 = torch.cat([u7, d1], dim=1)

        return self.final(u7)

In [8]:
class DBlock(nn.Module):
    def __init__(self, in_c, out_c, norm=True, use_spectral=True):
        super().__init__()
        conv = nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False)
        if use_spectral:
            conv = nn.utils.spectral_norm(conv)
        self.conv = conv
        self.bn = nn.BatchNorm2d(out_c) if norm else nn.Identity()
        self.act = nn.LeakyReLU(0.2, inplace=True)
        self.res = ResidualConv(out_c, out_c, norm=norm)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return self.res(x)


class PatchDiscriminatorDense(nn.Module):
    def __init__(self, in_channels=4, base=64):
        super().__init__()
        self.b1 = DBlock(in_channels, base, norm=False, use_spectral=True)
        self.b2 = DBlock(base, base * 2, use_spectral=True)
        self.b3 = DBlock(base * 2, base * 4, use_spectral=True)
        self.b4 = DBlock(base * 4, base * 8, use_spectral=True)
        self.b5 = DBlock(base * 8, base * 8, use_spectral=True)

        self.features = nn.ModuleList([self.b1, self.b2, self.b3, self.b4, self.b5])

        self.out = nn.Conv2d(base * 8, 1, 3, padding=1, bias=False)

    def forward(self, x, return_features=False):
        features = []
        for layer in self.features:
            x = layer(x)
            if return_features:
                features.append(x)
        output = self.out(x)
        if return_features:
            return output, features
        return output

In [9]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        self.slice1 = nn.Sequential()
        self.slice2 = nn.Sequential()
        self.slice3 = nn.Sequential()

        for x in range(4):
            self.slice1.add_module(str(x), vgg[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg[x])

        for param in self.parameters():
            param.requires_grad = False

        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.l1 = nn.L1Loss()

    def forward(self, x, y):
        x = (x + 1) / 2
        y = (y + 1) / 2

        x = self.normalize(x)
        y = self.normalize(y)

        feat_x1 = self.slice1(x)
        feat_y1 = self.slice1(y)
        feat_x2 = self.slice2(feat_x1)
        feat_y2 = self.slice2(feat_y1)
        feat_x3 = self.slice3(feat_x2)
        feat_y3 = self.slice3(feat_y2)

        loss1 = self.l1(feat_x1, feat_y1)
        loss2 = self.l1(feat_x2, feat_y2)
        loss3 = self.l1(feat_x3, feat_y3)

        return loss1 + loss2 + loss3

In [10]:
!nvidia-smi


Sat Nov  1 08:29:07 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.06             Driver Version: 535.183.06   CUDA Version: 12.6     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off | 00000000:4E:00.0 Off |                   On |
| N/A   37C    P0              90W / 400W |                  N/A |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [11]:
import inspect
from torch.optim.lr_scheduler import StepLR, MultiStepLR # Import MultiStepLR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

gen = DenseUNetGenerator().to(device)
disc = PatchDiscriminatorDense().to(device)
perceptual_loss = VGGPerceptualLoss().to(device)

def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu")
        if getattr(m, "bias", None) is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

gen.apply(weights_init)
disc.apply(weights_init)
print("Applied weights_init.")

# --- Final Hyperparameter Balance ---
lambda_L1 = 100
lambda_perceptual = 0.1
lambda_GAN = 0.1
epochs_per_part = 100

# --- NEW: Define optimizers and schedulers here ---
# This ensures they exist for the warm-up and checkpoint loading.
opt_gen = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_disc = torch.optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5, 0.999))
scheduler_gen = MultiStepLR(opt_gen, milestones=[30], gamma=0.1)
scheduler_disc = MultiStepLR(opt_disc, milestones=[30], gamma=0.1)
# ---------------------------------------------------

# Loss functions
criterion_adv = nn.BCEWithLogitsLoss()
l1 = nn.L1Loss()

# Mixed precision scaler
scaler = torch.amp.GradScaler(enabled=(device.type == "cuda"))

def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

print(f"Params ‚Üí Gen: {count_params(gen):,}, Disc: {count_params(disc):,}")

Using device: cuda
Applied weights_init.
Params ‚Üí Gen: 25,008,064, Disc: 17,949,440


In [12]:
from collections import OrderedDict

checkpoint_dir = "checkpoints-9_10_latest" # Make sure this matches your folder name
os.makedirs(checkpoint_dir, exist_ok=True)

latest_checkpoint_path = None
# Filter the list to only include checkpoint files
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')]

if checkpoint_files:
    latest_checkpoint_path = max(
        [os.path.join(checkpoint_dir, f) for f in checkpoint_files],
        key=os.path.getctime,
    )

start_part = 1
start_epoch = 1

if latest_checkpoint_path:
    print(f"‚úÖ Resuming training from checkpoint: {latest_checkpoint_path}")
    checkpoint = torch.load(latest_checkpoint_path, map_location=torch.device("cpu"))

    gen.load_state_dict(checkpoint["gen_state_dict"])
    disc.load_state_dict(checkpoint["disc_state_dict"])
    opt_gen.load_state_dict(checkpoint["opt_gen_state_dict"])
    opt_disc.load_state_dict(checkpoint["opt_disc_state_dict"])
    val_history = checkpoint.get("val_history", [])
    
    start_part = checkpoint["part"]
    start_epoch = checkpoint.get("epoch", 0) + 1

    if start_epoch > epochs_per_part:
        start_part += 1
        start_epoch = 1
    
    print(f"‚û°Ô∏è Resuming from Part {start_part}, Epoch {start_epoch}")
    
else:
    print("‚úÖ Starting training from scratch.")
    # --- NEW: Initialize val_history for a fresh start ---
    val_history = []
    # ----------------------------------------------------
    
    # This section for generator warm-up will only run on a fresh start
    if 'train_parts' in locals() and len(train_parts) > 0:
        print("\nüöÄ Starting Generator Warm-up for 5 epochs...")
        warmup_epochs = 5
        gen.train()
        warmup_loader = make_loader(train_parts[0], batch_size=32, shuffle=True)
        for epoch in range(warmup_epochs):
            for thermal_img, real_rgb in warmup_loader:
                thermal_img, real_rgb = thermal_img.to(device), real_rgb.to(device)
                
                opt_gen.zero_grad() 
                
                with torch.amp.autocast(device_type=device.type, enabled=(device.type == "cuda")):
                    fake_rgb = gen(thermal_img)
                    loss_g_l1 = l1(fake_rgb, real_rgb) * lambda_L1
                    loss_g_perceptual = perceptual_loss(fake_rgb, real_rgb) * lambda_perceptual
                    warmup_loss = loss_g_l1 + loss_g_perceptual
                scaler.scale(warmup_loss).backward()
                scaler.step(opt_gen)
                scaler.update()
            print(f"Warm-up Epoch {epoch+1}/{warmup_epochs}, Loss: {warmup_loss.item():.4f}")
        print("‚úÖ Generator Warm-up complete!")

‚úÖ Resuming training from checkpoint: checkpoints-9_10_latest/checkpoint_part_3_epoch_40.pt


  checkpoint = torch.load(latest_checkpoint_path, map_location=torch.device("cpu"))


‚û°Ô∏è Resuming from Part 3, Epoch 41


In [13]:
# # Cell 8
# # if hasattr(torch, "compile"):
    
# #     gen = torch.compile(gen)
# #     disc = torch.compile(disc)

# # Cell 8 (optional compile)
# if hasattr(torch, "compile"):
#     try:
#         gen = torch.compile(gen)
#         disc = torch.compile(disc)
#         print("‚úÖ Models compiled with torch.compile()")
#     except Exception as e:
#         print("‚ö†Ô∏è torch.compile not applied:", e)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR # Make sure this is imported
from math import exp
import matplotlib.pyplot as plt
import glob
import os
import time
from tqdm import tqdm

# --- Assume these are defined in a previous cell ---
# gen, disc, criterion_adv, l1, perceptual_loss, scaler
# lambda_GAN, lambda_L1, lambda_perceptual
# train_parts, val_loader, make_loader
# start_part, start_epoch, epochs_per_part
# device
# --------------------------------------------------

# ==============================================================================
# CELL 1: TORCH-BASED EVALUATION METRICS
# ==============================================================================

def denorm01(x):
    """Denormalizes from [-1, 1] to [0, 1]"""
    return (x * 0.5 + 0.5).clamp(0, 1)

def psnr_torch(img1, img2, data_range=1.0):
    """
    Calculates PSNR for a batch of images (range [0, 1]).
    """
    # img1 and img2 have shape [B, C, H, W]
    mse = F.mse_loss(img1, img2)
    psnr_val = 20 * torch.log10(data_range / torch.sqrt(mse))
    return psnr_val

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2 / float(2*sigma**2)) for x in range(window_size)])
    return gauss / gauss.sum()

def create_window(window_size, channel=1):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window

def ssim_torch(img1, img2, window_size=11, window=None, size_average=True, val_range=1.0):
    """
    Calculates SSIM for a batch of images (range [0, 1]).
    """
    (_, channel, _, _) = img1.size()
    
    if window is None:
        real_window = create_window(window_size, channel).to(img1.device)
    else:
        real_window = window.to(img1.device)

    K1 = 0.01
    K2 = 0.03
    L = val_range
    C1 = (K1 * L) ** 2
    C2 = (K2 * L) ** 2

    mu1 = F.conv2d(img1, real_window, padding=window_size//2, groups=channel)
    mu2 = F.conv2d(img2, real_window, padding=window_size//2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, real_window, padding=window_size//2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, real_window, padding=window_size//2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, real_window, padding=window_size//2, groups=channel) - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

@torch.no_grad()
def evaluate(gen, loader, device="cuda"):
    """
    NEW fast evaluation function using torch-based metrics.
    [FIXED for AMP/Half-precision mismatch]
    """
    gen.eval()
    psnr_sum = 0.0
    ssim_sum = 0.0
    n = 0
    
    # Pre-create the SSIM window (in float32) and move it to the device once
    ssim_window = create_window(11, 3).to(device) # 3 channels for RGB

    for t, r in loader:
        t, r = t.to(device), r.to(device)
        
        # --- ONLY autocast the model inference ---
        with torch.amp.autocast(device_type=device.type, enabled=(device.type == "cuda")):
            f = gen(t) # f will be float16
        
        # --- Convert BOTH images to float32 for metrics ---
        r_denorm = denorm01(r)  # r is already float32
        
        # *** THE FIX IS HERE ***
        # Cast f from float16 back to float32 before denorming and calculating metrics
        f_denorm = denorm01(f.float()) 
        
        # Now, r_denorm, f_denorm, and ssim_window are all float32
        
        # Calculate metrics for the whole batch on the GPU
        batch_size = t.size(0)
        psnr_sum += psnr_torch(r_denorm, f_denorm, data_range=1.0).item() * batch_size
        ssim_sum += ssim_torch(r_denorm, f_denorm, window=ssim_window, val_range=1.0).item() * batch_size
        n += batch_size
        
    # Return the average
    return psnr_sum / max(1, n), ssim_sum / max(1, n)

print("‚úÖ New Torch-based PSNR, SSIM, and evaluate functions defined.")


# ==============================================================================
# CELL 2: PLOTTING & CHECKPOINTING HELPERS
# ==============================================================================

# --- Directories ---
checkpoint_dir = "checkpoints-9_10_latest"
graph_dir = "graphs-9_10_latest"
best_model_path = os.path.join(checkpoint_dir, "best_model.pt")
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(graph_dir, exist_ok=True)

def plot_losses(history, save_path):
    """
    Plots and saves the training history.  
    Handles gaps in validation data (where psnr/ssim are None).
    """
    plt.figure(figsize=(15, 10))
    
    # Extract data
    epochs = list(range(len(history))) # Use simple integer indices for plotting
    epoch_labels = [f"{h['part']}-{h['epoch']}" for h in history]
    
    # These will contain None for skipped epochs, which plt.plot handles correctly
    psnr = [h['psnr'] for h in history]  
    ssim = [h['ssim'] for h in history]
    g_loss = [h['g_loss'] for h in history]
    d_loss = [h['d_loss'] for h in history]
    
    # Get indices *only* where validation was actually run
    val_indices = [i for i, h in enumerate(history) if h['psnr'] is not None]
    val_labels = [epoch_labels[i] for i in val_indices]

    # Plot PSNR
    plt.subplot(2, 2, 1)
    # Plot will create breaks for 'None' values
    plt.plot(epochs, psnr, label='Validation PSNR', color='blue', marker='o', linestyle='-')
    plt.title('Validation PSNR (Higher is Better)')
    plt.xlabel('Epoch (Part-Epoch)')
    plt.ylabel('PSNR (dB)')
    if len(val_indices) > 0: # Set ticks only to evaluated points
           plt.xticks(ticks=val_indices, labels=val_labels, rotation=45)
    plt.legend()
    plt.grid(True)
    
    # Plot SSIM
    plt.subplot(2, 2, 2)
    plt.plot(epochs, ssim, label='Validation SSIM', color='green', marker='o', linestyle='-')
    plt.title('Validation SSIM (Higher is Better)')
    plt.xlabel('Epoch (Part-Epoch)')
    plt.ylabel('SSIM')
    if len(val_indices) > 0: # Set ticks only to evaluated points
           plt.xticks(ticks=val_indices, labels=val_labels, rotation=45)
    plt.legend()
    plt.grid(True)
    
    # Plot Generator Loss (runs every epoch)
    plt.subplot(2, 2, 3)
    plt.plot(epochs, g_loss, label='Train Generator Loss', color='red')
    plt.title('Train Generator Loss (Lower is Better)')
    plt.xlabel('Epoch (Part-Epoch)')
    # Set ticks to match validation for readability
    if len(val_indices) > 0:  
           plt.xticks(ticks=val_indices, labels=val_labels, rotation=45)
    plt.legend()
    plt.grid(True)
    
    # Plot Discriminator Loss (runs every epoch)
    plt.subplot(2, 2, 4)
    plt.plot(epochs, d_loss, label='Train Discriminator Loss', color='purple')
    plt.title('Train Discriminator Loss (Lower is Better)')
    plt.xlabel('Epoch (Part-Epoch)')
    # Set ticks to match validation for readability
    if len(val_indices) > 0:
           plt.xticks(ticks=val_indices, labels=val_labels, rotation=45)
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path)
    # plt.show() # Show plot in notebook
    plt.close()

def save_checkpoint(state, is_best, part, epoch):
    """
    Saves the current model state and manages checkpoint rotation.
    Keeps only the 3 latest checkpoints and the single best model.
    """
    try:
        # 1. Save the "best" model if this is it
        if is_best:
            torch.save(state, best_model_path)
            print(f"    *** New Best Model! Saved to {best_model_path} ***")
        
        # 2. Save the "latest" checkpoint for this epoch
        filename = f"checkpoint_part_{part}_epoch_{epoch}.pt"
        latest_path = os.path.join(checkpoint_dir, filename)
        torch.save(state, latest_path)
        print(f"    üíæ Checkpoint saved to {latest_path}")

        # 3. Clean up old checkpoints (keep only 3 latest)
        checkpoints = sorted(
            glob.glob(os.path.join(checkpoint_dir, "checkpoint_part_*.pt")), 
            key=os.path.getmtime
        )
        
        if len(checkpoints) > 3:
            for cp_path in checkpoints[:-3]:
                print(f"    üóëÔ∏è Removing old checkpoint: {cp_path}")
                try:
                    os.remove(cp_path)
                    
                    # Also remove the corresponding graph
                    part_epoch_str = cp_path.split('checkpoint_')[-1].split('.pt')[0]
                    graph_to_remove = os.path.join(graph_dir, f"graph_{part_epoch_str}.png")
                    if os.path.exists(graph_to_remove):
                        os.remove(graph_to_remove)
                        
                except Exception as e:
                    print(f"    Error removing old file {cp_path} or its graph: {e}")
                    
    except Exception as e:
        print(f"    Error during save_checkpoint: {e}")

# --- Initialize History & Load Best Validation Metric ---

# IMPORTANT: If you are resuming training, you must load 'val_history' 
# from your checkpoint in the *previous* cell, along with start_part/start_epoch.
# If this is a new run, we'll create an empty list.
if 'val_history' not in globals():
    print("Initializing new 'val_history' list.")
    val_history = []

best_val_psnr = -1.0
if os.path.exists(best_model_path):
    try:
        print(f"Found existing best model at: {best_model_path}")
        best_checkpoint = torch.load(best_model_path, map_location=device)
        if 'val_history' in best_checkpoint and len(best_checkpoint['val_history']) > 0:
             # Find the max PSNR from history, filtering out None values
             valid_psnrs = [h['psnr'] for h in best_checkpoint['val_history'] if h['psnr'] is not None]
             if valid_psnrs:
                   best_val_psnr = max(valid_psnrs)
        print(f"Resuming with best_val_psnr: {best_val_psnr:.2f} dB")
    except Exception as e:
        print(f"Could not read best_val_psnr from best model: {e}. Starting with -1.0.")

print("‚úÖ Helper functions for checkpointing and plotting are defined.")


# ==============================================================================
# CELL 3: MAIN TRAINING LOOP (Using "Perfect" Original Logic)
# ==============================================================================

print("\n" + "="*50)
print("--- Starting Main Training Loop ---")
print(f"Strategy: Train on ONE part, reset optimizers.")
print(f"Resuming from Part: {start_part}, Epoch: {start_epoch}")
print(f"Tracking Best PSNR (current best: {best_val_psnr:.2f} dB)")
print(f"Hyperparameters: L1_Œª={lambda_L1}, Perceptual_Œª={lambda_perceptual}, GAN_Œª={lambda_GAN}")
print("Evaluation will run every 20 epochs or on the last epoch of each part.")
print("="*50 + "\n")

# --- Loop based on your ORIGINAL "perfect" code ---
for part_idx in range(start_part, len(train_parts) + 1):
    if part_idx > len(train_parts):
        break
    
    # --- FINAL STRATEGY: Train on ONLY the current part (from original code) ---
    train_loader = make_loader(train_parts[part_idx - 1], batch_size=32, shuffle=True)
    print(f"\n--- Starting Training on Part {part_idx}/{len(train_parts)} ---")
    print(f"Dataset size: {len(train_loader.dataset)} images")

    # --- Reset optimizers and use the Scheduler (from original code) ---
    print("‚ú® Resetting optimizers and using MultiStepLR schedule.")
    opt_gen = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_disc = torch.optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5, 0.999))
    # This scheduler will drop the LR at epoch 30
    scheduler_gen = MultiStepLR(opt_gen, milestones=[30], gamma=0.1) 
    scheduler_disc = MultiStepLR(opt_disc, milestones=[30], gamma=0.1)
    # -----------------------------------------------------------------

    for epoch in range(start_epoch, epochs_per_part + 1):
        gen.train()
        disc.train()
        
        epoch_start_time = time.time()
        g_loss_sum = 0.0
        d_loss_sum = 0.0
        
        loop = tqdm(train_loader, desc=f"Part {part_idx}/{len(train_parts)} | Epoch {epoch}/{epochs_per_part}", leave=False)

        for step, (thermal_img, real_rgb) in enumerate(loop, 1):
            thermal_img, real_rgb = thermal_img.to(device), real_rgb.to(device)
            
            # ---------------------
            #  Train Discriminator (FROM ORIGINAL "PERFECT" CODE)
            # ---------------------
            opt_disc.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type=device.type, enabled=(device.type == "cuda")):
                fake_rgb = gen(thermal_img)
                disc_real_input = torch.cat((thermal_img, real_rgb), 1)
                disc_fake_input = torch.cat((thermal_img, fake_rgb.detach()), 1)
                disc_real_output = disc(disc_real_input)
                disc_fake_output = disc(disc_fake_input)
                
                # *** THIS IS THE FIX: Restored label smoothing from your original code ***
                real_label = torch.ones_like(disc_real_output) * 0.9 
                fake_label = torch.zeros_like(disc_fake_output)
                
                loss_d_real = criterion_adv(disc_real_output, real_label)
                loss_d_fake = criterion_adv(disc_fake_output, fake_label)
                loss_d = (loss_d_real + loss_d_fake) / 2
            
            scaler.scale(loss_d).backward()
            scaler.step(opt_disc)

            # -----------------
            #  Train Generator (FROM ORIGINAL "PERFECT" CODE)
            # -----------------
            opt_gen.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type=device.type, enabled=(device.type == "cuda")):
                disc_fake_input = torch.cat((thermal_img, fake_rgb), 1)
                disc_fake_output = disc(disc_fake_input)
                loss_g_gan_raw = criterion_adv(disc_fake_output, torch.ones_like(disc_fake_output))
                loss_g_l1_raw = l1(fake_rgb, real_rgb)
                loss_g_perceptual_raw = perceptual_loss(fake_rgb, real_rgb)
                
                loss_g_gan = loss_g_gan_raw * lambda_GAN
                loss_g_l1 = loss_g_l1_raw * lambda_L1
                loss_g_perceptual = loss_g_perceptual_raw * lambda_perceptual
                loss_g = loss_g_gan + loss_g_l1 + loss_g_perceptual
            
            scaler.scale(loss_g).backward()
            scaler.step(opt_gen)
            scaler.update()
            
            # --- Update running losses and TQDM bar ---
            g_loss_sum += loss_g.item()
            d_loss_sum += loss_d.item()
            loop.set_postfix(
                G_Loss=f"{loss_g.item():.4f}",
                D_Loss=f"{loss_d.item():.4f}",
                G_L1=f"{loss_g_l1.item():.4f}"
            )

        # --- End of Epoch ---
        scheduler_gen.step()
        scheduler_disc.step()

        avg_g_loss = g_loss_sum / len(train_loader)
        avg_d_loss = d_loss_sum / len(train_loader)
        epoch_time = time.time() - epoch_start_time
        current_lr = scheduler_gen.get_last_lr()[0]
        
        print(f"\n--- Epoch Summary: Part {part_idx}/{len(train_parts)} | Epoch {epoch}/{epochs_per_part} | Time: {epoch_time:.2f}s ---")
        print(f"  Train Loss ‚Üí G: {avg_g_loss:.4f}, D: {avg_d_loss:.4f} | LR: {current_lr:.1e}")

        val_psnr, val_ssim = None, None # Initialize as None
        
        # --- Run Validation, Checkpointing, and Plotting ---
        if epoch % 20 == 0 or epoch == epochs_per_part:
            print(f"  Running evaluation for epoch {epoch}...")
            
            # --- Run Validation (uses new fast evaluate()) ---
            val_psnr, val_ssim = evaluate(gen, val_loader, device=device)
            print(f"  Val Metrics ‚Üí PSNR: {val_psnr:.2f} dB, SSIM: {val_ssim:.4f}")

            # --- Checkpoint Saving ---
            is_best = val_psnr > best_val_psnr
            if is_best:
                best_val_psnr = val_psnr # Update best score

            state = {
                'part': part_idx,
                'epoch': epoch,
                'gen_state_dict': gen.state_dict(),
                'disc_state_dict': disc.state_dict(),
                'opt_gen_state_dict': opt_gen.state_dict(),
                'opt_disc_state_dict': opt_disc.state_dict(),
                'val_history': val_history, # Pass the whole history
            }
            
            save_checkpoint(state, is_best, part_idx, epoch)
            
            # --- Append to history and Plot ---
            val_history.append({
                "part": part_idx, "epoch": epoch, 
                "psnr": val_psnr, "ssim": val_ssim, 
                "g_loss": avg_g_loss, "d_loss": avg_d_loss
            })
            
            graph_save_path = os.path.join(graph_dir, f"graph_part_{part_idx}_epoch_{epoch}.png")
            plot_losses(val_history, save_path=graph_save_path) # Plot the full history
        
        else:
            print("  Skipping evaluation, checkpointing, and graphing for this epoch.")
            # Append training losses, but None for validation metrics
            val_history.append({
                "part": part_idx, "epoch": epoch, 
                "psnr": None, # Append None
                "ssim": None, # Append None
                "g_loss": avg_g_loss, "d_loss": avg_d_loss
            })

    # Reset start_epoch to 1 for the next part (from original code)
    start_epoch = 1

print("\nüéâ --- Training Finished --- üéâ")

‚úÖ New Torch-based PSNR, SSIM, and evaluate functions defined.
Found existing best model at: checkpoints-9_10_latest/best_model.pt


  best_checkpoint = torch.load(best_model_path, map_location=device)


Resuming with best_val_psnr: 19.17 dB
‚úÖ Helper functions for checkpointing and plotting are defined.

--- Starting Main Training Loop ---
Strategy: Train on ONE part, reset optimizers.
Resuming from Part: 3, Epoch: 41
Tracking Best PSNR (current best: 19.17 dB)
Hyperparameters: L1_Œª=100, Perceptual_Œª=0.1, GAN_Œª=0.1
Evaluation will run every 20 epochs or on the last epoch of each part.


--- Starting Training on Part 3/10 ---
Dataset size: 13394 images
‚ú® Resetting optimizers and using MultiStepLR schedule.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 41/100 | Time: 111.17s ---
  Train Loss ‚Üí G: 14.6067, D: 0.3064 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 42/100 | Time: 106.27s ---
  Train Loss ‚Üí G: 14.4565, D: 0.2677 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 43/100 | Time: 106.15s ---
  Train Loss ‚Üí G: 14.3232, D: 0.2657 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 44/100 | Time: 106.06s ---
  Train Loss ‚Üí G: 14.2507, D: 0.2589 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 45/100 | Time: 106.04s ---
  Train Loss ‚Üí G: 13.9457, D: 0.2674 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 46/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 13.9505, D: 0.2641 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 47/100 | Time: 106.24s ---
  Train Loss ‚Üí G: 13.9247, D: 0.2561 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 48/100 | Time: 106.14s ---
  Train Loss ‚Üí G: 14.2833, D: 0.2528 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 49/100 | Time: 106.06s ---
  Train Loss ‚Üí G: 14.1165, D: 0.2456 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 50/100 | Time: 106.12s ---
  Train Loss ‚Üí G: 13.8831, D: 0.2475 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 51/100 | Time: 106.20s ---
  Train Loss ‚Üí G: 13.7463, D: 0.2361 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 52/100 | Time: 106.17s ---
  Train Loss ‚Üí G: 13.7443, D: 0.2467 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 53/100 | Time: 106.13s ---
  Train Loss ‚Üí G: 13.5768, D: 0.2487 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 54/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 13.6655, D: 0.2422 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 55/100 | Time: 106.21s ---
  Train Loss ‚Üí G: 13.4529, D: 0.2498 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 56/100 | Time: 106.17s ---
  Train Loss ‚Üí G: 13.5167, D: 0.2414 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 57/100 | Time: 106.04s ---
  Train Loss ‚Üí G: 13.3996, D: 0.2368 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 58/100 | Time: 106.04s ---
  Train Loss ‚Üí G: 13.4365, D: 0.2504 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 59/100 | Time: 106.05s ---
  Train Loss ‚Üí G: 13.4649, D: 0.2382 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 60/100 | Time: 106.06s ---
  Train Loss ‚Üí G: 13.4561, D: 0.2431 | LR: 2.0e-04
  Running evaluation for epoch 60...




  Val Metrics ‚Üí PSNR: 18.37 dB, SSIM: 0.6001
    üíæ Checkpoint saved to checkpoints-9_10_latest/checkpoint_part_3_epoch_60.pt
    üóëÔ∏è Removing old checkpoint: checkpoints-9_10_latest/checkpoint_part_2_epoch_100.pt


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 61/100 | Time: 106.19s ---
  Train Loss ‚Üí G: 13.9223, D: 0.2310 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 62/100 | Time: 106.06s ---
  Train Loss ‚Üí G: 13.5913, D: 0.2266 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 63/100 | Time: 106.13s ---
  Train Loss ‚Üí G: 13.3468, D: 0.2503 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 64/100 | Time: 106.01s ---
  Train Loss ‚Üí G: 13.2467, D: 0.2511 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 65/100 | Time: 106.03s ---
  Train Loss ‚Üí G: 13.6039, D: 0.2228 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 66/100 | Time: 106.15s ---
  Train Loss ‚Üí G: 13.6897, D: 0.2377 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 67/100 | Time: 106.14s ---
  Train Loss ‚Üí G: 13.3568, D: 0.2432 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 68/100 | Time: 106.03s ---
  Train Loss ‚Üí G: 13.2897, D: 0.2288 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 69/100 | Time: 106.00s ---
  Train Loss ‚Üí G: 13.3227, D: 0.2285 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 70/100 | Time: 105.99s ---
  Train Loss ‚Üí G: 13.1974, D: 0.2470 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 71/100 | Time: 105.98s ---
  Train Loss ‚Üí G: 12.6360, D: 0.1728 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 72/100 | Time: 106.12s ---
  Train Loss ‚Üí G: 12.7469, D: 0.1766 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 73/100 | Time: 106.16s ---
  Train Loss ‚Üí G: 12.8643, D: 0.1766 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 74/100 | Time: 106.14s ---
  Train Loss ‚Üí G: 12.8526, D: 0.1753 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 75/100 | Time: 106.19s ---
  Train Loss ‚Üí G: 12.8707, D: 0.1733 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 76/100 | Time: 106.17s ---
  Train Loss ‚Üí G: 12.8830, D: 0.1746 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 77/100 | Time: 106.22s ---
  Train Loss ‚Üí G: 12.8691, D: 0.1751 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 78/100 | Time: 106.21s ---
  Train Loss ‚Üí G: 12.8557, D: 0.1760 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 79/100 | Time: 106.12s ---
  Train Loss ‚Üí G: 12.9188, D: 0.1756 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 80/100 | Time: 106.07s ---
  Train Loss ‚Üí G: 12.9002, D: 0.1781 | LR: 2.0e-05
  Running evaluation for epoch 80...
  Val Metrics ‚Üí PSNR: 18.56 dB, SSIM: 0.6077
    üíæ Checkpoint saved to checkpoints-9_10_latest/checkpoint_part_3_epoch_80.pt
    üóëÔ∏è Removing old checkpoint: checkpoints-9_10_latest/checkpoint_part_3_epoch_20.pt


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 81/100 | Time: 106.22s ---
  Train Loss ‚Üí G: 12.8533, D: 0.1762 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 82/100 | Time: 106.08s ---
  Train Loss ‚Üí G: 12.8340, D: 0.1808 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 83/100 | Time: 106.08s ---
  Train Loss ‚Üí G: 12.8547, D: 0.1754 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 84/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 12.9128, D: 0.1757 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 85/100 | Time: 106.17s ---
  Train Loss ‚Üí G: 12.9083, D: 0.1828 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 86/100 | Time: 106.13s ---
  Train Loss ‚Üí G: 12.8674, D: 0.1803 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 87/100 | Time: 106.08s ---
  Train Loss ‚Üí G: 12.8362, D: 0.1763 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 88/100 | Time: 106.15s ---
  Train Loss ‚Üí G: 12.8871, D: 0.1801 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 89/100 | Time: 106.18s ---
  Train Loss ‚Üí G: 12.9349, D: 0.1777 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 90/100 | Time: 106.07s ---
  Train Loss ‚Üí G: 12.8679, D: 0.1790 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 91/100 | Time: 105.99s ---
  Train Loss ‚Üí G: 12.8413, D: 0.1843 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 92/100 | Time: 106.05s ---
  Train Loss ‚Üí G: 12.8783, D: 0.1760 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 93/100 | Time: 106.15s ---
  Train Loss ‚Üí G: 12.8922, D: 0.1794 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 94/100 | Time: 106.24s ---
  Train Loss ‚Üí G: 12.8644, D: 0.1802 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 95/100 | Time: 106.42s ---
  Train Loss ‚Üí G: 12.8248, D: 0.1783 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 96/100 | Time: 106.37s ---
  Train Loss ‚Üí G: 12.8465, D: 0.1862 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 97/100 | Time: 106.15s ---
  Train Loss ‚Üí G: 12.8882, D: 0.1799 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 98/100 | Time: 106.14s ---
  Train Loss ‚Üí G: 12.8854, D: 0.1876 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 3/10 | Epoch 99/100 | Time: 106.05s ---
  Train Loss ‚Üí G: 12.8763, D: 0.1775 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                         


--- Epoch Summary: Part 3/10 | Epoch 100/100 | Time: 106.00s ---
  Train Loss ‚Üí G: 12.8855, D: 0.1844 | LR: 2.0e-05
  Running evaluation for epoch 100...
  Val Metrics ‚Üí PSNR: 18.64 dB, SSIM: 0.6101
    üíæ Checkpoint saved to checkpoints-9_10_latest/checkpoint_part_3_epoch_100.pt
    üóëÔ∏è Removing old checkpoint: checkpoints-9_10_latest/checkpoint_part_3_epoch_40.pt

--- Starting Training on Part 4/10 ---
Dataset size: 13394 images
‚ú® Resetting optimizers and using MultiStepLR schedule.


                                                                                                                       


--- Epoch Summary: Part 4/10 | Epoch 1/100 | Time: 106.55s ---
  Train Loss ‚Üí G: 14.4341, D: 0.2553 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 4/10 | Epoch 2/100 | Time: 106.18s ---
  Train Loss ‚Üí G: 14.3029, D: 0.2469 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 4/10 | Epoch 3/100 | Time: 106.25s ---
  Train Loss ‚Üí G: 14.0820, D: 0.2558 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 4/10 | Epoch 4/100 | Time: 106.15s ---
  Train Loss ‚Üí G: 13.9755, D: 0.2326 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 4/10 | Epoch 5/100 | Time: 106.19s ---
  Train Loss ‚Üí G: 13.7462, D: 0.2399 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 4/10 | Epoch 6/100 | Time: 106.25s ---
  Train Loss ‚Üí G: 13.8698, D: 0.2251 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 4/10 | Epoch 7/100 | Time: 106.17s ---
  Train Loss ‚Üí G: 13.8274, D: 0.2345 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 4/10 | Epoch 8/100 | Time: 105.99s ---
  Train Loss ‚Üí G: 14.0650, D: 0.2152 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 4/10 | Epoch 9/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 13.6194, D: 0.2442 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 10/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 13.8879, D: 0.2083 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 11/100 | Time: 106.00s ---
  Train Loss ‚Üí G: 13.9925, D: 0.2194 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 12/100 | Time: 105.98s ---
  Train Loss ‚Üí G: 13.8384, D: 0.2226 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 13/100 | Time: 106.16s ---
  Train Loss ‚Üí G: 13.5208, D: 0.2349 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 14/100 | Time: 106.18s ---
  Train Loss ‚Üí G: 13.4092, D: 0.2201 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 15/100 | Time: 106.21s ---
  Train Loss ‚Üí G: 13.4368, D: 0.2186 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 16/100 | Time: 106.24s ---
  Train Loss ‚Üí G: 13.4692, D: 0.2161 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 17/100 | Time: 106.15s ---
  Train Loss ‚Üí G: 13.3854, D: 0.2216 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 18/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 13.1872, D: 0.2163 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 19/100 | Time: 106.43s ---
  Train Loss ‚Üí G: 13.4149, D: 0.2222 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 20/100 | Time: 106.03s ---
  Train Loss ‚Üí G: 13.3674, D: 0.2239 | LR: 2.0e-04
  Running evaluation for epoch 20...
  Val Metrics ‚Üí PSNR: 18.96 dB, SSIM: 0.6293
    üíæ Checkpoint saved to checkpoints-9_10_latest/checkpoint_part_4_epoch_20.pt
    üóëÔ∏è Removing old checkpoint: checkpoints-9_10_latest/checkpoint_part_3_epoch_60.pt


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 21/100 | Time: 106.12s ---
  Train Loss ‚Üí G: 13.4264, D: 0.2230 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 22/100 | Time: 106.07s ---
  Train Loss ‚Üí G: 13.2951, D: 0.2141 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 23/100 | Time: 106.09s ---
  Train Loss ‚Üí G: 13.4128, D: 0.1930 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 24/100 | Time: 106.15s ---
  Train Loss ‚Üí G: 13.2786, D: 0.2133 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 25/100 | Time: 106.25s ---
  Train Loss ‚Üí G: 13.1044, D: 0.2400 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 26/100 | Time: 106.06s ---
  Train Loss ‚Üí G: 13.2472, D: 0.2005 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 27/100 | Time: 106.03s ---
  Train Loss ‚Üí G: 13.1499, D: 0.2091 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 28/100 | Time: 106.04s ---
  Train Loss ‚Üí G: 13.1098, D: 0.2366 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 29/100 | Time: 106.10s ---
  Train Loss ‚Üí G: 13.1103, D: 0.2063 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 30/100 | Time: 106.09s ---
  Train Loss ‚Üí G: 13.4512, D: 0.2010 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 31/100 | Time: 106.10s ---
  Train Loss ‚Üí G: 12.8085, D: 0.1714 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 32/100 | Time: 105.96s ---
  Train Loss ‚Üí G: 12.8639, D: 0.1802 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 33/100 | Time: 106.08s ---
  Train Loss ‚Üí G: 12.8336, D: 0.1735 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 34/100 | Time: 106.02s ---
  Train Loss ‚Üí G: 12.8965, D: 0.1749 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 35/100 | Time: 106.10s ---
  Train Loss ‚Üí G: 12.8919, D: 0.1748 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 36/100 | Time: 105.98s ---
  Train Loss ‚Üí G: 12.8569, D: 0.1727 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 37/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 12.8076, D: 0.1718 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 38/100 | Time: 106.22s ---
  Train Loss ‚Üí G: 12.8270, D: 0.1735 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 39/100 | Time: 106.10s ---
  Train Loss ‚Üí G: 12.7859, D: 0.1729 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 40/100 | Time: 106.05s ---
  Train Loss ‚Üí G: 12.8283, D: 0.1734 | LR: 2.0e-05
  Running evaluation for epoch 40...
  Val Metrics ‚Üí PSNR: 18.69 dB, SSIM: 0.6179
    üíæ Checkpoint saved to checkpoints-9_10_latest/checkpoint_part_4_epoch_40.pt
    üóëÔ∏è Removing old checkpoint: checkpoints-9_10_latest/checkpoint_part_3_epoch_80.pt


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 41/100 | Time: 106.17s ---
  Train Loss ‚Üí G: 12.8416, D: 0.1729 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 42/100 | Time: 106.12s ---
  Train Loss ‚Üí G: 12.8540, D: 0.1701 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 43/100 | Time: 106.16s ---
  Train Loss ‚Üí G: 12.8114, D: 0.1719 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 44/100 | Time: 106.17s ---
  Train Loss ‚Üí G: 12.8160, D: 0.1731 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 45/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 12.7743, D: 0.1707 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 46/100 | Time: 106.04s ---
  Train Loss ‚Üí G: 12.7838, D: 0.1716 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 47/100 | Time: 106.13s ---
  Train Loss ‚Üí G: 12.8499, D: 0.1706 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 48/100 | Time: 106.14s ---
  Train Loss ‚Üí G: 12.8390, D: 0.1724 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 49/100 | Time: 106.10s ---
  Train Loss ‚Üí G: 12.7931, D: 0.1758 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 50/100 | Time: 106.06s ---
  Train Loss ‚Üí G: 12.7599, D: 0.1749 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 51/100 | Time: 106.05s ---
  Train Loss ‚Üí G: 12.7685, D: 0.1735 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 52/100 | Time: 106.00s ---
  Train Loss ‚Üí G: 12.8043, D: 0.1733 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 53/100 | Time: 105.97s ---
  Train Loss ‚Üí G: 12.7852, D: 0.1772 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 54/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 12.7741, D: 0.1752 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 55/100 | Time: 106.10s ---
  Train Loss ‚Üí G: 12.7971, D: 0.1722 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 56/100 | Time: 105.96s ---
  Train Loss ‚Üí G: 12.7592, D: 0.1741 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 57/100 | Time: 106.08s ---
  Train Loss ‚Üí G: 12.7480, D: 0.1723 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 58/100 | Time: 106.05s ---
  Train Loss ‚Üí G: 12.8057, D: 0.1753 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 59/100 | Time: 106.06s ---
  Train Loss ‚Üí G: 12.7692, D: 0.1733 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 91/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 12.7646, D: 0.1763 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 92/100 | Time: 106.17s ---
  Train Loss ‚Üí G: 12.8635, D: 0.1828 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 93/100 | Time: 106.10s ---
  Train Loss ‚Üí G: 12.8274, D: 0.1802 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 94/100 | Time: 106.05s ---
  Train Loss ‚Üí G: 12.7940, D: 0.1798 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 95/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 12.7810, D: 0.1801 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 4/10 | Epoch 96/100 | Time: 106.19s ---
  Train Loss ‚Üí G: 12.7593, D: 0.1839 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 82/100 | Time: 106.14s ---
  Train Loss ‚Üí G: 12.8432, D: 0.1742 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 83/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 12.7912, D: 0.1738 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 84/100 | Time: 106.24s ---
  Train Loss ‚Üí G: 12.7891, D: 0.1716 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 85/100 | Time: 106.16s ---
  Train Loss ‚Üí G: 12.7366, D: 0.1784 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 86/100 | Time: 106.06s ---
  Train Loss ‚Üí G: 12.7546, D: 0.1700 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 87/100 | Time: 106.07s ---
  Train Loss ‚Üí G: 12.8061, D: 0.1767 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 88/100 | Time: 106.25s ---
  Train Loss ‚Üí G: 12.7480, D: 0.1758 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 89/100 | Time: 106.24s ---
  Train Loss ‚Üí G: 12.7081, D: 0.1748 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 90/100 | Time: 106.14s ---
  Train Loss ‚Üí G: 12.7519, D: 0.1715 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 91/100 | Time: 106.13s ---
  Train Loss ‚Üí G: 12.8252, D: 0.1738 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 92/100 | Time: 106.14s ---
  Train Loss ‚Üí G: 12.8210, D: 0.1709 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 93/100 | Time: 106.20s ---
  Train Loss ‚Üí G: 12.8851, D: 0.1777 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 94/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 12.8966, D: 0.1741 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 95/100 | Time: 106.09s ---
  Train Loss ‚Üí G: 12.8456, D: 0.1789 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 96/100 | Time: 106.21s ---
  Train Loss ‚Üí G: 12.8086, D: 0.1740 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 97/100 | Time: 106.20s ---
  Train Loss ‚Üí G: 12.7331, D: 0.1745 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 98/100 | Time: 106.16s ---
  Train Loss ‚Üí G: 12.8202, D: 0.1754 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 5/10 | Epoch 99/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 12.9313, D: 0.1756 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                         


--- Epoch Summary: Part 5/10 | Epoch 100/100 | Time: 106.15s ---
  Train Loss ‚Üí G: 12.9013, D: 0.1782 | LR: 2.0e-05
  Running evaluation for epoch 100...
  Val Metrics ‚Üí PSNR: 18.76 dB, SSIM: 0.6181
    üíæ Checkpoint saved to checkpoints-9_10_latest/checkpoint_part_5_epoch_100.pt
    üóëÔ∏è Removing old checkpoint: checkpoints-9_10_latest/checkpoint_part_5_epoch_40.pt

--- Starting Training on Part 6/10 ---
Dataset size: 13394 images
‚ú® Resetting optimizers and using MultiStepLR schedule.


                                                                                                                       


--- Epoch Summary: Part 6/10 | Epoch 1/100 | Time: 106.50s ---
  Train Loss ‚Üí G: 14.1000, D: 0.2380 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 6/10 | Epoch 2/100 | Time: 106.09s ---
  Train Loss ‚Üí G: 13.7721, D: 0.2152 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 6/10 | Epoch 3/100 | Time: 106.06s ---
  Train Loss ‚Üí G: 13.8178, D: 0.2117 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 6/10 | Epoch 4/100 | Time: 106.14s ---
  Train Loss ‚Üí G: 13.8037, D: 0.2074 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                       


--- Epoch Summary: Part 6/10 | Epoch 5/100 | Time: 106.08s ---
  Train Loss ‚Üí G: 13.5427, D: 0.2149 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 17/100 | Time: 106.13s ---
  Train Loss ‚Üí G: 13.2684, D: 0.1960 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 18/100 | Time: 106.15s ---
  Train Loss ‚Üí G: 13.2438, D: 0.2033 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 19/100 | Time: 106.05s ---
  Train Loss ‚Üí G: 12.9470, D: 0.2224 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 20/100 | Time: 106.12s ---
  Train Loss ‚Üí G: 13.0519, D: 0.2136 | LR: 2.0e-04
  Running evaluation for epoch 20...
  Val Metrics ‚Üí PSNR: 18.73 dB, SSIM: 0.6229
    üíæ Checkpoint saved to checkpoints-9_10_latest/checkpoint_part_6_epoch_20.pt
    üóëÔ∏è Removing old checkpoint: checkpoints-9_10_latest/checkpoint_part_5_epoch_60.pt


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 21/100 | Time: 106.25s ---
  Train Loss ‚Üí G: 13.1909, D: 0.2086 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 22/100 | Time: 106.15s ---
  Train Loss ‚Üí G: 13.2935, D: 0.1898 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 23/100 | Time: 106.02s ---
  Train Loss ‚Üí G: 13.0641, D: 0.1883 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 24/100 | Time: 106.06s ---
  Train Loss ‚Üí G: 12.8849, D: 0.2171 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 25/100 | Time: 106.13s ---
  Train Loss ‚Üí G: 12.9718, D: 0.2074 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 26/100 | Time: 106.17s ---
  Train Loss ‚Üí G: 13.0796, D: 0.1946 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 27/100 | Time: 106.20s ---
  Train Loss ‚Üí G: 13.0054, D: 0.2033 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 28/100 | Time: 106.25s ---
  Train Loss ‚Üí G: 13.0201, D: 0.1934 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 29/100 | Time: 106.23s ---
  Train Loss ‚Üí G: 12.9597, D: 0.1959 | LR: 2.0e-04
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 30/100 | Time: 106.19s ---
  Train Loss ‚Üí G: 12.7926, D: 0.2078 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 31/100 | Time: 106.08s ---
  Train Loss ‚Üí G: 12.5291, D: 0.1742 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 32/100 | Time: 106.17s ---
  Train Loss ‚Üí G: 12.6733, D: 0.1732 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 33/100 | Time: 106.13s ---
  Train Loss ‚Üí G: 12.6773, D: 0.1757 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 34/100 | Time: 106.36s ---
  Train Loss ‚Üí G: 12.5549, D: 0.1694 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 35/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 12.7329, D: 0.1710 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 36/100 | Time: 106.13s ---
  Train Loss ‚Üí G: 12.7329, D: 0.1693 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 37/100 | Time: 106.30s ---
  Train Loss ‚Üí G: 12.6048, D: 0.1740 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 38/100 | Time: 106.10s ---
  Train Loss ‚Üí G: 12.6546, D: 0.1688 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 39/100 | Time: 106.04s ---
  Train Loss ‚Üí G: 12.6848, D: 0.1690 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 40/100 | Time: 106.05s ---
  Train Loss ‚Üí G: 12.6592, D: 0.1756 | LR: 2.0e-05
  Running evaluation for epoch 40...
  Val Metrics ‚Üí PSNR: 18.99 dB, SSIM: 0.6265
    üíæ Checkpoint saved to checkpoints-9_10_latest/checkpoint_part_6_epoch_40.pt
    üóëÔ∏è Removing old checkpoint: checkpoints-9_10_latest/checkpoint_part_5_epoch_80.pt


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 41/100 | Time: 106.10s ---
  Train Loss ‚Üí G: 12.5814, D: 0.1678 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 42/100 | Time: 106.06s ---
  Train Loss ‚Üí G: 12.6800, D: 0.1673 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 43/100 | Time: 106.08s ---
  Train Loss ‚Üí G: 12.7373, D: 0.1699 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 44/100 | Time: 106.11s ---
  Train Loss ‚Üí G: 12.7050, D: 0.1679 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


                                                                                                                        


--- Epoch Summary: Part 6/10 | Epoch 45/100 | Time: 106.09s ---
  Train Loss ‚Üí G: 12.7182, D: 0.1668 | LR: 2.0e-05
  Skipping evaluation, checkpointing, and graphing for this epoch.


Part 6/10 | Epoch 46/100:  76%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 319/419 [01:20<00:25,  3.97it/s, D_Loss=0.1635, G_L1=11.5379, G_Loss=13.1101]

In [None]:
print("--- Loading best performing model for final evaluation ---")

# 1. Initialize a new generator
# Ensure the Generator class definition is available
try:
    best_gen = Generator(in_channels=1, features=64).to(config["DEVICE"])
except NameError:
    print("Error: The 'Generator' class is not defined.")
    print("Please ensure the cell containing the Generator class definition has been run.")
    # You might want to stop execution here or handle this error appropriately
    best_gen = None 

if best_gen is not None:
    # 2. Load the saved best checkpoint
    # We use .get() for safety, though we know 'gen_state_dict' exists
    try:
        checkpoint = torch.load(config["BEST_CHECKPOINT_PATH"], map_location=config["DEVICE"])
        best_gen.load_state_dict(checkpoint.get('gen_state_dict'))
    except FileNotFoundError:
        print(f"Error: Best checkpoint file not found at {config['BEST_CHECKPOINT_PATH']}")
        print("Please ensure training has run and 'checkpoints/best_model.pth' exists.")
        best_gen = None # Set to None so subsequent cells can check
    except Exception as e:
        print(f"Error loading best model: {e}")
        best_gen = None # Set to None so subsequent cells can check

    if best_gen is not None:
        # 3. Set the model to evaluation mode
        best_gen.eval()
        print("‚úÖ Best model loaded successfully and set to eval() mode.")

In [None]:
import numpy as np
import torch
from tqdm import tqdm # Make sure tqdm is imported if not already

# Note: We assume the torch-based metric functions (psnr_torch, ssim_torch, denorm01, create_window)
# and the loss function (L1_LOSS) are already defined in your notebook from previous cells.

if 'best_gen' in locals() and best_gen is not None:
    print("üöÄ Starting final evaluation on the validation dataset (using best model)...")

    # Set the generator to evaluation mode (good practice, though already done)
    best_gen.eval() 

    # Accumulators for our metrics
    total_l1_loss = 0.0
    total_psnr = 0.0
    total_ssim = 0.0
    total_samples = 0

    # We use config["DEVICE"] for consistency
    device = config["DEVICE"]

    # Pre-create the SSIM window and move it to the device once
    ssim_window = None
    try:
        # Assumes create_window is defined
        ssim_window = create_window(11, 3).float().to(device)
    except NameError:
        print("Warning: 'create_window' function is not defined. SSIM will not be calculated.")
    except Exception as e:
        print(f"Warning: Error creating SSIM window: {e}. SSIM will not be calculated.")

    # Disable gradients for evaluation
    with torch.no_grad():
        loop = tqdm(val_loader, desc="Evaluating Best Model")
        
        # Define flags to only show helper function errors once
        denorm_error_shown = False
        metrics_error_shown = False
        l1_error_shown = False

        for thermal_img, real_rgb in loop:
            # Move data to the device
            thermal_img, real_rgb = thermal_img.to(device), real_rgb.to(device)

            # Run the generator
            with torch.amp.autocast(device_type=device.type, enabled=(device.type == "cuda")):
                fake_rgb = best_gen(thermal_img) # Use best_gen
                
                # Calculate and accumulate L1 loss
                try:
                    loss_g_l1_raw = L1_LOSS(fake_rgb, real_rgb)
                    total_l1_loss += loss_g_l1_raw.item() * thermal_img.size(0)
                except NameError:
                    if not l1_error_shown:
                        print("Error: 'L1_LOSS' (nn.L1Loss) is not defined. L1 metric cannot be calculated.")
                        l1_error_shown = True
                except Exception as e:
                    if not l1_error_shown:
                        print(f"Error calculating L1 loss: {e}")
                        l1_error_shown = True

            # Denormalize images *on the GPU* for PSNR/SSIM
            try:
                real_rgb_denorm = denorm01(real_rgb).float()
                fake_rgb_denorm = denorm01(fake_rgb).float()
            except NameError:
                if not denorm_error_shown:
                    print("Error: 'denorm01' function not defined. PSNR/SSIM cannot be calculated.")
                    denorm_error_shown = True
                break # Stop evaluation loop if helpers are missing
            except Exception as e:
                if not denorm_error_shown:
                    print(f"Error during denorm: {e}")
                    denorm_error_shown = True
                break
                
            # Calculate metrics for the whole batch on the GPU
            try:
                total_psnr += psnr_torch(real_rgb_denorm, fake_rgb_denorm, data_range=1.0).item() * thermal_img.size(0)
                if ssim_window is not None:
                    total_ssim += ssim_torch(real_rgb_denorm, fake_rgb_denorm, window=ssim_window, val_range=1.0).item() * thermal_img.size(0)
            except NameError:
                if not metrics_error_shown:
                    print("Error: 'psnr_torch' or 'ssim_torch' not defined. Metrics cannot be calculated.")
                    metrics_error_shown = True
                break # Stop evaluation loop
            except Exception as e:
                if not metrics_error_shown:
                    print(f"Error calculating metrics: {e}")
                    metrics_error_shown = True
                break
                
            total_samples += thermal_img.size(0)

    # Calculate the final averages
    if total_samples > 0:
        avg_l1 = total_l1_loss / total_samples
        avg_psnr = total_psnr / total_samples
        avg_ssim = total_ssim / total_samples if ssim_window is not None else 0.0

        print("\n" + "="*30)
        print("‚úÖ Final 'Best Model' Results")
        print("="*30)
        print(f"Total Samples Evaluated: {total_samples}")
        if not l1_error_shown:
            print(f"Average L1 Loss:       {avg_l1:.4f}")
        if not metrics_error_shown and not denorm_error_shown:
            print(f"Average PSNR:          {avg_psnr:.2f} dB")
            if ssim_window is not None:
                print(f"Average SSIM:          {avg_ssim:.3f}")
        print("="*30)
    elif not (denorm_error_shown or metrics_error_shown):
        print("\nEvaluation did not run. Check val_loader or other errors.")
else:
    print("Skipping quantitative evaluation because 'best_gen' was not loaded successfully.")

In [None]:
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import numpy as np

# This is the function definition
def show_test_images(gen, loader, device, num_images=5):
    """
    Visualizes the output of the generator on a few validation samples.
    """
    gen.eval()  # Set the generator to evaluation mode
    
    # Get a batch of validation data
    try:
        t, r = next(iter(loader))
    except StopIteration:
        print("Data loader is empty.")
        return
    except NameError:
        print("Error: 'val_loader' is not defined.")
        return

    t, r = t.to(device), r.to(device)
    
    with torch.no_grad():
        fake_rgb = gen(t).detach().cpu()

    # Denormalize images from [-1, 1] to [0, 1] for plotting
    t = (t.cpu() + 1) / 2
    r = (r.cpu() + 1) / 2
    fake_rgb = (fake_rgb + 1) / 2
    
    # Ensure we don't try to show more images than are in the batch
    num_images = min(num_images, len(t))

    plt.figure(figsize=(15, num_images * 5))
    plt.suptitle("Best Model Results", fontsize=20) # Updated title

    for i in range(num_images):
        # Plot Input Thermal Image (convert grayscale to RGB for consistent display)
        plt.subplot(num_images, 3, i * 3 + 1)
        plt.imshow(t[i].permute(1, 2, 0).squeeze(), cmap='gray')
        plt.title("Input Thermal")
        plt.axis('off')

        # Plot Generated RGB Image
        plt.subplot(num_images, 3, i * 3 + 2)
        plt.imshow(fake_rgb[i].permute(1, 2, 0))
        plt.title("Generated RGB (Best Model)")
        plt.axis('off')

        # Plot Ground Truth RGB Image
        plt.subplot(num_images, 3, i * 3 + 3)
        plt.imshow(r[i].permute(1, 2, 0))
        plt.title("Ground Truth RGB")
        plt.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# --- Call the function ---
if 'best_gen' in locals() and best_gen is not None:
    print("--- Showing visual results from BEST model ---")
    try:
        show_test_images(best_gen, val_loader, config["DEVICE"], num_images=5)
    except NameError as e:
        print(f"Error calling show_test_images: {e}")
        print("Please ensure 'val_loader' and 'config' are defined.")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
else:
    print("Skipping visual evaluation because 'best_gen' was not loaded successfully.")