In [None]:
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
from tqdm import tqdm
from torchvision import models, transforms

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

torch.backends.cudnn.benchmark = True 

In [None]:
import torch
print(torch.version.cuda)          # version CUDA utilisée
print(torch.cuda.is_available())   # True si GPU reconnu
print(torch.cuda.get_device_name(0))  # nom du GPU


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.d1 = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.ReLU(True)
        )
        self.d2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True)
        )

        self.d3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256, affine=True),
            nn.ReLU(True)
        )

        self.u1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True)
        )

        self.u2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 64, 3, 1, 1),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(True)
        )

        self.u3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(d1)
        d3 = self.d3(d2)

        # Decoder avec des skip connections
        u1 = self.u1(d3)
        u1 = torch.cat([u1, d2], dim=1)

        u2 = self.u2(u1)
        u2 = torch.cat([u2, d1], dim=1)

        return self.u3(u2)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 3, 1, 1)
        )

    def forward(self, img_cond, img):
        x = torch.cat([img_cond, img], dim=1)
        return self.model(x)

In [None]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights="IMAGENET1K_V1").features[:16]
        self.vgg = vgg.eval()
        for p in self.vgg.parameters():
            p.requires_grad = False

    def forward(self, x, y):
        # On compare les "features" entre l'image générée et la cible
        return nn.functional.l1_loss(self.vgg(x), self.vgg(y))

In [None]:
def tv_loss(x):
    return torch.mean(torch.abs(x[:, :, :-1] - x[:, :, 1:])) + \
           torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))

In [None]:
def denorm01(x):
    # [-1,1] -> [0,1]
    return (x * 0.5) + 0.5

In [None]:
def lr_lambda(epoch):
    # On commence la décroissance après 50 époques (sur un total de 100)
    decay_start = 50 
    return 1.0 - max(0, epoch - decay_start) / float(100 - decay_start + 1)

In [None]:
def batch_psnr_ssim(restored, target):
    # restored/target: torch tensors shape (B,3,H,W) in [-1,1]
    restored = denorm01(restored).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
    target = denorm01(target).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()

    psnrs, ssims = [], []
    for i in range(restored.shape[0]):
        r = restored[i]
        t = target[i]
        psnrs.append(psnr(t, r, data_range=1.0))
        ssims.append(ssim(t, r, data_range=1.0, channel_axis=2))
    return float(np.mean(psnrs)), float(np.mean(ssims))

In [None]:
#Fonction qui sauvegarde des comparaisons entre images dégradées, restaurées et cibles

def save_val_comparison(epoch, generator, val_loader, device, out_dir):
    if val_loader is None: return
    with torch.no_grad():
        generator.eval()
        degraded_v, target_v = next(iter(val_loader))
        degraded_v = degraded_v[:8].to(device)
        target_v = target_v[:8].to(device)
        restored_v = generator(degraded_v)

        def denorm(x): return (x * 0.5) + 0.5
        comparison = torch.cat([denorm(degraded_v), denorm(restored_v), denorm(target_v)], dim=0).clamp(0, 1)
        save_image(comparison, os.path.join(out_dir, f"val_epoch_{epoch}.png"), nrow=8)
        generator.train()

In [None]:
#Fonction qui sauvegarde une grille d'images générées

def save_grid_samples(epoch, generator, context_images, out_dir):
    with torch.no_grad():
        generator.eval()
        samples = generator(context_images)
        samples = (samples * 0.5) + 0.5 
        samples = samples.clamp(0, 1).cpu()
        grid = make_grid(samples, nrow=5)
        save_image(grid, os.path.join(out_dir, f"generated_epoch_{epoch}.png"))
        generator.train()

# TRAINING

In [None]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    if device.type == "cuda":
        print("GPU:", torch.cuda.get_device_name(0))

    from paired_dataset import PairedImageDataset
    batch_size = 128

    # --- DATALOADERS ---
    train_dir = r"..\..\data\train_10k"
    val_dir   = r"..\..\data\test"  
    train_dataset = PairedImageDataset(train_dir)
    val_dataset   = PairedImageDataset(val_dir) if os.path.exists(val_dir) else None

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                              num_workers=4, pin_memory=(device.type == "cuda"))
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
              
                            num_workers=4) if val_dataset else None

    # --- MODÈLES & OPTIMISEURS ---
    gen = Generator().to(device)
    disc = Discriminator().to(device)
    vgg_fn = VGGPerceptualLoss().to(device)
    

    # Ratio de LR 4:1 pour aider le Générateur à suivre le Discriminateur
    opt_g = optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_d = optim.Adam(disc.parameters(), lr=5e-5, betas=(0.5, 0.999))

    scheduler_g = optim.lr_scheduler.LambdaLR(opt_g, lr_lambda=lr_lambda)
    scheduler_d = optim.lr_scheduler.LambdaLR(opt_d, lr_lambda=lr_lambda)

    criterion = nn.BCEWithLogitsLoss()
    
    # --- SAUVEGARDE & REPRISE ---
    output_dir, ckpt_dir = "output_gan", "checkpoints_gan"
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(ckpt_dir, exist_ok=True)

    last_path = os.path.join(ckpt_dir, "gan_last.pth")
    best_path = os.path.join(ckpt_dir, "gan_best.pth")

    start_epoch, best_score = 1, -1e9

    if os.path.exists(last_path):
        print(f"Resume depuis {last_path}")
        ckpt = torch.load(last_path, map_location=device, weights_only=False)
        gen.load_state_dict(ckpt["gen_state_dict"])
        disc.load_state_dict(ckpt["disc_state_dict"])
        opt_g.load_state_dict(ckpt["opt_g_state_dict"])
        opt_d.load_state_dict(ckpt["opt_d_state_dict"])
        
        # CHARGEMENT DES SCHEDULERS (très important !)
        if "sched_g_state_dict" in ckpt:
            scheduler_g.load_state_dict(ckpt["sched_g_state_dict"])
            scheduler_d.load_state_dict(ckpt["sched_d_state_dict"])
            
        start_epoch = ckpt.get("epoch", 0) + 1
        best_score = ckpt.get("best_score", best_score)

    fixed_cond = next(iter(train_loader))[0][:25].to(device)
    epochs, save_every = 100, 2

    # AMP (Automated Mixed Precision)
    use_amp = (device.type == "cuda")
    scaler_g = torch.amp.GradScaler('cuda', enabled=use_amp)
    scaler_d = torch.amp.GradScaler('cuda', enabled=use_amp)

    # --- BOUCLE PRINCIPALE ---
    for epoch in range(start_epoch, epochs + 1):
        gen.train(); disc.train()
        g_losses, d_losses = [], []
        pbar = tqdm(train_loader, desc=f"Epoch [{epoch}/{epochs}]", file=sys.stdout)

        for degraded, target in pbar:
            degraded = degraded.to(device); target = target.to(device)

            # ---- 1. TRAIN DISCRIMINATOR (1 fois) ----
            opt_d.zero_grad(set_to_none=True)
            with torch.amp.autocast('cuda', enabled=use_amp):
                out_real = disc(degraded, target)
                d_loss_real = criterion(out_real, torch.full_like(out_real, 0.9))

                fake = gen(degraded)
                out_fake = disc(degraded, fake.detach())
                d_loss_fake = criterion(out_fake, torch.zeros_like(out_fake))
                d_loss = 0.5 * (d_loss_real + d_loss_fake)

            scaler_d.scale(d_loss).backward()
            scaler_d.step(opt_d)
            scaler_d.update()
            d_losses.append(d_loss.item())

            # ---- 2. TRAIN GENERATOR (2 fois) ----
            # On définit les coefficients ici pour qu'ils soient accessibles
            l_adv = 1.0
            l_l1 = 40.0 if epoch <= 7 else 20.0
            l_vgg = 0.5 if epoch > 7 else 0.0

            for _ in range(2): # La boucle pour les 2 entraînements du générateur --> cela aide à stabiliser l'entraînement
                opt_g.zero_grad(set_to_none=True)
                with torch.amp.autocast('cuda', enabled=use_amp):
                    # On régénère 'fake' à chaque passage car les poids changent
                    fake = gen(degraded) 
                    out_fake_for_g = disc(degraded, fake)
                    
                    adv_loss = criterion(out_fake_for_g, torch.ones_like(out_fake_for_g))
                    l1_l = torch.nn.functional.l1_loss(fake, target)
                    
                    if epoch <= 7:
                        g_loss = l_l1 * l1_l
                    else:
                        vgg_l = vgg_fn(fake, target)
                        g_loss = (l_adv * adv_loss) + (l_l1 * l1_l) + (l_vgg * vgg_l) + (0.005 * tv_loss(fake))

                scaler_g.scale(g_loss).backward()
                scaler_g.step(opt_g)
                scaler_g.update()
                
            g_losses.append(g_loss.item())
            pbar.set_postfix({"G": f"{g_loss.item():.2f}", "D": f"{d_loss.item():.2f}"})

        # --- FIN D'ÉPOQUE : VALIDATION & SAUVEGARDE ---
        train_g, train_d = np.mean(g_losses), np.mean(d_losses)
        
        if val_loader:
            gen.eval()
            psnrs, ssims = [], []
            with torch.no_grad():
                for dv, tv in val_loader:
                    dv, tv = dv.to(device), tv.to(device)
                    res = gen(dv)
                    p, s = batch_psnr_ssim(res, tv)
                    psnrs.append(p); ssims.append(s)
            val_p, val_s = np.mean(psnrs), np.mean(ssims)
            score = val_p + (0.01 * val_s)
            print(f"\nEpoch {epoch} | G={train_g:.3f} D={train_d:.3f} | PSNR={val_p:.2f} SSIM={val_s:.4f}")
        else:
            score = -train_g

        if epoch % save_every == 0 or epoch == 1:
            save_grid_samples(epoch, gen, fixed_cond, output_dir)
            save_val_comparison(epoch, gen, val_loader, device, output_dir)

        # Logique du "Best"
        is_best = score > best_score
        if is_best: best_score = score
        
        #--- SAUVEGARDE DU CHECKPOINT ---
        save_dict = {
            "epoch": epoch, 
            "best_score": best_score,
            "gen_state_dict": gen.state_dict(), 
            "disc_state_dict": disc.state_dict(),
            "opt_g_state_dict": opt_g.state_dict(), 
            "opt_d_state_dict": opt_d.state_dict(),
            "sched_g_state_dict": scheduler_g.state_dict(),
            "sched_d_state_dict": scheduler_d.state_dict(),
        }
        torch.save(save_dict, last_path)
        if is_best:
            torch.save(save_dict, best_path)
            print(f"Record battu ! Sauvegardé dans gan_best.pth")
        
        # --- MISE À JOUR DU LEARNING RATE ---
        scheduler_g.step()
        scheduler_d.step()
        
        # Petit log pour surveiller que ça baisse bien
        curr_lr_g = opt_g.param_groups[0]['lr']
        print(f"Fin Époque {epoch} | LR Générateur: {curr_lr_g:.7f}")

    print("\nTraining terminé avec succès.")