In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
from torch.utils.data import DataLoader
import torchvision
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [17]:
from torchvision import datasets, transforms
batch_size = 128

# Data loading
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='data', train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [18]:
# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break

Image batch dimensions: torch.Size([128, 1, 28, 28])
Image label dimensions: torch.Size([128])


In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class CVAE(nn.Module):
    def __init__(self, input_dim, latent_dim=10, num_classes=10, image_channels=1):
        super(CVAE, self).__init__()
        
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.input_dim = input_dim
        self.image_channels = image_channels

        # Encodeur
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(image_channels + num_classes, 32, kernel_size=4, stride=2, padding=0),  # Sortie: 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),  # Sortie: 7x7
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0),  # Sortie: 4x4
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(128 * 3 * 3, latent_dim)  # Projection vers l'espace latent (moyenne)
        self.fc_logvar = nn.Linear(128 * 3 * 3, latent_dim)  # Projection vers l'espace latent (log variance)

        # Décodeur
        self.decoder_fc = nn.Linear(latent_dim, 128 * 3 * 3)
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0),  # Sortie: 7x7
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=0),  # Sortie: 14x14
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=4, stride=2, padding=0),  # Sortie: 28x28
            nn.Sigmoid()  # Pour normaliser la sortie entre 0 et 1
        )

    def encode(self, x, labels):
        # Convertir les labels en one-hot et les concaténer au canal d'image
        labels_onehot = F.one_hot(labels, self.num_classes).float()
        labels_onehot = labels_onehot.view(-1, self.num_classes, 1, 1).repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, labels_onehot], dim=1)

        # Encoder l'image + label
        h = self.encoder_conv(x)
        print(h.shape)
        mu = self.fc_mu(h)
        log_var = self.fc_logvar(h)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, labels):
        # Convertir les labels en one-hot et les concaténer à l'espace latent
        labels_onehot = F.one_hot(labels, self.num_classes).float()
        z = torch.cat([z, labels_onehot], dim=1)
        h = self.decoder_fc(z)
        h = h.view(-1, 128, 4, 4)  # Remodeler en 128 canaux de taille 3x3
        x_recon = self.decoder_conv(h)
        return x_recon

    def forward(self, x, labels):
        # Encoder
        mu, log_var = self.encode(x, labels)
        z = self.reparameterize(mu, log_var)
        # Décodeur
        x_recon = self.decode(z, labels)
        return x_recon, mu, log_var

    def loss_function(self, recon_x, x, mu, log_var):
        # Binary Cross Entropy pour la reconstruction
        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
        # KL Divergence pour la régularisation
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return BCE + KLD


latent_dim = 10
num_classes = 10


In [23]:
# Initialisation du modèle
model = CVAE(input_dim=28*28, latent_dim=latent_dim, num_classes=num_classes, image_channels=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Boucle d'entraînement
epochs = 5

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = data.to(device)  # Déplacer les données sur GPU si disponible
        labels = labels.to(device)
        optimizer.zero_grad()

        recon_batch, mu, log_var = model(data, labels)
        loss = model.loss_function(recon_batch, data, mu, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f'Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset)}')

print("Entraînement terminé")


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x512 and 2048x10)