In [75]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
from tqdm import tqdm

import torch_directml


In [76]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.d1 = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.ReLU(True)
        )
        self.d2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )

        self.d3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True)
        )

        self.u1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )

        self.u2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )

        self.u3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(d1)
        d3 = self.d3(d2)

        # Decoder
        u1 = self.u1(d3)
        u1 = torch.cat([u1, d2], dim=1)

        u2 = self.u2(u1)
        u2 = torch.cat([u2, d1], dim=1)

        return self.u3(u2)

In [77]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),   # image_détériorée + image
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 3, 1, 1)
        )

    def forward(self, img_cond, img):
        x = torch.cat([img_cond, img], dim=1)
        return self.model(x)


# CHARGEMENT DATASET

In [78]:
batch_size = 32

In [79]:
from paired_dataset import PairedImageDataset

In [80]:
from torch.utils.data import DataLoader

dataset = PairedImageDataset("..\\data\\train2")

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=6,
    shuffle=True,
    pin_memory=True
)

In [81]:
degraded, target = next(iter(dataloader))
print(degraded.shape, target.shape)


torch.Size([32, 3, 128, 128]) torch.Size([32, 3, 128, 128])


# TRAINING

In [82]:
device = torch_directml.device()

output_dir = "output_gan"
os.makedirs(output_dir, exist_ok=True)

print("Using device:", device)


Using device: privateuseone:0


In [None]:
latent_dim = 100

gen = Generator().to(device)
disc = Discriminator().to(device)

opt_g = optim.Adam(gen.parameters(), lr=2e-4)
opt_d = optim.Adam(disc.parameters(), lr=2e-4)
criterion = torch.nn.BCEWithLogitsLoss()

In [84]:
gen

Generator(
  (d1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
  )
  (d2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (d3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (u1): Sequential(
    (0): Upsample(scale_factor=2.0, mode='nearest')
    (1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
  )
  (u2): Sequential(
    (0): Upsample(scale_factor=2.0, mode='nearest')
    (1): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2)

In [85]:
disc

Discriminator(
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [None]:
def tv_loss(x):
    return torch.mean(torch.abs(x[:, :, :-1] - x[:, :, 1:])) + \
           torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))

In [None]:
epochs = 50
save_every = 2

fixed_cond = next(iter(dataloader))[0][:25].to(device)

for epoch in range(1, epochs + 1):
    g_losses = []
    d_losses = []

    # barre de progression sur les batches
    pbar = tqdm(dataloader, desc=f"Epoch [{epoch}/{epochs}]", leave=False)

    for degraded, target in pbar:
        degraded = degraded.to(device)
        target = target.to(device)

        # ---- Discriminator ----

        
        opt_d.zero_grad()
        out_real = disc(degraded, target)
        real_labels = torch.full_like(out_real, 0.9, device=device)
        d_loss_real = criterion(out_real, real_labels)

        fake = gen(degraded)
        out_fake = disc(degraded, fake.detach())
        fake_labels = torch.zeros_like(out_fake, device=device)
        d_loss_fake = criterion(out_fake, fake_labels)

        d_loss = 0.5 * (d_loss_real + d_loss_fake)
        d_loss.backward()
        opt_d.step()


        # ---- Generator ----
        opt_g.zero_grad()
        out_fake_for_g = disc(degraded, fake)
        real_labels_g = torch.ones_like(out_fake_for_g)

        #On adapte le poids du L1 en fonction de l'epoch
        if epoch < 10:
            lambda_l1 = 100
        elif epoch < 25:
            lambda_l1 = 50
        else:
            lambda_l1 = 20
            
        adv_loss = criterion(out_fake_for_g, real_labels_g)
        l1_loss = torch.nn.functional.l1_loss(fake, target)

        if epoch <= 7 : #On warmup le generateur donc pas de adv_loss
            g_loss = lambda_l1 * l1_loss
        else :
            tv = tv_loss(fake)
            g_loss = adv_loss + lambda_l1 * l1_loss + 0.1 * tv

        g_loss.backward()
        opt_g.step()

        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())

        # mise à jour de la barre
        pbar.set_postfix({
            "G_loss": f"{g_loss.item():.3f}",
            "D_loss": f"{d_loss.item():.3f}"
        })

    print(
        f"Epoch {epoch}: "
        f"gen_loss={np.mean(g_losses):.4f}, "
        f"disc_loss={np.mean(d_losses):.4f}"
    )

    if epoch % save_every == 0 or epoch in (1, epochs):
        with torch.no_grad():
            gen.eval()
            samples = gen(fixed_cond)
            samples = (samples + 1) / 2.0
            samples = samples.cpu() 
            grid = make_grid(samples, nrow=5)
            save_image(
                grid,
                os.path.join(output_dir, f"generated_epoch_{epoch}.png")
            )
            gen.train()


                                                                                           

Epoch 1: gen_loss=9.1804, disc_loss=0.2745


                                                                                           

Epoch 2: gen_loss=7.9605, disc_loss=0.1658


                                                                                           

Epoch 3: gen_loss=7.8073, disc_loss=0.1644


                                                                                           

Epoch 4: gen_loss=7.7088, disc_loss=0.1644


                                                                                           

Epoch 5: gen_loss=7.6379, disc_loss=0.1637


                                                                                           

Epoch 6: gen_loss=7.5853, disc_loss=0.1635


                                                                                           

Epoch 7: gen_loss=7.5318, disc_loss=0.1677


                                                                                            

Epoch 8: gen_loss=9.8184, disc_loss=0.6214


                                                                                            

Epoch 9: gen_loss=9.4038, disc_loss=0.4951


                                                                                            

Epoch 10: gen_loss=5.3691, disc_loss=0.6504


                                                                                            

Epoch 11: gen_loss=5.3269, disc_loss=0.6626


                                                                                            

Epoch 12: gen_loss=5.0606, disc_loss=0.5787


                                                                                            

Epoch 13: gen_loss=5.2898, disc_loss=0.6668


                                                                                            

Epoch 14: gen_loss=5.4931, disc_loss=0.6369


                                                                                            

Epoch 15: gen_loss=4.9348, disc_loss=0.6408


                                                                                            

Epoch 16: gen_loss=5.8730, disc_loss=0.5710


                                                                                            

Epoch 17: gen_loss=6.6343, disc_loss=0.4804


                                                                                            

Epoch 18: gen_loss=4.6004, disc_loss=0.6231


                                                                                            

Epoch 19: gen_loss=6.0366, disc_loss=0.5693


                                                                                            

Epoch 20: gen_loss=6.3529, disc_loss=0.5587


                                                                                            

Epoch 21: gen_loss=4.6785, disc_loss=0.6495


                                                                                            

Epoch 22: gen_loss=5.8868, disc_loss=0.6328


                                                                                            

Epoch 23: gen_loss=6.0813, disc_loss=0.6041


                                                                                            

Epoch 24: gen_loss=4.6171, disc_loss=0.6316


                                                                                            

Epoch 25: gen_loss=3.0663, disc_loss=0.6841


                                                                                            

Epoch 26: gen_loss=3.1418, disc_loss=0.6404


                                                                                            

Epoch 27: gen_loss=2.2031, disc_loss=0.7271


                                                      

KeyboardInterrupt: 

In [None]:
# récupérer le premier batch
first_batch = next(iter(dataloader))
degraded, target = first_batch  # (B, C, H, W)

# prendre les 25 premières images pour visualisation
fixed_cond = degraded[:25].to(device)

# si ton dataset a les chemins/fichiers
if hasattr(dataloader.dataset, 'samples'):  # typique ImageFolder
    print("25 premières images utilisées pour visualisation :")
    for i in range(25):
        print(dataloader.dataset.samples[i][0])  # chemin du fichier
else:
    print("fixed_cond shape :", fixed_cond.shape)
    print("Pas de chemin disponible, seulement les tensors")


fixed_cond shape : torch.Size([25, 3, 128, 128])
Pas de chemin disponible, seulement les tensors
