In [31]:
%load_ext autoreload
%autoreload 2
import time
import numpy as np
import torch
from torch.utils.data import DataLoader
import torchvision
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
cpu


In [17]:
from torchvision import datasets, transforms
batch_size = 128

# Data loading
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='data', train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [18]:
# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break

Image batch dimensions: torch.Size([128, 1, 28, 28])
Image label dimensions: torch.Size([128])


In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def to_onehot(labels, num_classes, device):
    labels_onehot = torch.zeros(labels.size(0), num_classes).to(device)
    labels_onehot.scatter_(1, labels.view(-1, 1), 1)
    return labels_onehot

class ConditionalVariationalAutoencoder(nn.Module):

    def __init__(self, num_features, num_latent, num_classes):
        super(ConditionalVariationalAutoencoder, self).__init__()
        
        self.num_classes = num_classes

        ##########################
        # ENCODER
        ##########################
        self.encoder = nn.Sequential(
            nn.Conv2d(1 + num_classes, 16, kernel_size=6, stride=2, padding=0),  # Output: 16, 14, 14
            nn.LeakyReLU(),
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=0),  # Output: 32, 6, 6
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, kernel_size=2, stride=2, padding=0),  # Output: 64, 2, 2
            nn.LeakyReLU()
        )

        self.fc_z_mean = nn.Linear(64 * 2 * 2, num_latent)
        self.fc_z_log_var = nn.Linear(64 * 2 * 2, num_latent)

        ##########################
        # DECODER
        ##########################
        self.fc_decode_input = nn.Linear(num_latent + num_classes, 64 * 2 * 2)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2, padding=0),  # Output: 32, 4, 4
            nn.LeakyReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=3, padding=1),  # Output: 16, 12, 12
            nn.LeakyReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=6, stride=3, padding=4),  # Output: 1, 28, 28
            nn.Sigmoid()
        )

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encoder_step(self, x, labels):
        # Convertir les labels en one-hot et les ajouter comme condition
        onehot_labels = to_onehot(labels, self.num_classes, x.device).view(-1, self.num_classes, 1, 1)
        onehot_labels = onehot_labels.expand(-1, self.num_classes, x.size(2), x.size(3))
        x = torch.cat((x, onehot_labels), dim=1)
        
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        z_mean = self.fc_z_mean(x)
        z_log_var = self.fc_z_log_var(x)
        return z_mean, z_log_var

    def decoder_step(self, z, labels):
        # Convertir les labels en one-hot et les ajouter comme condition
        onehot_labels = to_onehot(labels, self.num_classes, z.device)
        z = torch.cat((z, onehot_labels), dim=1)
        
        x = self.fc_decode_input(z)
        x = x.view(-1, 64, 2, 2)
        x = self.decoder(x)
        return x

    def forward(self, x, labels):
        z_mean, z_log_var = self.encoder_step(x, labels)
        z = self.reparameterize(z_mean, z_log_var)
        decoded = self.decoder_step(z, labels)
        return z_mean, z_log_var, z, decoded

# Paramètres
device = 'cuda' if torch.cuda.is_available() else 'cpu'
random_seed = 42
torch.manual_seed(random_seed)

num_features = 28 * 28  # Pour des images 28x28
num_latent = 10  # Taille de l'espace latent
num_classes = 10  # CIFAR-10 par exemple
learning_rate = 0.001

# Instancier le modèle
model = ConditionalVariationalAutoencoder(num_features, num_latent, num_classes)
model = model.to(device)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [32]:
start_time = time.time()
num_epochs = 5
for epoch in range(num_epochs):
    for batch_idx, (features, targets) in enumerate(train_loader):
        
        features = features.to(device)
        targets = targets.to(device)

        ### FORWARD AND BACK PROP
        z_mean, z_log_var, encoded, decoded = model(features, targets)

        # cost = reconstruction loss + Kullback-Leibler divergence
        kl_divergence = (0.5 * (z_mean**2 + 
                                torch.exp(z_log_var) - z_log_var - 1)).sum()
        
        
        ### Add condition
        onehot_targets = to_onehot(targets, num_classes, device)
        onehot_targets = onehot_targets.view(-1, num_classes, 1, 1)
        
        ones = torch.ones(features.size()[0], 
                          num_classes,
                          features.size()[2], 
                          features.size()[3], 
                          dtype=features.dtype).to(device)
        ones = ones * onehot_targets
        x_con = torch.cat((features, ones), dim=1)
        
        
        ### Compute loss
        pixelwise_bce = F.binary_cross_entropy(decoded, x_con, reduction='sum')
        cost = kl_divergence + pixelwise_bce
        
        ### UPDATE MODEL PARAMETERS
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        
        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' 
                   %(epoch+1, num_epochs, batch_idx, 
                     len(train_loader), cost))
            
    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

torch.Size([128, 512])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x512 and 1152x10)