In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from torchvision.utils import save_image
from PIL import Image
import os
from tqdm import tqdm
import sys
import pandas as pd

In [20]:
torch.backends.cudnn.benchmark = True

In [21]:
# --- 1. PERCEPTUAL LOSS (VGG) ---
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # On utilise les features de VGG16 (jusqu'√† la couche relu3_3 environ)
        vgg = models.vgg16(weights="IMAGENET1K_V1").features[:16]
        self.vgg = vgg.eval()
        for p in self.vgg.parameters():
            p.requires_grad = False

    def forward(self, x, y):
        return nn.functional.l1_loss(self.vgg(x), self.vgg(y))

In [22]:
# --- VAE-UNET HYBRIDE ---
class VAE_UNet(nn.Module):
    def __init__(self, latent_dim=64):
        super(VAE_UNet, self).__init__()
        
        self.latent_dim = latent_dim
        
        # ENCODEUR avec skip connections
        self.enc1 = self.conv_block(3, 64)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.enc2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.enc3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.enc4 = self.conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        # BOTTLENECK - Espace latent VAE
        self.bottleneck = self.conv_block(512, 1024)
        
        # Couches VAE (mu et logvar)
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(1024 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(1024 * 8 * 8, latent_dim)
        
        # Reconstruction depuis l'espace latent
        self.fc_decode = nn.Linear(latent_dim, 1024 * 8 * 8)
        self.unflatten = nn.Unflatten(1, (1024, 8, 8))
        
        # D√âCODEUR avec skip connections (style U-Net)
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = self.conv_block(1024, 512)  # 1024 = 512 (up) + 512 (skip)
        
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self.conv_block(512, 256)  # 512 = 256 (up) + 256 (skip)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self.conv_block(256, 128)  # 256 = 128 (up) + 128 (skip)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self.conv_block(128, 64)  # 128 = 64 (up) + 64 (skip)
        
        # Sortie finale
        self.out = nn.Conv2d(64, 3, 1)
    
    def conv_block(self, in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(0.2, inplace=True), # LeakyReLU est souvent mieux pour Tanh
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(0.2, inplace=True)
            )
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def encode(self, x):
        # Encodeur avec sauvegarde des skip connections
        e1 = self.enc1(x)          # 128x128x64
        p1 = self.pool1(e1)
        
        e2 = self.enc2(p1)         # 64x64x128
        p2 = self.pool2(e2)
        
        e3 = self.enc3(p2)         # 32x32x256
        p3 = self.pool3(e3)
        
        e4 = self.enc4(p3)         # 16x16x512
        p4 = self.pool4(e4)
        
        # Bottleneck
        b = self.bottleneck(p4)    # 8x8x1024
        
        # VAE latent space
        flat = self.flatten(b)
        mu = self.fc_mu(flat)
        logvar = self.fc_logvar(flat)
        
        return mu, logvar, e1, e2, e3, e4
    
    def decode(self, z, e1, e2, e3, e4):
        # On repart du vecteur latent
        x = self.fc_decode(z)
        x = self.unflatten(x) 
        
        # Passage dans le d√©codeur avec Skip Connections
        d4 = self.up4(x)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        
        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        out = self.out(d1)
        return torch.tanh(out) # Tanh selon conseil n¬∞3 (plage -1 √† 1)

    def forward(self, x):
        # 1. On encode
        mu, logvar, e1, e2, e3, e4 = self.encode(x)
        # 2. On reparam√®tre (le "bruit" du VAE)
        z = self.reparameterize(mu, logvar)
        # 3. On d√©code avec les skips
        recon = self.decode(z, e1, e2, e3, e4)
        
        return recon, mu, logvar

In [23]:
# --- 3. LOSS AVEC ANNEALING ET VGG ---
def hybrid_loss(recon, target, mu, logvar, epoch, perceptual_model, beta_max=1e-4):
    # KL Annealing : Beta augmente progressivement les 20 premi√®res √©poques
    beta = min(beta_max, (epoch + 1) / 20 * beta_max)
    
    l1 = nn.functional.l1_loss(recon, target)
    mse = nn.functional.mse_loss(recon, target)
    vgg_loss = perceptual_model(recon, target)
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    
    # √âquilibrage des poids de la loss
    total_loss = l1 + 0.5 * mse + 0.1 * vgg_loss + beta * kld
    return total_loss, l1.item(), vgg_loss.item(), kld.item()

In [24]:
class RestorationDataset(Dataset):
    def __init__(self, degraded_dir, clean_dir, transform=None):
        self.degraded_dir = degraded_dir
        self.clean_dir = clean_dir
        self.transform = transform
        
        print("üîç Indexation des fichiers en cours...")
        # On liste les deux dossiers
        degraded_list = os.listdir(degraded_dir)
        clean_list_set = set(os.listdir(clean_dir)) # TR√àS IMPORTANT : Le set rend la recherche instantan√©e
        
        self.filenames = []
        for f in degraded_list:
            # On g√©n√®re le nom correspondant dans le dossier clean
            clean_target = f.replace("degraded_", "")
            if clean_target in clean_list_set:
                self.filenames.append(f)
        
        print(f"Indexation termin√©e : {len(self.filenames)} paires d'images trouv√©es.")

    def __len__(self): 
        return len(self.filenames)

    def __getitem__(self, idx):
        deg_fn = self.filenames[idx]
        cln_fn = deg_fn.replace("degraded_", "")
        
        deg_img = Image.open(os.path.join(self.degraded_dir, deg_fn)).convert('RGB')
        cln_img = Image.open(os.path.join(self.clean_dir, cln_fn)).convert('RGB')
        
        seed = torch.seed()
        torch.manual_seed(seed)
        deg_img = self.transform(deg_img)
        torch.manual_seed(seed)
        cln_img = self.transform(cln_img)
        
        return deg_img, cln_img

In [25]:
def vae_loss_function(recon, target, mu, logvar, beta=0.0001):
    """
    Loss VAE optimis√©e :
    - L1 : nettet√©
    - MSE : structure
    - KL divergence : r√©gularisation
    """
    l1_loss = nn.functional.l1_loss(recon, target, reduction='sum')
    mse_loss = nn.functional.mse_loss(recon, target, reduction='sum')
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    total_loss = l1_loss + 0.5 * mse_loss + beta * kld_loss
    
    return total_loss, l1_loss.item(), mse_loss.item(), kld_loss.item()




In [26]:
def save_samples(model, dataloader, epoch, device, output_dir="samples", num_samples=8):
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    
    with torch.no_grad():
        degraded, clean = next(iter(dataloader))
        degraded, clean = degraded[:num_samples].to(device), clean[:num_samples].to(device)
        restored, _, _ = model(degraded)
        
        # --- √âTAPE DE D√â-NORMALISATION ---
        # On repasse de [-1, 1] √† [0, 1] pour que l'image soit lisible
        def denorm(x):
            return (x * 0.5) + 0.5
        
        comparison = torch.cat([denorm(degraded), denorm(restored), denorm(clean)], dim=0)
        
        save_image(comparison, 
                   os.path.join(output_dir, f"epoch_{epoch:03d}.png"),
                   nrow=num_samples, 
                   normalize=False) # Important de mettre False ici
    
    model.train()

In [27]:
# --- 5. MAIN TRAINING LOOP ---
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialisation du mod√®le
    model = VAE_UNet(latent_dim=64).to(device)
    perceptual_fn = VGGPerceptualLoss().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    
    # --- LOGIQUE DE REPRISE (RESUME) ---
    checkpoint_path = "vae_unet_last.pth" if os.path.exists("vae_unet_last.pth") else "vae_unet_best.pth"
    start_epoch = 0
    best_val_loss = float('inf')

    if os.path.exists(checkpoint_path):
        print(f"Chargement du checkpoint : {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print("√âtat de l'optimiseur charg√©.")
            
        start_epoch = checkpoint.get('epoch', 0)
        best_val_loss = checkpoint.get('loss', float('inf'))
        print(f"Reprise du travail √† partir de l'√©poque {start_epoch + 1}")
    else:
        print("Aucun checkpoint trouv√©. D√©marrage d'un nouvel entra√Ænement.")

    # --- LOGIQUE DE REPRISE DE L'HISTORIQUE (FIX VALUERROR) ---
    history_path = "training_history.csv"
    if os.path.exists(history_path):
        try:
            # On charge l'ancien CSV pour ne pas avoir de listes de tailles diff√©rentes
            history = pd.read_csv(history_path).to_dict(orient='list')
            print(f"Historique charg√© ({len(history['epoch'])} √©poques trouv√©es).")
        except Exception:
            history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'l1': [], 'vgg': [], 'kld': []}
    else:
        history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'l1': [], 'vgg': [], 'kld': []}

    epochs_to_add = 27 
    total_epochs = start_epoch + epochs_to_add

    # --- DATALOADERS (Identiques) ---
    # Define train_transform and val_transform
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor()
    ])
    val_transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    dataset = RestorationDataset("../../data/train/degraded_images/", "../../data/train/images/", train_transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0, pin_memory=True)
    val_dataset = RestorationDataset("../../data/test/degraded_images/", "../../data/test/images/", val_transform)
    val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
    
    os.makedirs("samples", exist_ok=True)

    try:
        for epoch in range(start_epoch, total_epochs):
            model.train()
            train_epoch_loss, epoch_l1, epoch_vgg, epoch_kld = 0, 0, 0, 0
            train_bar = tqdm(dataloader, desc=f"√âpoque {epoch+1}/{total_epochs} [Train]", file=sys.stdout)
            
            for batch_idx, (degraded, clean) in enumerate(train_bar):
                degraded, clean = degraded.to(device), clean.to(device)
                optimizer.zero_grad()
                restored, mu, logvar = model(degraded)
                loss, l1, vgg, kld = hybrid_loss(restored, clean, mu, logvar, epoch, perceptual_fn)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
                train_epoch_loss += loss.item()
                epoch_l1 += l1; epoch_vgg += vgg; epoch_kld += kld
                
                if batch_idx % 10 == 0:
                    train_bar.set_postfix({'L1': f'{l1:.4f}', 'VGG': f'{vgg:.4f}'})

                if batch_idx % 500 == 0 and batch_idx > 0:
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss.item(),
                    }, "vae_unet_last.pth")
            
            avg_train_loss = train_epoch_loss / len(dataloader)

            # Validation
            model.eval()
            val_epoch_loss = 0
            val_bar = tqdm(val_dataloader, desc=f"√âpoque {epoch+1}/{total_epochs} [Val]", file=sys.stdout, leave=False)
            with torch.no_grad():
                for deg_v, cln_v in val_bar:
                    deg_v, cln_v = deg_v.to(device), cln_v.to(device)
                    res_v, mu_v, logvar_v = model(deg_v)
                    v_loss, _, _, _ = hybrid_loss(res_v, cln_v, mu_v, logvar_v, epoch, perceptual_fn)
                    val_epoch_loss += v_loss.item()
            
            avg_val_loss = val_epoch_loss / len(val_dataloader)
            scheduler.step(avg_val_loss)

            print(f"\nFIN √âPOQUE {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
            
            # --- SAUVEGARDE √âTAT ---
            save_dict = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_val_loss,
            }
            torch.save(save_dict, "vae_unet_last.pth")

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(save_dict, "vae_unet_best.pth")
                print(f"Nouveau record ! Mod√®le sauvegard√©.")

            # --- MISE √Ä JOUR HISTORIQUE (CORRECTION ICI) ---
            history['epoch'].append(epoch + 1)
            history['train_loss'].append(avg_train_loss)
            history['val_loss'].append(avg_val_loss)
            # ON AJOUTE LES M√âTRIQUES MANQUANTES POUR √âVITER LE VALUEERROR
            history['l1'].append(epoch_l1 / len(dataloader))
            history['vgg'].append(epoch_vgg / len(dataloader))
            history['kld'].append(epoch_kld / len(dataloader))

            # Sauvegarde propre du DataFrame
            pd.DataFrame(history).to_csv(history_path, index=False)

            if (epoch + 1) % 2 == 0:
                save_samples(model, val_dataloader, epoch + 1, device)

    except KeyboardInterrupt:
        print("\nArr√™t manuel (Ctrl+C). Sauvegarde de s√©curit√©...")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_train_loss if 'avg_train_loss' in locals() else 0,
        }, "vae_unet_last.pth")
        print("√âtat sauvegard√©. √Ä bient√¥t !")

    print("\nS√©ance termin√©e !")

Aucun checkpoint trouv√©. D√©marrage d'un nouvel entra√Ænement.
üîç Indexation des fichiers en cours...
Indexation termin√©e : 66226 paires d'images trouv√©es.
üîç Indexation des fichiers en cours...
Indexation termin√©e : 16557 paires d'images trouv√©es.
√âpoque 1/27 [Train]:   2%|‚ñè         | 33/2070 [00:55<56:39,  1.67s/it, L1=0.1234, VGG=0.6356] 

Arr√™t manuel (Ctrl+C). Sauvegarde de s√©curit√©...
√âtat sauvegard√©. √Ä bient√¥t !

S√©ance termin√©e !


In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import numpy as np

In [None]:
# --- CONFIGURATION TEST (CORRIG√âE) ---
test_degraded_dir = "../../data/test/degraded_images/"
test_clean_dir = "../../data/test/images/"
output_dir = "test_results_metrics"
checkpoint_path = "vae_unet_best.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(output_dir, exist_ok=True)

# 1. On utilise le MEME latent_dim que l'entra√Ænement (64)
model = VAE_UNet(latent_dim=64).to(device) 
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
model.eval()

# 2. On utilise la MEME normalisation que l'entra√Ænement
test_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Fonction pour repasser en [0, 1] pour les calculs de PSNR/SSIM
def denorm(x):
    return (x * 0.5) + 0.5

# --- BOUCLE DE TEST ---
filenames = [f for f in os.listdir(test_degraded_dir) if f.endswith(('.jpg', '.png'))]
all_psnr, all_ssim = [], []

with torch.no_grad():
    for filename in tqdm(filenames):
        deg_path = os.path.join(test_degraded_dir, filename)
        clean_filename = filename.replace("degraded_", "")
        clean_path = os.path.join(test_clean_dir, clean_filename)

        if not os.path.exists(clean_path): continue

        # On charge et on normalise
        img_deg = test_transform(Image.open(deg_path).convert('RGB')).unsqueeze(0).to(device)
        img_clean = test_transform(Image.open(clean_path).convert('RGB')).unsqueeze(0).to(device)

        # Inf√©rence
        img_restored, _, _ = model(img_deg)

        # 3. D√â-NORMALISATION pour comparer des images entre 0 et 1
        clean_np = denorm(img_clean).squeeze().cpu().permute(1, 2, 0).numpy()
        restored_np = denorm(img_restored).squeeze().cpu().permute(1, 2, 0).numpy()

        # Calcul des scores (sur plage 1.0)
        current_psnr = psnr(clean_np, restored_np, data_range=1.0)
        current_ssim = ssim(clean_np, restored_np, data_range=1.0, channel_axis=2)

        all_psnr.append(current_psnr)
        all_ssim.append(current_ssim)

        # Sauvegarde visuelle d√©-normalis√©e
        comparison = torch.cat([denorm(img_deg), denorm(img_restored), denorm(img_clean)], dim=0)
        save_image(comparison, os.path.join(output_dir, f"score_{current_psnr:.2f}_{filename}"), nrow=3)

print(f"\nPSNR Moyen : {np.mean(all_psnr):.2f} dB | SSIM Moyen : {np.mean(all_ssim):.4f}")