In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()


        self.model = nn.Sequential(
            # Encodeur
            nn.Conv2d(3, 64, 4, 2, 1),    # 28 → 14
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1), # 14 → 7
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # Décodeur
            nn.ConvTranspose2d(128, 64, 4, 2, 1), # 7 → 14
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 3, 4, 2, 1),   # 14 → 28 (RGB output)
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)
    

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),   # image_détériorée + image
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 3, 1, 1)
        )

    def forward(self, img_cond, img):
        x = torch.cat([img_cond, img], dim=1)
        return self.model(x)


# CHARGEMENT DATASET

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class PairedImageDataset(Dataset):
    def __init__(self, root):
        self.target_dir = os.path.join(root, "images")
        self.deg_dir = os.path.join(root, "degraded_images")

        self.target_names = sorted(os.listdir(self.target_dir))
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.target_names)

    def __getitem__(self, idx):
        target_name = self.target_names[idx]

        # --- construire le nom de l'image dégradée ---
        # image_00183.jpg -> degraded_image_00183.jpg
        deg_name = "degraded_" + target_name

        target_path = os.path.join(self.target_dir, target_name)
        deg_path = os.path.join(self.deg_dir, deg_name)

        if not os.path.exists(deg_path):
            raise FileNotFoundError(deg_path)

        target = Image.open(target_path).convert("RGB")
        degraded = Image.open(deg_path).convert("RGB")

        target = self.to_tensor(target)
        degraded = self.to_tensor(degraded)

        return degraded, target


In [None]:
batch_size = 256

In [None]:
from torch.utils.data import DataLoader

dataset = PairedImageDataset("..\\data")

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)


In [None]:
degraded, target = next(iter(dataloader))
print(degraded.shape, target.shape)


# TRAINING

In [None]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")

output_dir = "output_gan"
os.makedirs(output_dir, exist_ok=True)

In [None]:
latent_dim = 100

gen = Generator()
disc = Discriminator()

opt_g = optim.Adam(gen.parameters(), lr=1e-4)
opt_d = optim.Adam(disc.parameters(), lr=1e-4)
criterion = torch.nn.BCEWithLogitsLoss()

In [None]:
gen

In [None]:
disc

In [None]:
epochs = 200
save_every = 5

fixed_cond = next(iter(dataloader))[0][:25].to(device)  # exemple pour visualisation

for epoch in range(1, epochs+1):
    g_losses = []; d_losses = []
    for degraded, target in dataloader:   # dataloader -> (degraded, target)
        degraded = degraded.to(device)   # ex (B, C_cond, 128,128)
        target = target.to(device)       # ex (B, C_out, 128,128)

        # ---- Discriminator ----
        opt_d.zero_grad()
        out_real = disc(degraded, target)                # (B,1,H',W')
        real_labels = torch.ones_like(out_real, device=device)
        print(out_real.shape)
        print(real_labels.shape)
        d_loss_real = criterion(out_real, real_labels)

        fake = gen(degraded)                             # (B,C_out,128,128)
        out_fake = disc(degraded, fake.detach())
        fake_labels = torch.zeros_like(out_fake, device=device)
        d_loss_fake = criterion(out_fake, fake_labels)

        d_loss = (d_loss_real + d_loss_fake) * 0.5
        d_loss.backward()
        opt_d.step()

        # ---- Generator ----
        opt_g.zero_grad()
        out_fake_for_g = disc(degraded, fake)            # want D to predict real
        g_loss = criterion(out_fake_for_g, real_labels)
        g_loss.backward()
        opt_g.step()

        g_losses.append(g_loss.item()); d_losses.append(d_loss.item())

    print(f"Epoch {epoch}: gen_loss={np.mean(g_losses):.4f}, disc_loss={np.mean(d_losses):.4f}")

    if epoch % save_every == 0 or epoch in (1, epochs):
        with torch.no_grad():
            gen.eval()
            samples = gen(fixed_cond)             # (N, C_out, H, W)
            samples = (samples + 1) / 2.0         # if Tanh used during training
            grid = make_grid(samples, nrow=5)
            save_image(grid, os.path.join(output_dir, f"generated_epoch_{epoch}.png"))
            gen.train()