In [3]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()


        self.model = nn.Sequential(
            # Encodeur
            nn.Conv2d(1, 64, 4, 2, 1),    # 28 → 14
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1), # 14 → 7
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # Décodeur
            nn.ConvTranspose2d(128, 64, 4, 2, 1), # 7 → 14
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 3, 4, 2, 1),   # 14 → 28 (RGB output)
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)
    

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(4, 64, 4, 2, 1),   # image_détériorée + image
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 3, 1, 1)
        )

    def forward(self, img_cond, img):
        x = torch.cat([img_cond, img], dim=1)
        return self.model(x)
