In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os
import kagglehub

In [3]:
# Download latest version
path = kagglehub.dataset_download("badasstechie/celebahq-resized-256x256")
print("Path to dataset files:", path)

# Vérifier le chemin du dataset
print(f"Dataset téléchargé dans : {path}")
# Afficher les fichiers/dossiers présents
print("Fichiers dans le dossier :", os.listdir(path))

# Transformation des images
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Redimensionner à 256x256 (au cas où)
    transforms.ToTensor(),          # Convertir en tenseur PyTorch
])

# Chargement des images avec ImageFolder
dataset = datasets.ImageFolder(root=path, transform=transform)

# Définition des tailles de train et test (80% train, 20% test)
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size

# Séparer le dataset en train et test
trainset, testset = random_split(dataset, [train_size, test_size])

# Création des DataLoaders
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=4)

# Vérification des tailles des ensembles
print(f"Train set size: {len(trainset)} images")
print(f"Test set size: {len(testset)} images")

# Vérifier le chargement d'un batch
images, _ = next(iter(trainloader))
print(f"Batch size: {images.shape}")  # Devrait afficher (128, 3, 256, 256)

Downloading from https://www.kaggle.com/api/v1/datasets/download/badasstechie/celebahq-resized-256x256?dataset_version_number=1...


100%|██████████| 283M/283M [00:53<00:00, 5.58MB/s] 

Extracting files...





Path to dataset files: /Users/edouard/.cache/kagglehub/datasets/badasstechie/celebahq-resized-256x256/versions/1
Dataset téléchargé dans : /Users/edouard/.cache/kagglehub/datasets/badasstechie/celebahq-resized-256x256/versions/1
Fichiers dans le dossier : ['celeba_hq_256']
Train set size: 27000 images
Test set size: 3000 images
Batch size: torch.Size([128, 3, 256, 256])


In [4]:
def show(images):

  images = images.cpu().detach()
  fig, axes = plt.subplots(5, 5, figsize=(6, 6))

  for i, ax in enumerate(axes.flat):
      img = images[i].permute(1, 2, 0).numpy()
      ax.imshow(img,)
      ax.axis('off')

  plt.tight_layout()
  plt.show()

In [5]:
images, _ = next(iter(testloader))
show(images)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/edouard/miniforge3/envs/IA_m1/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/edouard/miniforge3/envs/IA_m1/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
  File "/Users/edouard/miniforge3/envs/IA_m1/lib/python3.9/site-packages/torchvision/__init__.py", line 10, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils  # usort:skip
  File "/Users/edouard/miniforge3/envs/IA_m1/lib/python3.9/site-packages/torchvision/datasets/__init__.py", line 1, in <module>
    from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
  File "/Users/edouard/miniforge3/envs/IA_m1/lib/python3.9/site-packages/torchvision/datasets/_optical_flow.py", line 10, in <module>
    from PIL import Image
  File "/Users/edouard

KeyboardInterrupt: 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
device

In [None]:
def show_images_side_by_side(images, reconstructed_images, nrow=8, title="Original vs Reconstructed"):
    """
    Affiche les images originales et leurs reconstructions côte à côte.

    :param images: Tenseur Pytorch des images originales de taille (N, C, H, W)
    :param reconstructed_images: Tenseur Pytorch des images reconstruites de taille (N, C, H, W)
    :param nrow: Nombre d'images par ligne
    :param title: Titre de la figure
    """
    if images.shape != reconstructed_images.shape:
        raise ValueError("Les tenseurs d'images doivent avoir la même forme.")

    # Concaténer les images et leurs reconstructions
    stacked_images = torch.cat((images, reconstructed_images), dim=0)

    # Créer une grille
    grid = torchvision.utils.make_grid(stacked_images, nrow=nrow, normalize=True, scale_each=True)

    # Afficher la figure
    plt.figure(figsize=(12, 6))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis("off")
    plt.title(title)
    plt.show()

In [None]:

class AutoEncoder(nn.Module):
    def __init__(self, latent_channels=16):
        super(AutoEncoder, self).__init__()

        # Encoder : Réduit la taille 256x256x3 → 32x32xlatent_channels
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),  # 128x128
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 64x64
            nn.ReLU(),
#            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 32x32
#            nn.ReLU(),
            nn.Conv2d(128, latent_channels, kernel_size=3, stride=1, padding=1)  # 32x32xlatent_channels
        )

        # Decoder : Reconstruit l'image 32x32xlatent_channels → 256x256x3
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_channels, 128, kernel_size=4, stride=2, padding=1),  # 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 128x128
            nn.ReLU(),
#            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 256x256
#            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),  # Output 3 channels
            nn.Sigmoid()  # Normalisation entre 0 et 1
        )

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed

In [None]:
"""
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        self.shortcut = (
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
            if in_channels != out_channels else nn.Identity()
        )

    def forward(self, x):
        return self.shortcut(x) + self.block(x)

class AutoEncoder(nn.Module):
    def __init__(self, latent_channels=64):  # Plus de canaux latents pour plus de richesse
        super(AutoEncoder, self).__init__()

        # Encoder : Réduit 256x256x3 → 32x32xlatent_channels
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=4, stride=2, padding=1),  # 128x128
            nn.ReLU(inplace=True),
            ResBlock(128, 256),
            nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1),  # 64x64
            nn.ReLU(inplace=True),
            ResBlock(256, 512),
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),  # 32x32
            nn.ReLU(inplace=True),
            ResBlock(512, latent_channels),
            nn.Conv2d(latent_channels, latent_channels, kernel_size=3, stride=1, padding=1),  # 32x32xlatent_channels
            nn.BatchNorm2d(latent_channels),
            nn.ReLU(inplace=True)
        )

        # Decoder : Reconstruit 32x32xlatent_channels → 256x256x3
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_channels, 512, kernel_size=4, stride=2, padding=1),  # 64x64
            nn.ReLU(inplace=True),
            ResBlock(512, 512),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # 128x128
            nn.ReLU(inplace=True),
            ResBlock(256, 256),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 256x256
            nn.ReLU(inplace=True),
            ResBlock(128, 128),
            nn.Conv2d(128, 3, kernel_size=3, stride=1, padding=1),  # Output 3 channels
            nn.Sigmoid()  # Normalisation entre 0 et 1
        )

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed
"""

In [None]:
model = AutoEncoder(latent_channels=2).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
nb_params = sum([p.numel() for p in model.parameters()])
print(f"Number of parameters {nb_params}")

In [None]:
nb_epochs = 10
display_interval = 5

# Stocker un batch de test pour éviter d'appeler next(iter(testloader)) à chaque epoch
test_images, _ = next(iter(testloader))
test_images = test_images.to(device)

for epoch in range(nb_epochs):

    model.train()  # Mode entraînement
    total_loss = 0.0
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{nb_epochs}")
    i = 0

    for images, _ in progress_bar:
        images = images.to(device)
        reconstructed_images = model(images)

        loss = criterion(reconstructed_images, images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.detach().item()  # Éviter les accumulations inutiles
        progress_bar.set_postfix(loss=total_loss / (i + 1))
        i += 1

    avg_loss = total_loss / len(trainloader)

    # Phase de test (pas de gradient)
    model.eval()
    with torch.no_grad():
        reconstructed_test = model(test_images)

    # Afficher les images tous les 5 epochs
    if (epoch + 1) % display_interval == 0:
        images_cpu = test_images.cpu()
        reconstructed_cpu = reconstructed_test.cpu()
        show_images_side_by_side(images_cpu[:24], reconstructed_cpu[:24])

print("Training terminé ! 🚀")

In [None]:
total_test_loss = 0.0
with torch.no_grad():
    for images, _ in tqdm(testloader):
        images = images.to(device)
        reconstructed_images = model(images)
        loss = criterion(reconstructed_images, images)
        total_test_loss += loss.item()

avg_test_loss = total_test_loss / len(testloader)
print(f"Test Loss: {avg_test_loss:.7f}")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
save_path = "/content/drive/My Drive/DeepLearning/Projets/ae_256-3_64-2.pth"

# Sauvegarder les poids uniquement
torch.save(model.state_dict(), save_path)
print(f"Modèle sauvegardé dans {save_path}")