In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
from tqdm import tqdm

import torch_directml


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.d1 = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.ReLU(True)
        )
        self.d2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True)
        )

        self.d3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256, affine=True),
            nn.ReLU(True)
        )

        self.u1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True)
        )

        self.u2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 64, 3, 1, 1),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(True)
        )

        self.u3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(d1)
        d3 = self.d3(d2)

        # Decoder avec des skip connections
        u1 = self.u1(d3)
        u1 = torch.cat([u1, d2], dim=1)

        u2 = self.u2(u1)
        u2 = torch.cat([u2, d1], dim=1)

        return self.u3(u2)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),   # image_détériorée + image
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 3, 1, 1)
        )

    def forward(self, img_cond, img):
        x = torch.cat([img_cond, img], dim=1)
        return self.model(x)


# CHARGEMENT DATASET

In [None]:
batch_size = 32

In [None]:
from paired_dataset import PairedImageDataset

In [None]:
from torch.utils.data import DataLoader

dataset = PairedImageDataset("..\\..\\data\\train")

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=6, #--> à adapter selon le nombre de coeurs dispo
    shuffle=True,
    pin_memory=True
)

In [None]:
degraded, target = next(iter(dataloader))
print(degraded.shape, target.shape)


# TRAINING

In [None]:
device = torch_directml.device()

output_dir = "output_gan"
os.makedirs(output_dir, exist_ok=True)

print("Using device:", device)


In [None]:
latent_dim = 100

gen = Generator().to(device)
disc = Discriminator().to(device)

opt_g = optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_d = optim.Adam(disc.parameters(), lr=5e-5, betas=(0.5, 0.999))

criterion = torch.nn.BCEWithLogitsLoss()

In [None]:
gen

In [None]:
disc

In [None]:
def tv_loss(x):
    return torch.mean(torch.abs(x[:, :, :-1] - x[:, :, 1:])) + \
           torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))

In [None]:
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
# ICI ce training est adapté pour du GPU AMD 
# Le modèle final n'est pas entrainé via AMD mais via le fichier ganGpu_model.ipynb
#(Quelques modifications mineures sont visibles entre les deux fichiers (ex : dans le fichier GPU on peut executer des entrainements en partant d'un .pth))

epochs = 100
save_every = 2

fixed_cond = next(iter(dataloader))[0][:25].to(device)

for epoch in range(1, epochs + 1):
    g_losses = []
    d_losses = []

    # barre de progression sur les batches
    pbar = tqdm(dataloader, desc=f"Epoch [{epoch}/{epochs}]", leave=False)

    for degraded, target in pbar:
        degraded = degraded.to(device)
        target = target.to(device)

        # ---- Discriminator ----
        opt_d.zero_grad()
        out_real = disc(degraded, target)
        real_labels = torch.full_like(out_real, 0.9, device=device)
        d_loss_real = criterion(out_real, real_labels)

        fake = gen(degraded)
        out_fake = disc(degraded, fake.detach())
        fake_labels = torch.zeros_like(out_fake, device=device)
        d_loss_fake = criterion(out_fake, fake_labels)

        d_loss = 0.5 * (d_loss_real + d_loss_fake)
        d_loss.backward()
        opt_d.step()


        # ---- Generator ----
        #On fait deux entraînements du générateur par itération --> cela aide à stabiliser l'entrainement
        for _ in range(2):
            opt_g.zero_grad()
            out_fake_for_g = disc(degraded, fake)
            real_labels_g = torch.ones_like(out_fake_for_g)

            #On adapte le poids du L1 en fonction de l'epoque (on le décroit au fur et à mesure)
            if epoch < 7:
                lambda_l1 = 40
            elif epoch < 25:
                lambda_l1 = 15
            else:
                lambda_l1 = 7

                
            adv_loss = criterion(out_fake_for_g, real_labels_g)
            l1_loss = torch.nn.functional.l1_loss(fake, target)

            if epoch <= 7 : #On warmup le generateur donc pas de adv_loss
                g_loss = lambda_l1 * l1_loss
            else :
                tv = tv_loss(fake) #On ajoute une régularisation de variation totale
                g_loss = adv_loss + lambda_l1 * l1_loss + 0.005 * tv

            g_loss.backward()
            opt_g.step()

        #Stockage des pertes pour affichage
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())

        # mise à jour de la barre
        pbar.set_postfix({
            "G_loss": f"{g_loss.item():.3f}",
            "D_loss": f"{d_loss.item():.3f}"
        })

    print(
        f"Epoch {epoch}: "
        f"gen_loss={np.mean(g_losses):.4f}, "
        f"disc_loss={np.mean(d_losses):.4f}"
    )

    #Sauvegarde des échantillons générés
    if epoch % save_every == 0 or epoch in (1, epochs):
        with torch.no_grad():
            gen.eval()
            samples = gen(fixed_cond)
            samples = (samples + 1) / 2.0
            samples = samples.cpu() 
            grid = make_grid(samples, nrow=5)
            save_image(
                grid,
                os.path.join(output_dir, f"generated_epoch_{epoch}.png")
            )
            gen.train()
    
    #Sauvegarde du modèle toutes les 10 époques
    if epoch % 10 == 0 or epoch == epochs:
        torch.save(
            {
                "epoch": epoch,
                "gen_state_dict": gen.state_dict(),
                "disc_state_dict": disc.state_dict(),
                "opt_g_state_dict": opt_g.state_dict(),
                "opt_d_state_dict": opt_d.state_dict(),
            },
            os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
        )



