In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import imageio
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torchvision.utils import save_image
to_pil_image = transforms.ToPILImage()

def save_reconstructed_images(recon_images, epoch):
    save_image(recon_images.cpu(), f"./outputs/output{epoch}.jpg")
def save_loss_plot(train_loss, valid_loss):
    # loss plots
    plt.figure(figsize=(10, 7))
    plt.plot(train_loss, color='orange', label='train loss')
    plt.plot(valid_loss, color='red', label='validataion loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('./outputs/loss.jpg')
    plt.show()

In [None]:
from tqdm import tqdm
import torch
def final_loss(bce_loss, mu, logvar):
    BCE = bce_loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
def train(model, dataloader, dataset, device, optimizer, criterion):
    model.train()
    running_loss = 0.0
    counter = 0
    for i, batch_data in tqdm(enumerate(dataloader), total=int(len(dataset)/dataloader.batch_size)):
        counter += 1
        batch_data = batch_data[0]
        batch_data = batch_data.to(device)
        #print(batch_data.size())
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(batch_data)
        bce_loss = criterion(reconstruction, batch_data)
        loss = final_loss(bce_loss, mu, logvar)
        loss.backward()
        running_loss += loss.item()
        optimizer.step()
    train_loss = running_loss / counter
    return train_loss

In [None]:
def validate(model, dataloader, dataset, device, criterion):
    model.eval()
    running_loss = 0.0
    counter = 0
    with torch.no_grad():
        for i, batch_data in tqdm(enumerate(dataloader), total=int(len(dataset)/dataloader.batch_size)):
            counter += 1
            batch_data = batch_data[0]
            batch_data = batch_data.to(device)
            reconstruction, mu, logvar = model(batch_data)
            bce_loss = criterion(reconstruction, batch_data)
            loss = final_loss(bce_loss, mu, logvar)
            running_loss += loss.item()

            # save the last batch input and output of every epoch
            if i == int(len(dataset)/dataloader.batch_size) - 1:
                recon_images = reconstruction
    val_loss = running_loss / counter
    return val_loss, recon_images

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
kernel_size = 4 # (4, 4) kernel
init_channels = 8 # initial number of filters
image_channels = 1 # MNIST images are grayscale
latent_dim = 16 # latent dimension for sampling

# Classe de l'Encodeur

In [None]:
class VAE_Encoder(nn.Module):
  def __init_(self, image_channels, init_channels, kernel_size, latent_dim):
    super(VAE_Encoder, self).__init__()

    self.conv1 = nn.Conv2d(
      in_channels=image_channels, out_channels=init_channels, kernel_size=kernel_size,
      stride=2, padding=1
    )
    self.conv2 = nn.Conv2d(
      in_channels=init_channels, out_channels=init_channels * 2, kernel_size=kernel_size,
      stride=2, padding=1
    )
    self.conv3 = nn.Conv2d(
      in_channels=init_channels * 2, out_channels=init_channels * 4, kernel_size=kernel_size,
      stride=2, padding=1
    )
    self.conv4 = nn.Conv2d(
      in_channels=init_channels * 4, out_channels=64, kernel_size=kernel_size,
      stride=2, padding=0
    )

    self.fc1 = nn.Linear(64, 128)
    self.fc_mu = nn.Linear(128, latent_dim)
    self.fc_log_var = nn.Linear(128, latent_dim)

    def forward(self, x):
      x = F.relu(self.conv1(x))
      x = F.relu(self.conv2(x))
      x = F.relu(self.conv3(x))
      x = F.relu(self.conv4(x))
      batch, _, _, _ = x.shape
      x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
      hidden = F.relu(self.fc1(x))
      mu = self.fc_mu(hidden)
      log_var = self.fc_log_var(hidden)
      return mu, log_var

# Classe du décodeur

In [None]:
class VAE_Decoder(nn.Module):
  def __init__(self, image_channels, init_channels, kernel_size, latent_dim):
    super(VAE_Decoder, self).__init__()

    self.fc2 = nn.Linear(latent_dim, 64)
    self.deconv1 = nn.ConvTranspose2d(
      in_channels=64, out_channels=init_channels * 8, kernel_size=kernel_size,
      stride=1, padding=0
    )
    self.deconv2 = nn.ConvTranspose2d(
      in_channels=init_channels * 8, out_channels=init_channels * 4, kernel_size=kernel_size,
      stride=2, padding=1
    )
    self.deconv3 = nn.ConvTranspose2d(
      in_channels=init_channels * 4, out_channels=init_channels * 2, kernel_size=kernel_size,
      stride=2, padding=1
    )
    self.deconv4 = nn.ConvTranspose2d(
      in_channels=init_channels * 2, out_channels=image_channels, kernel_size=kernel_size,
      stride=2, padding=1
    )

    def forward(self, z):
      z = F.relu(self.fc2(z))
      z = z.view(-1, 64, 1, 1)
      x = F.relu(self.deconv1(z))
      x = F.relu(self.deconv2(x))
      x = F.relu(self.deconv3(x))
      reconstruction = torch.sigmoid(self.deconv4(x))
      return reconstruction

# Classe du modèle (Encodeur + Décodeur)

In [None]:
# define a Conv VAE
class ConvVAE(nn.Module):
    def __init__(self, encoder, decoder):
      super(ConvVAE, self).__init__()
      self.encoder = encoder;
      self.decoder = decoder;

    def sample(self, x):
      mu, log_var = self.encoder()
      std = torch.exp(0.5*log_var) # standard deviation
      eps = torch.randn_like(std) # `randn_like` as we need the same size
      sample = mu + (eps * std) # sampling
      return sample

    def forward(self, x):
      # encoding
      x = F.relu(self.enc1(x))
      x = F.relu(self.enc2(x))
      x = F.relu(self.enc3(x))
      x = F.relu(self.enc4(x))
      batch, _, _, _ = x.shape
      x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
      hidden = self.fc1(x)
      # get `mu` and `log_var`
      mu = self.fc_mu(hidden)
      log_var = self.fc_log_var(hidden)
      # get the latent vector through reparameterization
      z = self.sample(mu, log_var)
      z = self.fc2(z)
      z = z.view(-1, 64, 1, 1)

      # decoding
      x = F.relu(self.dec1(z))
      x = F.relu(self.dec2(x))
      x = F.relu(self.dec3(x))
      reconstruction = torch.sigmoid(self.dec4(x))
      return reconstruction, mu, log_var

In [None]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
#import model
import torchvision.transforms as transforms
import torchvision
import matplotlib
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
matplotlib.style.use('ggplot')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# initialize the model
model = ConvVAE().to(device)
# set the learning parameters
lr = 0.001
epochs = 100
batch_size = 64
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss(reduction='sum')
# a list to save all the reconstructed images in PyTorch grid format
grid_images = []

In [None]:
import torch
import torch.nn as nn
from torchsummary import summary

summary(model=model, input_size=(1, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 16, 16]             136
            Conv2d-2             [-1, 16, 8, 8]           2,064
            Conv2d-3             [-1, 32, 4, 4]           8,224
            Conv2d-4             [-1, 64, 1, 1]          32,832
            Linear-5                  [-1, 128]           8,320
            Linear-6                   [-1, 16]           2,064
            Linear-7                   [-1, 16]           2,064
            Linear-8                   [-1, 64]           1,088
   ConvTranspose2d-9             [-1, 64, 4, 4]          65,600
  ConvTranspose2d-10             [-1, 32, 8, 8]          32,800
  ConvTranspose2d-11           [-1, 16, 16, 16]           8,208
  ConvTranspose2d-12            [-1, 1, 32, 32]             257
Total params: 163,657
Trainable params: 163,657
Non-trainable params: 0
-------------------------------

In [None]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])
# training set and train data loader
trainset = torchvision.datasets.MNIST(
    root='../input', train=True, download=True, transform=transform
)
trainloader = DataLoader(
    trainset, batch_size=batch_size, shuffle=True
)
# validation set and validation data loader
testset = torchvision.datasets.MNIST(
    root='../input', train=False, download=True, transform=transform
)
testloader = DataLoader(
    testset, batch_size=batch_size, shuffle=False
)

In [None]:
train_loss = []
valid_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = train(
        model, trainloader, trainset, device, optimizer, criterion
    )
    valid_epoch_loss, recon_images = validate(
        model, testloader, testset, device, criterion
    )
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    # save the reconstructed images from the validation loop
    #save_reconstructed_images(recon_images, epoch+1)
    # convert the reconstructed images to PyTorch image grid format
    image_grid = make_grid(recon_images.detach().cpu())
    grid_images.append(image_grid)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {valid_epoch_loss:.4f}")

# enregistrer le modèle
checkpoint_dir = '/content/drive/MyDrive/VAE/model'
os.makedirs(checkpoint_dir, exist_ok=True)
model_save_path = os.path.join(checkpoint_dir, 'final_model.pth')
torch.save(model.state_dict(), model_save_path)

In [None]:
file_path = '/content/drive/MyDrive/VAE/model/final_model.pth'
checkpoint = torch.load(file_path)
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [None]:
def generate_new_image():
  # Assuming z_mean and z_log_var are tensors
  z_sample = model.sample(self, mu, log_var)

  # Pass through the decoder
  z = self.fc2(z_sample)
  z = z.view(-1, 64, 1, 1)
  x = F.relu(self.dec1(z))
  x = F.relu(self.dec2(x))
  x = F.relu(self.dec3(x))
  generated_image = torch.sigmoid(self.dec4(x))
  return generated_image

In [None]:
model.generate_new_image()