In [13]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image

In [14]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.d1 = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.ReLU(True)
        )
        self.d2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )

        self.d3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True)
        )

        self.u1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )

        self.u2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )

        self.u3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(d1)
        d3 = self.d3(d2)

        # Decoder
        u1 = self.u1(d3)
        u1 = torch.cat([u1, d2], dim=1)

        u2 = self.u2(u1)
        u2 = torch.cat([u2, d1], dim=1)

        return self.u3(u2)

In [15]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),   # image_détériorée + image
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 3, 1, 1)
        )

    def forward(self, img_cond, img):
        x = torch.cat([img_cond, img], dim=1)
        return self.model(x)


# CHARGEMENT DATASET

In [16]:
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class PairedImageDataset(Dataset):
    def __init__(self, root):
        self.target_dir = os.path.join(root, "images")
        self.deg_dir = os.path.join(root, "degraded_images")

        self.target_names = sorted(os.listdir(self.target_dir))
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.target_names)

    def __getitem__(self, idx):
        target_name = self.target_names[idx]

        # --- construire le nom de l'image dégradée ---
        # image_00183.jpg -> degraded_image_00183.jpg
        deg_name = "degraded_" + target_name

        target_path = os.path.join(self.target_dir, target_name)
        deg_path = os.path.join(self.deg_dir, deg_name)

        if not os.path.exists(deg_path):
            raise FileNotFoundError(deg_path)

        target = Image.open(target_path).convert("RGB")
        degraded = Image.open(deg_path).convert("RGB")

        target = self.to_tensor(target)
        degraded = self.to_tensor(degraded)

        return degraded, target


In [17]:
batch_size = 64

In [18]:
from torch.utils.data import DataLoader

dataset = PairedImageDataset("..\\data")

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)


In [19]:
degraded, target = next(iter(dataloader))
print(degraded.shape, target.shape)


torch.Size([64, 3, 128, 128]) torch.Size([64, 3, 128, 128])


# TRAINING

In [20]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")

output_dir = "output_gan"
os.makedirs(output_dir, exist_ok=True)

In [21]:
gen = Generator()
disc = Discriminator()

opt_g = optim.Adam(gen.parameters(), lr=1e-4)
opt_d = optim.Adam(disc.parameters(), lr=2e-4)
criterion = torch.nn.BCEWithLogitsLoss()

In [22]:
gen

Generator(
  (d1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
  )
  (d2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (d3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (u1): Sequential(
    (0): Upsample(scale_factor=2.0, mode='bilinear')
    (1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
  )
  (u2): Sequential(
    (0): Upsample(scale_factor=2.0, mode='bilinear')
    (1): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (

In [23]:
disc

Discriminator(
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [None]:
epochs = 20
save_every = 2

fixed_cond = next(iter(dataloader))[0][:25].to(device)  # exemple pour visualisation

for epoch in range(1, epochs+1):
    g_losses = []; d_losses = []
    for degraded, target in dataloader:   # dataloader -> (degraded, target)
        degraded = degraded.to(device)   # ex (B, C_cond, 128,128)
        target = target.to(device)       # ex (B, C_out, 128,128)

        # ---- Discriminator ----
        opt_d.zero_grad()
        out_real = disc(degraded, target)
        real_labels = torch.full_like(out_real, 0.9, device=device)
        # print(out_real.shape)
        # print(real_labels.shape)
        d_loss_real = criterion(out_real, real_labels)

        fake = gen(degraded)                             # (B,C_out,128,128)
        out_fake = disc(degraded, fake.detach())
        fake_labels = torch.zeros_like(out_fake, device=device)
        d_loss_fake = criterion(out_fake, fake_labels)

        d_loss = (d_loss_real + d_loss_fake) * 0.5
        d_loss.backward()
        opt_d.step()

        # ---- Generator ----
        opt_g.zero_grad()
        out_fake_for_g = disc(degraded, fake)            # want D to predict real

        #on fait un nouveau calcul de loss
        lambda_l1 = 30

        adv_loss = criterion(out_fake_for_g, real_labels)
        l1_loss = torch.nn.functional.l1_loss(fake, target)

        g_loss = adv_loss + lambda_l1 * l1_loss

        g_loss.backward()
        opt_g.step()

        g_losses.append(g_loss.item()); d_losses.append(d_loss.item())

    print(f"Epoch {epoch}: gen_loss={np.mean(g_losses):.4f}, disc_loss={np.mean(d_losses):.4f}")

    if epoch % save_every == 0 or epoch in (1, epochs):
        with torch.no_grad():
            gen.eval()
            samples = gen(fixed_cond)             # (N, C_out, H, W)
            samples = (samples + 1) / 2.0         # if Tanh used during training
            grid = make_grid(samples, nrow=5)
            save_image(grid, os.path.join(output_dir, f"generated_epoch_{epoch}.png"))
            gen.train()

torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
Epoch 1: gen_loss=6.9329, disc_loss=0.6933
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch.Size([64, 1, 16, 16])
torch

KeyboardInterrupt: 