In [4]:
import warnings
warnings.filterwarnings('ignore')

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
import argparse
import matplotlib
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import random

matplotlib.style.use('ggplot')

In [21]:
features = 10
# define a simple linear VAE #until now normal VAE without Beta
class LinearVAE(nn.Module):
    def __init__(self):
        super(LinearVAE, self).__init__()
 
        # encoder 
        self.enc1 = nn.Linear(in_features=784, out_features=512)
        self.enc2 = nn.Linear(in_features=512, out_features=features*2)
 
        # decoder 
        self.dec1 = nn.Linear(in_features=features, out_features=512)
        self.dec2 = nn.Linear(in_features=512, out_features=784)
        

    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling as if coming from the input space
        return sample
 
    def forward(self, x):
        # encoding
        x = F.relu(self.enc1(x))
        x = self.enc2(x).view(-1, 2, features)

        # get `mu` and `log_var`
        mu = x[:, 0, :] # the first feature values as mean
        log_var = x[:, 1, :] # the other feature values as variance

        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)
 
        # decoding
        x = F.relu(self.dec1(z))
        reconstruction = torch.sigmoid(self.dec2(x))
        return reconstruction, mu, log_var

Parameters for training

In [78]:
# leanring parameters
epochs = 10
train_games = 30
val_games = 10
batch_size = 64
beta = 150
lr = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [75]:
def final_loss(bce_loss, mu, logvar, beta, kl_wheight):
    """
    This function will add the reconstruction loss (BCELoss) and the (one could also take the mse loss instead of bce)
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss 
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta*kl_wheight*KLD

In [56]:
# transforms why do i need a transform? used it diretly
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [36]:
import torchvision.datasets as datasets
from torchvision.transforms import ToTensor

In [37]:
train_data = datasets.MNIST(root='./data', train=True, download=True,  transform = ToTensor())
val_data = datasets.MNIST(root='./data', train=False, download=True,  transform = ToTensor())

In [38]:
train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True
)
val_loader = DataLoader(
    val_data,
    batch_size=batch_size,
    shuffle=False
)

In [76]:
model = LinearVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
#criterion = nn.BCELoss(reduction='sum')
criterion = torch.nn.MSELoss(reduction = 'sum')
print(model)

LinearVAE(
  (enc1): Linear(in_features=784, out_features=512, bias=True)
  (enc2): Linear(in_features=512, out_features=20, bias=True)
  (dec1): Linear(in_features=10, out_features=512, bias=True)
  (dec2): Linear(in_features=512, out_features=784, bias=True)
)


Training Loop (we train the autoencoder on one image in the buffer not on the total buffer. This could also be a nice feature)

In [70]:
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
        data, _ = data
        data = data.to(device)
        data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        bce_loss = criterion(reconstruction, data)
        loss = final_loss(bce_loss, mu, logvar, beta, kl_wheight = dataloader.batch_size/len(train_data))
        running_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

In [71]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
            data, _ = data
            data = data.to(device)
            data = data.view(data.size(0), -1)
            reconstruction, mu, logvar = model(data)
            bce_loss = criterion(reconstruction, data)
            loss = final_loss(bce_loss, mu, logvar, beta, kl_wheight = dataloader.batch_size/len(val_data))
            running_loss += loss.item()
        
            # save the last batch input and output of every epoch
            if i == int(len(val_data)/dataloader.batch_size) - 1:
                num_rows = 8
                both = torch.cat((data.view(batch_size, 1, 28, 28)[:8], 
                                  reconstruction.view(batch_size, 1, 28, 28)[:8]))
                save_image(both.cpu(), f"MNISToutput{epoch}.png", nrow=num_rows)
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

In [79]:
train_loss = []
val_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = fit(model, train_loader)
    val_epoch_loss = validate(model, val_loader)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

Epoch 1 of 10


938it [00:12, 74.00it/s]                                                                                               
157it [00:01, 100.16it/s]                                                                                              


Train Loss: 27.9726
Val Loss: 47.2227
Epoch 2 of 10


938it [00:13, 69.96it/s]                                                                                               
157it [00:01, 99.71it/s]                                                                                               


Train Loss: 24.3459
Val Loss: 43.6279
Epoch 3 of 10


938it [00:13, 69.63it/s]                                                                                               
157it [00:01, 85.53it/s]                                                                                               


Train Loss: 22.9887
Val Loss: 41.5026
Epoch 4 of 10


938it [00:13, 71.81it/s]                                                                                               
157it [00:01, 102.17it/s]                                                                                              


Train Loss: 22.0560
Val Loss: 40.1989
Epoch 5 of 10


938it [00:12, 75.12it/s]                                                                                               
157it [00:01, 87.73it/s]                                                                                               


Train Loss: 21.3658
Val Loss: 39.2782
Epoch 6 of 10


938it [00:13, 69.13it/s]                                                                                               
157it [00:01, 98.73it/s]                                                                                               


Train Loss: 20.8052
Val Loss: 38.5122
Epoch 7 of 10


938it [00:13, 67.96it/s]                                                                                               
157it [00:01, 96.86it/s]                                                                                               


Train Loss: 20.3612
Val Loss: 38.2274
Epoch 8 of 10


938it [00:13, 69.51it/s]                                                                                               
157it [00:01, 96.84it/s]                                                                                               


Train Loss: 19.9902
Val Loss: 37.7074
Epoch 9 of 10


938it [00:12, 73.63it/s]                                                                                               
157it [00:01, 101.17it/s]                                                                                              


Train Loss: 19.6773
Val Loss: 37.4635
Epoch 10 of 10


938it [00:12, 74.31it/s]                                                                                               
157it [00:01, 96.86it/s]                                                                                               

Train Loss: 19.4062
Val Loss: 37.2062



