In [134]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
from tqdm import tqdm

import torch_directml


In [135]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.d1 = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.ReLU(True)
        )
        self.d2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True)
        )

        self.d3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256, affine=True),
            nn.ReLU(True)
        )

        self.u1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True)
        )

        self.u2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 64, 3, 1, 1),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(True)
        )

        self.u3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(d1)
        d3 = self.d3(d2)

        # Decoder
        u1 = self.u1(d3)
        u1 = torch.cat([u1, d2], dim=1)

        u2 = self.u2(u1)
        u2 = torch.cat([u2, d1], dim=1)

        return self.u3(u2)

In [136]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),   # image_détériorée + image
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 3, 1, 1)
        )

    def forward(self, img_cond, img):
        x = torch.cat([img_cond, img], dim=1)
        return self.model(x)


# CHARGEMENT DATASET

In [137]:
batch_size = 32

In [138]:
from paired_dataset import PairedImageDataset

In [139]:
from torch.utils.data import DataLoader

dataset = PairedImageDataset("..\\data\\train2")

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=6,
    shuffle=True,
    pin_memory=True
)

In [140]:
degraded, target = next(iter(dataloader))
print(degraded.shape, target.shape)


torch.Size([32, 3, 128, 128]) torch.Size([32, 3, 128, 128])


# TRAINING

In [141]:
device = torch_directml.device()

output_dir = "output_gan"
os.makedirs(output_dir, exist_ok=True)

print("Using device:", device)


Using device: privateuseone:0


In [142]:
latent_dim = 100

gen = Generator().to(device)
disc = Discriminator().to(device)

opt_g = optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_d = optim.Adam(disc.parameters(), lr=5e-5, betas=(0.5, 0.999))

criterion = torch.nn.BCEWithLogitsLoss()

In [143]:
gen

Generator(
  (d1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
  )
  (d2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (2): ReLU(inplace=True)
  )
  (d3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (2): ReLU(inplace=True)
  )
  (u1): Sequential(
    (0): Upsample(scale_factor=2.0, mode='nearest')
    (1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (3): ReLU(inplace=True)
  )
  (u2): Sequential(
    (0): Upsample(scale_factor=2.0, mode='nearest')
    (1): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1,

In [144]:
disc

Discriminator(
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [145]:
def tv_loss(x):
    return torch.mean(torch.abs(x[:, :, :-1] - x[:, :, 1:])) + \
           torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))

In [146]:
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

In [147]:
epochs = 100
save_every = 2

fixed_cond = next(iter(dataloader))[0][:25].to(device)

for epoch in range(1, epochs + 1):
    g_losses = []
    d_losses = []

    # barre de progression sur les batches
    pbar = tqdm(dataloader, desc=f"Epoch [{epoch}/{epochs}]", leave=False)

    for degraded, target in pbar:
        degraded = degraded.to(device)
        target = target.to(device)

        # ---- Discriminator ----

        
        opt_d.zero_grad()
        out_real = disc(degraded, target)
        real_labels = torch.full_like(out_real, 0.9, device=device)
        d_loss_real = criterion(out_real, real_labels)

        fake = gen(degraded)
        out_fake = disc(degraded, fake.detach())
        fake_labels = torch.zeros_like(out_fake, device=device)
        d_loss_fake = criterion(out_fake, fake_labels)

        d_loss = 0.5 * (d_loss_real + d_loss_fake)
        d_loss.backward()
        opt_d.step()


        # ---- Generator ----
        opt_g.zero_grad()
        out_fake_for_g = disc(degraded, fake)
        real_labels_g = torch.ones_like(out_fake_for_g)

        #On adapte le poids du L1 en fonction de l'epoch
        if epoch < 7:
            lambda_l1 = 40
        elif epoch < 25:
            lambda_l1 = 15
        else:
            lambda_l1 = 7

            
        adv_loss = criterion(out_fake_for_g, real_labels_g)
        l1_loss = torch.nn.functional.l1_loss(fake, target)

        if epoch <= 7 : #On warmup le generateur donc pas de adv_loss
            g_loss = lambda_l1 * l1_loss
        else :
            tv = tv_loss(fake)
            g_loss = adv_loss + lambda_l1 * l1_loss + 0.005 * tv

        g_loss.backward()
        opt_g.step()

        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())

        # mise à jour de la barre
        pbar.set_postfix({
            "G_loss": f"{g_loss.item():.3f}",
            "D_loss": f"{d_loss.item():.3f}"
        })

    print(
        f"Epoch {epoch}: "
        f"gen_loss={np.mean(g_losses):.4f}, "
        f"disc_loss={np.mean(d_losses):.4f}"
    )

    if epoch % save_every == 0 or epoch in (1, epochs):
        with torch.no_grad():
            gen.eval()
            samples = gen(fixed_cond)
            samples = (samples + 1) / 2.0
            samples = samples.cpu() 
            grid = make_grid(samples, nrow=5)
            save_image(
                grid,
                os.path.join(output_dir, f"generated_epoch_{epoch}.png")
            )
            gen.train()
    
    if epoch % 10 == 0 or epoch == epochs:
        torch.save(
            {
                "epoch": epoch,
                "gen_state_dict": gen.state_dict(),
                "disc_state_dict": disc.state_dict(),
                "opt_g_state_dict": opt_g.state_dict(),
                "opt_d_state_dict": opt_d.state_dict(),
            },
            os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
        )



                                                                                            

Epoch 1: gen_loss=3.9322, disc_loss=0.4574


                                                                                            

Epoch 2: gen_loss=3.2601, disc_loss=0.2523


                                                                                            

Epoch 3: gen_loss=3.1733, disc_loss=0.2168


                                                                                            

Epoch 4: gen_loss=3.1270, disc_loss=0.2045


                                                                                            

Epoch 5: gen_loss=3.0917, disc_loss=0.1955


                                                                                            

Epoch 6: gen_loss=3.0623, disc_loss=0.1886


                                                                                            

Epoch 7: gen_loss=1.1277, disc_loss=0.1803


                                                                                            

Epoch 8: gen_loss=2.4544, disc_loss=0.6737


                                                                                            

Epoch 9: gen_loss=2.5086, disc_loss=0.6434


                                                                                             

Epoch 10: gen_loss=2.5488, disc_loss=0.6288


                                                                                             

Epoch 11: gen_loss=2.6813, disc_loss=0.5959


                                                                                             

Epoch 12: gen_loss=2.8152, disc_loss=0.5592


                                                                                             

Epoch 13: gen_loss=2.8182, disc_loss=0.5536


                                                                                             

Epoch 14: gen_loss=2.7583, disc_loss=0.5576


                                                                                             

Epoch 15: gen_loss=2.6819, disc_loss=0.5656


                                                                                             

Epoch 16: gen_loss=2.7009, disc_loss=0.5679


                                                                                             

Epoch 17: gen_loss=2.7438, disc_loss=0.5572


                                                                                             

Epoch 18: gen_loss=2.6812, disc_loss=0.5646


                                                                                             

Epoch 19: gen_loss=2.7387, disc_loss=0.5549


                                                                                             

Epoch 20: gen_loss=2.7893, disc_loss=0.5356


                                                                                             

Epoch 21: gen_loss=2.8410, disc_loss=0.5309


                                                                                             

Epoch 22: gen_loss=2.8537, disc_loss=0.5283


                                                                                             

Epoch 23: gen_loss=2.8814, disc_loss=0.5231


                                                                                             

Epoch 24: gen_loss=2.8768, disc_loss=0.5130


                                                                                             

Epoch 25: gen_loss=2.0850, disc_loss=0.5196


                                                                                             

Epoch 26: gen_loss=2.1831, disc_loss=0.5006


                                                                                             

Epoch 27: gen_loss=2.2643, disc_loss=0.4847


                                                                                             

Epoch 28: gen_loss=2.2832, disc_loss=0.4812


                                                                                             

Epoch 29: gen_loss=2.3089, disc_loss=0.4813


                                                                                             

Epoch 30: gen_loss=2.2895, disc_loss=0.4786


                                                                                             

Epoch 31: gen_loss=2.2950, disc_loss=0.4775


                                                                                             

Epoch 32: gen_loss=2.2889, disc_loss=0.4749


                                                                                             

Epoch 33: gen_loss=2.2496, disc_loss=0.4748


                                                                                             

Epoch 34: gen_loss=2.2957, disc_loss=0.4671


                                                                                             

Epoch 35: gen_loss=2.2664, disc_loss=0.4658


                                                                                             

Epoch 36: gen_loss=2.2990, disc_loss=0.4599


                                                                                             

Epoch 37: gen_loss=2.2731, disc_loss=0.4603


                                                                                             

Epoch 38: gen_loss=2.3634, disc_loss=0.4472


                                                                                             

Epoch 39: gen_loss=2.3338, disc_loss=0.4497


                                                                                             

Epoch 40: gen_loss=2.3529, disc_loss=0.4445


                                                                                             

Epoch 41: gen_loss=2.3824, disc_loss=0.4395


                                                                                             

Epoch 42: gen_loss=2.3986, disc_loss=0.4358


                                                                                             

Epoch 43: gen_loss=2.4287, disc_loss=0.4329


                                                                                             

Epoch 44: gen_loss=2.3900, disc_loss=0.4360


                                                                                             

Epoch 45: gen_loss=2.4513, disc_loss=0.4289


                                                                                             

Epoch 46: gen_loss=2.4622, disc_loss=0.4259


                                                                                             

Epoch 47: gen_loss=2.4508, disc_loss=0.4264


                                                                                             

Epoch 48: gen_loss=2.5120, disc_loss=0.4210


                                                                                             

Epoch 49: gen_loss=2.4877, disc_loss=0.4205


                                                                                             

Epoch 50: gen_loss=2.4793, disc_loss=0.4214


                                                                                             

Epoch 51: gen_loss=2.5024, disc_loss=0.4200


                                                                                             

Epoch 52: gen_loss=2.5065, disc_loss=0.4183


                                                                                             

Epoch 53: gen_loss=2.5237, disc_loss=0.4166


                                                                                             

Epoch 54: gen_loss=2.5334, disc_loss=0.4108


                                                                                             

Epoch 55: gen_loss=2.5421, disc_loss=0.4125


                                                                                             

Epoch 56: gen_loss=2.5433, disc_loss=0.4128


                                                                                             

Epoch 57: gen_loss=2.5551, disc_loss=0.4094


                                                                                             

Epoch 58: gen_loss=2.5751, disc_loss=0.4061


                                                                                             

Epoch 59: gen_loss=2.5630, disc_loss=0.4086


                                                                                             

Epoch 60: gen_loss=2.5757, disc_loss=0.4060


                                                                                             

Epoch 61: gen_loss=2.6270, disc_loss=0.4012


                                                                                             

Epoch 62: gen_loss=2.6070, disc_loss=0.4010


                                                                                             

Epoch 63: gen_loss=2.6019, disc_loss=0.3990


                                                                                             

Epoch 64: gen_loss=2.6219, disc_loss=0.4009


                                                                                             

Epoch 65: gen_loss=2.6304, disc_loss=0.3966


                                                                                             

Epoch 66: gen_loss=2.6378, disc_loss=0.3974


                                                                                             

KeyboardInterrupt: 

In [None]:
# récupérer le premier batch
first_batch = next(iter(dataloader))
degraded, target = first_batch  # (B, C, H, W)

# prendre les 25 premières images pour visualisation
fixed_cond = degraded[:25].to(device)

# si ton dataset a les chemins/fichiers
if hasattr(dataloader.dataset, 'samples'):  # typique ImageFolder
    print("25 premières images utilisées pour visualisation :")
    for i in range(25):
        print(dataloader.dataset.samples[i][0])  # chemin du fichier
else:
    print("fixed_cond shape :", fixed_cond.shape)
    print("Pas de chemin disponible, seulement les tensors")


fixed_cond shape : torch.Size([25, 3, 128, 128])
Pas de chemin disponible, seulement les tensors
