In [None]:
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
from tqdm import tqdm

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

torch.backends.cudnn.benchmark = True 

In [None]:
import torch
print(torch.version.cuda)          # version CUDA utilisée
print(torch.cuda.is_available())   # True si GPU reconnu
print(torch.cuda.get_device_name(0))  # nom du GPU


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.d1 = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.ReLU(True)
        )
        self.d2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True)
        )

        self.d3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256, affine=True),
            nn.ReLU(True)
        )

        self.u1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True)
        )

        self.u2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 64, 3, 1, 1),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(True)
        )

        self.u3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(d1)
        d3 = self.d3(d2)

        u1 = self.u1(d3)
        u1 = torch.cat([u1, d2], dim=1)

        u2 = self.u2(u1)
        u2 = torch.cat([u2, d1], dim=1)

        return self.u3(u2)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 3, 1, 1)
        )

    def forward(self, img_cond, img):
        x = torch.cat([img_cond, img], dim=1)
        return self.model(x)

In [None]:
def tv_loss(x):
    return torch.mean(torch.abs(x[:, :, :-1] - x[:, :, 1:])) + \
           torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))

In [None]:
def denorm01(x):
    # [-1,1] -> [0,1]
    return (x * 0.5) + 0.5

In [None]:
def batch_psnr_ssim(restored, target):
    # restored/target: torch tensors shape (B,3,H,W) in [-1,1]
    restored = denorm01(restored).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
    target = denorm01(target).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()

    psnrs, ssims = [], []
    for i in range(restored.shape[0]):
        r = restored[i]
        t = target[i]
        psnrs.append(psnr(t, r, data_range=1.0))
        ssims.append(ssim(t, r, data_range=1.0, channel_axis=2))
    return float(np.mean(psnrs)), float(np.mean(ssims))

# TRAINING

In [None]:
if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    if device.type == "cuda":
        print("GPU:", torch.cuda.get_device_name(0))


    from paired_dataset import PairedImageDataset

    batch_size = 32

    train_dir = r"..\data\train"
    val_dir   = r"..\data\test"  

    train_dataset = PairedImageDataset(train_dir)
    val_dataset   = PairedImageDataset(val_dir) if os.path.exists(val_dir) else None


    num_workers = 0
    pin_memory = True if device.type == "cuda" else False

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
        pin_memory=pin_memory
    )

    val_loader = None
    if val_dataset is not None:
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=False,
            pin_memory=pin_memory
        )

    degraded, target = next(iter(train_loader))
    print("Batch shapes:", degraded.shape, target.shape)

    gen = Generator().to(device)
    disc = Discriminator().to(device)

    opt_g = optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_d = optim.Adam(disc.parameters(), lr=5e-5, betas=(0.5, 0.999))

    criterion = nn.BCEWithLogitsLoss()

    output_dir = "output_gan"
    os.makedirs(output_dir, exist_ok=True)

    ckpt_dir = "checkpoints_gan"
    os.makedirs(ckpt_dir, exist_ok=True)

    last_path = os.path.join(ckpt_dir, "gan_last.pth")
    best_path = os.path.join(ckpt_dir, "gan_best.pth")

    start_epoch = 1
    best_score = -1e9  # on maximise PSNR (ou SSIM)

    if os.path.exists(last_path):
        print(f"♻️ Resume depuis {last_path}")
        ckpt = torch.load(last_path, map_location=device)
        gen.load_state_dict(ckpt["gen_state_dict"])
        disc.load_state_dict(ckpt["disc_state_dict"])
        opt_g.load_state_dict(ckpt["opt_g_state_dict"])
        opt_d.load_state_dict(ckpt["opt_d_state_dict"])
        start_epoch = ckpt.get("epoch", 0) + 1
        best_score = ckpt.get("best_score", best_score)

    fixed_cond = next(iter(train_loader))[0][:25].to(device, non_blocking=pin_memory)


    epochs = 100
    save_every = 2

    # AMP (optionnel mais top sur NVIDIA)
    use_amp = (device.type == "cuda")
    scaler_g = torch.cuda.amp.GradScaler(enabled=use_amp)
    scaler_d = torch.cuda.amp.GradScaler(enabled=use_amp)

    def save_grid_samples(epoch):
        with torch.no_grad():
            gen.eval()
            samples = gen(fixed_cond)
            samples = denorm01(samples).clamp(0, 1).cpu()
            grid = make_grid(samples, nrow=5)
            save_image(grid, os.path.join(output_dir, f"generated_epoch_{epoch}.png"))
            gen.train()

    def save_val_comparison(epoch):
        if val_loader is None:
            return
        with torch.no_grad():
            gen.eval()
            degraded_v, target_v = next(iter(val_loader))
            degraded_v = degraded_v[:8].to(device, non_blocking=pin_memory)
            target_v = target_v[:8].to(device, non_blocking=pin_memory)
            restored_v = gen(degraded_v)
            comp = torch.cat([denorm01(degraded_v), denorm01(restored_v), denorm01(target_v)], dim=0).clamp(0, 1)
            save_image(comp, os.path.join(output_dir, f"val_epoch_{epoch}.png"), nrow=8)
            gen.train()



    for epoch in range(start_epoch, epochs + 1):
        gen.train()
        disc.train()

        g_losses, d_losses = [], []

        pbar = tqdm(train_loader, desc=f"Epoch [{epoch}/{epochs}]", file=sys.stdout)

        for degraded, target in pbar:
            degraded = degraded.to(device, non_blocking=pin_memory)
            target = target.to(device, non_blocking=pin_memory)


            opt_d.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=use_amp):
                out_real = disc(degraded, target)
                real_labels = torch.full_like(out_real, 0.9, device=device) 
                d_loss_real = criterion(out_real, real_labels)

                fake = gen(degraded)
                out_fake = disc(degraded, fake.detach())
                fake_labels = torch.zeros_like(out_fake, device=device)
                d_loss_fake = criterion(out_fake, fake_labels)

                d_loss = 0.5 * (d_loss_real + d_loss_fake)

            scaler_d.scale(d_loss).backward()
            scaler_d.step(opt_d)
            scaler_d.update()

            opt_g.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=use_amp):
                out_fake_for_g = disc(degraded, fake)
                real_labels_g = torch.ones_like(out_fake_for_g, device=device)


                if epoch < 7:
                    lambda_l1 = 40
                elif epoch < 25:
                    lambda_l1 = 15
                else:
                    lambda_l1 = 7

                adv_loss = criterion(out_fake_for_g, real_labels_g)
                l1_loss = torch.nn.functional.l1_loss(fake, target)

                if epoch <= 7:
                    g_loss = lambda_l1 * l1_loss
                else:
                    tv = tv_loss(fake)
                    g_loss = adv_loss + lambda_l1 * l1_loss + 0.005 * tv

            scaler_g.scale(g_loss).backward()
            scaler_g.step(opt_g)
            scaler_g.update()

            g_losses.append(g_loss.item())
            d_losses.append(d_loss.item())

            pbar.set_postfix({
                "G": f"{g_loss.item():.3f}",
                "D": f"{d_loss.item():.3f}"
            })

        train_g = float(np.mean(g_losses))
        train_d = float(np.mean(d_losses))


        val_psnr, val_ssim = None, None
        if val_loader is not None:
            gen.eval()
            psnr_list, ssim_list = [], []
            with torch.no_grad():
                vbar = tqdm(val_loader, desc=f"Epoch [{epoch}/{epochs}] Val", leave=False, file=sys.stdout)
                for deg_v, tgt_v in vbar:
                    deg_v = deg_v.to(device, non_blocking=pin_memory)
                    tgt_v = tgt_v.to(device, non_blocking=pin_memory)
                    res_v = gen(deg_v)
                    p, s = batch_psnr_ssim(res_v, tgt_v)
                    psnr_list.append(p)
                    ssim_list.append(s)

            val_psnr = float(np.mean(psnr_list))
            val_ssim = float(np.mean(ssim_list))

        if val_psnr is not None:
            score = val_psnr + 0.01 * val_ssim
        else:
            score = -train_g

        print(f"\nEpoch {epoch} | G={train_g:.4f} D={train_d:.4f}" +
              (f" | Val PSNR={val_psnr:.2f} SSIM={val_ssim:.4f}" if val_psnr is not None else ""))


        if epoch % save_every == 0 or epoch in (1, epochs):
            save_grid_samples(epoch)
            save_val_comparison(epoch)


        is_best = score > best_score
        if is_best:
            best_score = score

        save_obj = {
            "epoch": epoch,
            "best_score": best_score,
            "gen_state_dict": gen.state_dict(),
            "disc_state_dict": disc.state_dict(),
            "opt_g_state_dict": opt_g.state_dict(),
            "opt_d_state_dict": opt_d.state_dict(),
            "train_g": train_g,
            "train_d": train_d,
            "val_psnr": val_psnr,
            "val_ssim": val_ssim,
        }

        torch.save(save_obj, last_path)

        if is_best:
            torch.save(save_obj, best_path)
            print(f"BEST updated -> saved to {best_path}")

    print("\nTraining terminé.")