In [3]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()


        self.model = nn.Sequential(
            # Encodeur
            nn.Conv2d(1, 64, 4, 2, 1),    # 28 → 14
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1), # 14 → 7
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # Décodeur
            nn.ConvTranspose2d(128, 64, 4, 2, 1), # 7 → 14
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 3, 4, 2, 1),   # 14 → 28 (RGB output)
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)
    

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(4, 64, 4, 2, 1),   # image_détériorée + image
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 3, 1, 1)
        )

    def forward(self, img_cond, img):
        x = torch.cat([img_cond, img], dim=1)
        return self.model(x)


# CHARGEMENT DATASET

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class PairedImageDataset(Dataset):
    def __init__(self, root):
        self.deg_dir = os.path.join(root, "degraded_images")
        self.tgt_dir = os.path.join(root, "images")

        self.names = sorted(os.listdir(self.deg_dir))
        self.to_tensor = transforms.ToTensor()  # rien d’autre

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]

        degraded = Image.open(os.path.join(self.deg_dir, name)).convert("RGB")
        target   = Image.open(os.path.join(self.tgt_dir, name)).convert("RGB")

        degraded = self.to_tensor(degraded)  # (C,H,W) in [0,1]
        target   = self.to_tensor(target)

        return degraded, target


# TRAINING

In [6]:
batch_size = 256

In [9]:
latent_dim = 100

gen = Generator()
disc = Discriminator()

opt_g = optim.Adam(gen.parameters(), lr=1e-4)
opt_d = optim.Adam(disc.parameters(), lr=1e-4)
criterion = nn.BCELoss() # BCE loss here! (less complicated than in the course...)

In [10]:
gen

Generator(
  (model): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU(inplace=True)
    (5): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU(inplace=True)
    (8): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): Tanh()
  )
)

In [11]:
disc

Discriminator(
  (model): Sequential(
    (0): Conv2d(4, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [None]:
epochs = 200
save_every = 5

fixed_cond = next(iter(dataloader))[0][:25].to(device)  # exemple pour visualisation

for epoch in range(1, epochs+1):
    g_losses = []; d_losses = []
    for degraded, target in dataloader:   # dataloader -> (degraded, target)
        degraded = degraded.to(device)   # ex (B, C_cond, 128,128)
        target = target.to(device)       # ex (B, C_out, 128,128)

        # ---- Discriminator ----
        opt_d.zero_grad()
        out_real = disc(degraded, target)                # (B,1,H',W')
        real_labels = torch.ones_like(out_real, device=device)
        d_loss_real = criterion(out_real, real_labels)

        fake = gen(degraded)                             # (B,C_out,128,128)
        out_fake = disc(degraded, fake.detach())
        fake_labels = torch.zeros_like(out_fake, device=device)
        d_loss_fake = criterion(out_fake, fake_labels)

        d_loss = (d_loss_real + d_loss_fake) * 0.5
        d_loss.backward()
        opt_d.step()

        # ---- Generator ----
        opt_g.zero_grad()
        out_fake_for_g = disc(degraded, fake)            # want D to predict real
        g_loss = criterion(out_fake_for_g, real_labels)
        g_loss.backward()
        opt_g.step()

        g_losses.append(g_loss.item()); d_losses.append(d_loss.item())

    print(f"Epoch {epoch}: gen_loss={np.mean(g_losses):.4f}, disc_loss={np.mean(d_losses):.4f}")

    if epoch % save_every == 0 or epoch in (1, epochs):
        with torch.no_grad():
            gen.eval()
            samples = gen(fixed_cond)             # (N, C_out, H, W)
            samples = (samples + 1) / 2.0         # if Tanh used during training
            grid = make_grid(samples, nrow=5)
            save_image(grid, os.path.join(output_dir, f"generated_epoch_{epoch}.png"))
            gen.train()