# Variational AutoEncoders

This project is an experimentation to implement a variational autoencoder.

The focus here is not on implementation but on the mathematics of these tools.

We focus on images. Our model will be tested on two well-known dataset: the FashionMNIST and the MNIST datasets. Our goal is to new generate images form the training data.

Here the structure of a variational autoencoder (credits to Wikipedia):

![Autoencoder Schema Wikipedia](Reparameterized_Variational_Autoencoder.png)

Notes on **hyperparameter tuning**:
- H_DIM at least 150, more than 300 doesn't seems to have effects
- Z_DIM 20 is fine, if lower than 15 the generation will produce far worse results
- NUM_EPOCHS at least 10, 20 is good
- LR_RATE 3e-4 is fine, you can double it if you increase the number of epochs
- ALPHA regulates the mixture of losses, if alpha=1 all the generated images will be the same

I learned some ideas from https://github.com/karpathy/minGPT

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim


import torchvision.datasets as datasets

from tqdm import tqdm

from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader


In [3]:

if torch.backends.mps.is_available():
        DEVICE = torch.device('mps')
elif torch.cuda.is_available():
        DEVICE = torch.device('cuda')
else:
        DEVICE = torch.device('cpu')



In [4]:
# Input img -> hidden dim -> mean, std -> parametrization trick -> deocder -> output img

class VariationalAutoEncoder(nn.Module):
    """
    Here is the pipeline:
    Get the image (relu, linear)
    Dense layer - encoder(linear)
    Split in mu and sigma (two linears)
    Add noise, aka parametrization trick, get the code
    Dense layer - decoder (linear)
    Output the image (linear, sigmoid)
    """
    def __init__(self, input_dim, h_dim = 200, z_dim =20):
        super().__init__()
        
        # Encoder
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)
        
        # Decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)
        
        self.relu = nn.ReLU()
        
    def encode(self, x):
        h = F.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
        
        return mu, sigma
        
    
    def decode(self, z):
        h = self.relu(self.z_2hid(z))
        
        return torch.sigmoid(self.hid_2img(h)) # we must be sure that values are from 0 to 1, because we normalized the image.
        
    def forward(self, x):
         
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        
        z_reparametrized = mu + sigma*epsilon
        x_reconstructed = self.decode(z_reparametrized)
        
        return x_reconstructed, mu, sigma
    
    


In [5]:
# Hyperparameters

DEVICE = torch.device('mps')

# FashionMNIST and MNIST have 28x28 dimension
INPUT_DIM = 784

H_DIM = 300        # more power in terms of understanding the features in the image
Z_DIM = 20         # more compression

NUM_EPOCHS = 20
BATCH_SIZE = 32

LR_RATE = 3e-4 

ALPHA = .5




In [6]:
def train(model, num_epochs, train_loader, optimizer, loss_fn, alpha):
    for epoch in range(num_epochs):
        loop = tqdm(enumerate(train_loader))
        for i, (x, _) in loop:
            
            
            x=x.to(DEVICE).view(x.shape[0], INPUT_DIM)
            x_reconstructed, mu, sigma = model(x)
            
            # compute loss
            reconstruction_loss =loss_fn(x_reconstructed, x)   # help in reconstruciton
            kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2)) # kl divergence  # shirnk to gaussian
            loss = alpha*reconstruction_loss + (1-alpha)*kl_div

            # backprop
            optimizer.zero_grad()
            loss.backward()
            #torch.nn.utils.clip_grad_norm(model.parameters(p)
            optimizer.step()
            loop.set_postfix(loss=loss.item())
            
    return model

Kullback-Leibler Divergence
$$ \text{KL}\left(q(z | x) \| p(z)\right) = -\frac{1}{2} \sum_{i=1}^{N} \left(1 + \log(\sigma_i^2) - \mu_i^2 - \sigma_i^2\right) $$

In [7]:
def inference(model, labels, folder, dataset, n_examples=1):
    
    len_labels = len(labels)
    
    for label in labels:
    
        images = []
        idx = 0
        for x, y in dataset:
            if y == idx:
                images.append(x)
                idx += 1
            if idx == len_labels:
                break
        
        encodings_digit = []
        
        for d in range(len_labels):
            
            with torch.no_grad():
                mu, sigma = model.encode(images[d].view(1, 784).to(DEVICE))
            
            encodings_digit.append((mu, sigma))

        mu, sigma = encodings_digit[label]
        
        for example in range(n_examples):
            epsilon = torch.randn_like(sigma)
            z = mu + sigma * epsilon
            out = model.decode(z)
            out = out.view(-1, 1, 28, 28)
            save_image(out, folder + f"/label_{label}_{example}.png")


## Fashion MNIST dataset

In [8]:
# import data from standard datasets of pytorch
dataset = datasets.FashionMNIST(root="dataset/", train = True, transform = transforms.ToTensor(), download = True)
train_loader = DataLoader(dataset=dataset, batch_size = BATCH_SIZE, shuffle=True) #not using workers!

# instantiate the model we build
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)

# set the optimizer ( one here can add also the scheduler)
optimizer = optim.Adam(model.parameters(), lr = LR_RATE)

# set the loss function
loss_fn = nn.BCELoss(reduction='sum')


In [9]:
model = train(model=model, num_epochs=NUM_EPOCHS, train_loader=train_loader, optimizer=optimizer, loss_fn=loss_fn, alpha=ALPHA)

        

0it [00:00, ?it/s]

1875it [00:23, 80.23it/s, loss=4.64e+3]
1875it [00:22, 83.61it/s, loss=4.12e+3]
1875it [00:22, 83.10it/s, loss=4.44e+3]
1875it [00:22, 84.04it/s, loss=4.11e+3]
1875it [00:22, 84.39it/s, loss=4.04e+3]
1875it [00:22, 82.82it/s, loss=4.18e+3]
1875it [00:22, 83.61it/s, loss=4.15e+3]
1875it [00:22, 82.48it/s, loss=4.17e+3]
1875it [00:22, 82.81it/s, loss=4.51e+3]
1875it [00:23, 79.00it/s, loss=4.4e+3] 
1875it [00:22, 82.12it/s, loss=4.24e+3]
1875it [00:21, 87.08it/s, loss=4.11e+3]
1875it [00:22, 84.15it/s, loss=4.56e+3]
1875it [00:22, 83.45it/s, loss=4.38e+3]
1875it [00:23, 80.05it/s, loss=4.67e+3]
1875it [00:23, 79.20it/s, loss=4.84e+3]
1875it [00:22, 83.05it/s, loss=4.04e+3]
1875it [00:21, 85.63it/s, loss=4.48e+3]
1875it [00:22, 84.80it/s, loss=4.28e+3]
1875it [00:21, 86.05it/s, loss=4.41e+3]


In [10]:
# Get all possible labels from dataset

labels =  []
for _, y in dataset:
    labels.append(y)
    
labels = list(set(labels))



In [11]:
# Generate new images

inference(model=model, labels=labels, dataset=dataset, folder= "generated/fashion", n_examples=10)


## MNIST dataset

In [12]:
dataset = datasets.MNIST(root="dataset/", train = True, transform = transforms.ToTensor(), download = True)
train_loader = DataLoader(dataset=dataset, batch_size = BATCH_SIZE, shuffle=True) #not using workers!

model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr = LR_RATE)

loss_fn = nn.BCELoss(reduction='sum')

In [13]:
model = train(model=model, num_epochs=NUM_EPOCHS, train_loader=train_loader, optimizer=optimizer, loss_fn=loss_fn, alpha=ALPHA)

        

1875it [00:22, 85.13it/s, loss=2.39e+3]
1875it [00:21, 86.08it/s, loss=2.23e+3]
1875it [00:22, 84.43it/s, loss=2.08e+3]
1875it [00:21, 86.39it/s, loss=2.2e+3] 
1875it [00:22, 84.61it/s, loss=1.99e+3]
1875it [00:22, 85.09it/s, loss=2.11e+3]
1875it [00:21, 85.56it/s, loss=1.92e+3]
1875it [00:22, 84.46it/s, loss=2.04e+3]
1875it [00:21, 86.35it/s, loss=2.21e+3]
1875it [00:22, 84.19it/s, loss=2.1e+3] 
1875it [00:23, 78.81it/s, loss=2.04e+3]
1875it [00:22, 81.99it/s, loss=2.06e+3]
1875it [00:22, 84.14it/s, loss=2.05e+3]
1875it [00:22, 81.81it/s, loss=2.16e+3]
1875it [00:22, 83.71it/s, loss=2.02e+3]
1875it [00:21, 85.46it/s, loss=1.92e+3]
257it [00:03, 84.49it/s, loss=1.91e+3]

In [None]:
# Get all possible labels from dataset

labels =  []
for _, y in dataset:
    labels.append(y)
    
labels = list(set(labels))



[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [None]:
# Generate new images

inference(model=model, labels=labels, dataset=dataset, folder= "generated/mnist", n_examples=10)
