In [8]:
import os
import json
import math
import numpy as np 
import pandas as pd

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline 
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()


## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import FashionMNIST
from torchvision import transforms

  set_matplotlib_formats('svg', 'pdf') # For export


In [9]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Device: cpu


In [10]:
# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor()])

# Loading the training dataset. We need to split it into a training and validation part
complete_set = torchvision.datasets.ImageFolder("trafic_32", transform=transform)

train_size = int(0.84 * len(complete_set))
test_size = len(complete_set) - train_size
train_set, test_set = data.dataset.random_split(complete_set, [train_size, test_size], generator=torch.Generator().manual_seed(42))

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=1000, shuffle=True, drop_last=True, pin_memory=True, num_workers=2)
test_loader = data.DataLoader(test_set, batch_size=1000, shuffle=False, drop_last=True, num_workers=2)

def get_train_images(num):
    return torch.stack([test_set[i][0] for i in range(10,10+num)], dim=0)

In [11]:
print(len(train_set))
print(len(train_set[0]))
print(len(train_set[0][0]))
print(len(train_set[0][0][0]))
print(len(train_set[0][0][0][0]))
print(len(train_loader))

32935
2
3
32
32
32


In [12]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.fc_1 = nn.Linear(input_dim, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_mean  = nn.Linear(hidden_dim, latent_dim)
        self.fc_var   = nn.Linear (hidden_dim, latent_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
        self.training = True
        
    def forward(self, x):
        x = torch.flatten(x, 1)
        x       = self.LeakyReLU(self.fc_1(x))
        x       = self.LeakyReLU(self.fc_2(x))
        mean     = self.fc_mean(x)
        log_var  = self.fc_var(x)                      # encoder produces mean and log of variance 
                                                       #             (i.e., parateters of simple tractable normal distribution "q"
        
        return mean, log_var

In [13]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc_1 = nn.Linear(latent_dim, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_3 = nn.Linear(hidden_dim, output_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        #print(x.size())
        h     = self.LeakyReLU(self.fc_1(x))
        #print(h.size())
        h     = self.LeakyReLU(self.fc_2(h))
        #print(h.size())
        x_hat = torch.sigmoid(self.fc_3(h))
        #print(x_hat.size())
        x_hat = x_hat.view([-1, 3, 32, 32])
        #print(x_hat.size())
        return x_hat

In [14]:
class VAE(nn.Module):
    def __init__(self, x_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
        self.decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)

        
    def reparameterization(self, mean, var):
        z = torch.rand_like(mean) * torch.sqrt(var) + mean
        return z
        
                
    def forward(self, x):
        mean, log_var = self.encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat = self.decoder(z)
        return x_hat, mean, log_var

In [15]:
def vae_loss_function(x,x_hat,mean,log_var):
    reproduction_loss = nn.functional.mse_loss(x_hat,x,reduction='sum')
    KLD = -0.5 * torch.sum(1+log_var - mean.pow(2) - log_var.exp())
    return reproduction_loss + KLD

In [16]:
vae = VAE(latent_dim=32, hidden_dim=1000, x_dim=3072).to(device)

In [17]:
optimizer = optim.Adam(vae.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99)

In [19]:
num_epochs = 20
for n in range(num_epochs):
    losses_epoch = []
    for x, _ in iter(train_loader):
        x = x.to(device)
        out, means, log_var = vae(x)
        loss = vae_loss_function(x, out, means, log_var)
        losses_epoch.append(loss.item())
        loss.backward()               
        optimizer.step()             
        optimizer.zero_grad()  
    L1_list = []
#     if n % 10 == 0:
    for x, _ in iter(test_loader):
        x  = x.to(device)
        out, _, _ = vae(x)
        L1_list.append(torch.mean(torch.abs(out-x)).item())
    print(f"Epoch {n} loss {np.mean(np.array(losses_epoch))}, test L1 = {np.mean(L1_list)}")
    scheduler.step()

KeyboardInterrupt: 

In [None]:
def visualize_reconstructions(model, input_imgs, device):
    # Reconstruct images
    model.eval()
    with torch.no_grad():
        reconst_imgs = model.decoder(torch.randn([input_imgs, model.latent_dim]).to(device))
    reconst_imgs = reconst_imgs.cpu()
    
    # Plotting
    grid = torchvision.utils.make_grid(imgs, nrow=4, normalize=False, range=(-1,1))
    grid = grid.permute(1, 2, 0)
    if len(input_imgs) == 4:
        plt.figure(figsize=(10,10))
    else:
        plt.figure(figsize=(600,400))
    plt.title(f"Reconstructions")
    plt.imshow(grid)
    plt.axis('off')
    plt.show()
    return reconst_imgs

In [None]:
def generate_images(model, n_imgs, device):
    # Generate images
    model.eval()
    with torch.no_grad():
        generated_imgs = model.decoder(torch.randn([n_imgs, model.latent_dim]).to(device))
    generated_imgs = generated_imgs.cpu()
    
    grid = torchvision.utils.make_grid(generated_imgs, nrow=4, normalize=False, range=(-1,1))
    grid = grid.permute(1, 2, 0)
    if len(input_imgs) == 4:
        plt.figure(figsize=(10,10))
    else:
        plt.figure(figsize=(600,400))
    plt.title(f"Generations")
    plt.imshow(grid)
    plt.axis('off')
    plt.show()
    return generated_imgs

In [None]:
reconst_imgs = generate_images(vae, 10, device)

In [None]:
torch.save(reconst_imgs.cpu().detach(),"piatek_brus_maj.pt") 

In [None]:
input_imgs = get_train_images(1000)
reconst_imgs = visualize_reconstructions(vae, 1000, device)

In [None]:
def generate_images(model, n_imgs, device):
    # Reconstruct images
    model.eval()
    with torch.no_grad():
        reconst_imgs = model.decoder(torch.randn([n_imgs, model.latent_dim]).to(device))
    reconst_imgs = reconst_imgs.cpu()
    
    # Plotting
    grid = torchvision.utils.make_grid(reconst_imgs, nrow=4, normalize=False, range=(-1,1))
    grid = grid.permute(1, 2, 0)
    if len(input_imgs) == 4:
        plt.figure(figsize=(10,10))
    else:
        plt.figure(figsize=(600,400))
    plt.title(f"Reconstructions")
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

In [None]:
generate_images(vae, 1000, device)

In [None]:
torch.save(reconst_imgs.cpu().detach(),"piatek_brus_maj.pt")