In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os
from tqdm import tqdm
import sys

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
# --- VAE-UNET HYBRIDE ---
class VAE_UNet(nn.Module):
    def __init__(self, latent_dim=256):
        super(VAE_UNet, self).__init__()
        
        self.latent_dim = latent_dim
        
        # ENCODEUR avec skip connections
        self.enc1 = self.conv_block(3, 64)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.enc2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.enc3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.enc4 = self.conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        # BOTTLENECK - Espace latent VAE
        self.bottleneck = self.conv_block(512, 1024)
        
        # Couches VAE (mu et logvar)
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(1024 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(1024 * 8 * 8, latent_dim)
        
        # Reconstruction depuis l'espace latent
        self.fc_decode = nn.Linear(latent_dim, 1024 * 8 * 8)
        self.unflatten = nn.Unflatten(1, (1024, 8, 8))
        
        # DÉCODEUR avec skip connections (style U-Net)
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = self.conv_block(1024, 512)  # 1024 = 512 (up) + 512 (skip)
        
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self.conv_block(512, 256)  # 512 = 256 (up) + 256 (skip)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self.conv_block(256, 128)  # 256 = 128 (up) + 128 (skip)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self.conv_block(128, 64)  # 128 = 64 (up) + 64 (skip)
        
        # Sortie finale
        self.out = nn.Conv2d(64, 3, 1)
    
    def conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def encode(self, x):
        # Encodeur avec sauvegarde des skip connections
        e1 = self.enc1(x)          # 128x128x64
        p1 = self.pool1(e1)
        
        e2 = self.enc2(p1)         # 64x64x128
        p2 = self.pool2(e2)
        
        e3 = self.enc3(p2)         # 32x32x256
        p3 = self.pool3(e3)
        
        e4 = self.enc4(p3)         # 16x16x512
        p4 = self.pool4(e4)
        
        # Bottleneck
        b = self.bottleneck(p4)    # 8x8x1024
        
        # VAE latent space
        flat = self.flatten(b)
        mu = self.fc_mu(flat)
        logvar = self.fc_logvar(flat)
        
        return mu, logvar, e1, e2, e3, e4
    
    def decode(self, z, e1, e2, e3, e4):
        # Reconstruction depuis l'espace latent
        x = self.fc_decode(z)
        x = self.unflatten(x)      # 8x8x1024
        
        # Décodeur avec skip connections
        d4 = self.up4(x)           # 16x16x512
        d4 = torch.cat([d4, e4], dim=1)  # Concat avec skip
        d4 = self.dec4(d4)
        
        d3 = self.up3(d4)          # 32x32x256
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)          # 64x64x128
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)          # 128x128x64
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        out = self.out(d1)         # 128x128x3
        return torch.sigmoid(out)
    
    def forward(self, x):
        mu, logvar, e1, e2, e3, e4 = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, e1, e2, e3, e4)
        return recon, mu, logvar


In [None]:
class RestorationDataset(Dataset):
    def __init__(self, degraded_dir, clean_dir, transform=None):
        self.degraded_dir = degraded_dir
        self.clean_dir = clean_dir
        self.transform = transform
        
        degraded_files = set(os.listdir(degraded_dir))
        clean_files = set(os.listdir(clean_dir))
        
        self.filenames = []
        for deg_file in degraded_files:
            clean_file = deg_file.replace("degraded_", "")
            if clean_file in clean_files:
                self.filenames.append(deg_file)
        
        print(f"{len(self.filenames)} paires d'images trouvées")
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        degraded_filename = self.filenames[idx]
        clean_filename = degraded_filename.replace("degraded_", "")
        
        degraded_path = os.path.join(self.degraded_dir, degraded_filename)
        clean_path = os.path.join(self.clean_dir, clean_filename)
        
        degraded_img = Image.open(degraded_path).convert('RGB')
        clean_img = Image.open(clean_path).convert('RGB')
        
        if self.transform:
            degraded_img = self.transform(degraded_img)
            clean_img = self.transform(clean_img)
        
        return degraded_img, clean_img

In [None]:
def vae_loss_function(recon, target, mu, logvar, beta=0.0001):
    """
    Loss VAE optimisée :
    - L1 : netteté
    - MSE : structure
    - KL divergence : régularisation
    """
    l1_loss = nn.functional.l1_loss(recon, target, reduction='sum')
    mse_loss = nn.functional.mse_loss(recon, target, reduction='sum')
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    total_loss = l1_loss + 0.5 * mse_loss + beta * kld_loss
    
    return total_loss, l1_loss.item(), mse_loss.item(), kld_loss.item()




In [None]:
def save_samples(model, dataloader, epoch, device, output_dir="samples", num_samples=8):
    """Génère et sauvegarde des échantillons de restauration"""
    os.makedirs(output_dir, exist_ok=True)
    
    model.eval()
    with torch.no_grad():
        degraded, clean = next(iter(dataloader))
        degraded, clean = degraded[:num_samples].to(device), clean[:num_samples].to(device)
        
        restored, _, _ = model(degraded)
        
        # Grille : [dégradée | restaurée | propre]
        comparison = torch.cat([degraded, restored, clean], dim=0)
        save_image(comparison, 
                   os.path.join(output_dir, f"epoch_{epoch:03d}.png"),
                   nrow=num_samples, 
                   normalize=False)
    
    model.train()



In [None]:
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Utilisation de : {device}")
    
    # Modèle VAE-UNet
    model = VAE_UNet(latent_dim=256).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5)
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3
    )
    
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])
    
    path_degraded = "../data/train/degraded_images/"
    path_clean = "../data/train/images/"
    
    dataset = RestorationDataset(path_degraded, path_clean, transform)
    dataloader = DataLoader(
        dataset, 
        batch_size=32,
        shuffle=True, 
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    os.makedirs("samples", exist_ok=True)
    
    num_epochs = 50
    best_loss = float('inf')
    
    print("Génération des échantillons initiaux...")
    save_samples(model, dataloader, 0, device)
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_l1 = 0
        epoch_mse = 0
        epoch_kld = 0
        
        progress_bar = tqdm(dataloader, desc=f"Époque {epoch+1}/{num_epochs}", file=sys.stdout)
        
        for batch_idx, (degraded, clean) in enumerate(progress_bar):
            degraded, clean = degraded.to(device), clean.to(device)
            
            optimizer.zero_grad()
            restored, mu, logvar = model(degraded)
            loss, l1, mse, kld = vae_loss_function(restored, clean, mu, logvar)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
            epoch_l1 += l1
            epoch_mse += mse
            epoch_kld += kld
            
            if batch_idx % 10 == 0:
                progress_bar.set_postfix({
                    'loss': f'{loss.item()/len(degraded):.4f}',
                    'l1': f'{l1/len(degraded):.2f}',
                    'kld': f'{kld/len(degraded):.2f}'
                })
        
        avg_loss = epoch_loss / len(dataset)
        scheduler.step(avg_loss)
        
        print(f"\nÉpoque {epoch+1} | Loss: {avg_loss:.4f} | "
              f"L1: {epoch_l1/len(dataset):.4f} | "
              f"MSE: {epoch_mse/len(dataset):.4f} | "
              f"KLD: {epoch_kld/len(dataset):.4f}")
        
        # Génération d'échantillons tous les 2 epochs
        if (epoch + 1) % 2 == 0:
            print(f"Génération d'échantillons...")
            save_samples(model, dataloader, epoch + 1, device)
        
        # Sauvegarde du meilleur modèle
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }, "vae_unet_best.pth")
            print(f"Meilleur modèle sauvegardé (loss: {best_loss:.4f})")
        
        # Checkpoints réguliers
        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), f"vae_unet_epoch_{epoch+1}.pth")
    
    print("\nEntraînement terminé !")
    print(f"Les échantillons sont dans le dossier 'samples/'")
    print(f"Meilleur modèle : vae_unet_best.pth (loss: {best_loss:.4f})")

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import numpy as np

In [None]:
# --- CONFIGURATION ---
test_degraded_dir = "../data/test/degraded_images/"
test_clean_dir = "../data/test/images/"
output_dir = "test_results_metrics"
checkpoint_path = "vae_unet_best.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(output_dir, exist_ok=True)

# --- CHARGEMENT DU MODÈLE ---
model = VAE_UNet(latent_dim=256).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
model.eval()

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# --- BOUCLE DE TEST ---
filenames = [f for f in os.listdir(test_degraded_dir) if f.endswith(('.jpg', '.png'))]
all_psnr = []
all_ssim = []

print(f"Évaluation de {len(filenames)} images de test...")

with torch.no_grad():
    for filename in tqdm(filenames):
        # 1. Chargement des images
        deg_path = os.path.join(test_degraded_dir, filename)
        clean_filename = filename.replace("degraded_", "")
        clean_path = os.path.join(test_clean_dir, clean_filename)

        if not os.path.exists(clean_path): continue

        img_deg = transform(Image.open(deg_path).convert('RGB')).unsqueeze(0).to(device)
        img_clean = transform(Image.open(clean_path).convert('RGB')).unsqueeze(0).to(device)

        # 2. Inférence
        img_restored, _, _ = model(img_deg)

        # 3. Conversion pour calcul des métriques (Tenseur -> Numpy HWC [0,1])
        # On détache et on déplace sur CPU
        clean_np = img_clean.squeeze().cpu().permute(1, 2, 0).numpy()
        restored_np = img_restored.squeeze().cpu().permute(1, 2, 0).numpy()

        # 4. Calcul des scores
        # data_range=1.0 car nos pixels sont entre 0 et 1
        current_psnr = psnr(clean_np, restored_np, data_range=1.0)
        current_ssim = ssim(clean_np, restored_np, data_range=1.0, channel_axis=2)

        all_psnr.append(current_psnr)
        all_ssim.append(current_ssim)

        # 5. Sauvegarde visuelle (Optionnel)
        comparison = torch.cat([img_deg, img_restored, img_clean], dim=0)
        save_image(comparison, os.path.join(output_dir, f"score_{current_psnr:.2f}_{filename}"), nrow=3)

# --- RÉSULTATS FINAUX ---
print("\n" + "="*30)
print(f"RÉSULTATS DU MODÈLE SUR LE TEST SET")
print(f"PSNR Moyen : {np.mean(all_psnr):.2f} dB")
print(f"SSIM Moyen : {np.mean(all_ssim):.4f}")
print("="*30)