In [41]:
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd # this module is useful to work with tabular data
import random # this module will be used to select random samples from a collection
import os # this module will be used just to create directories in the local filesystem
from tqdm import tqdm # this module is useful to plot progress bars
from scipy.stats import loguniform
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn, optim
from torch.nn import functional as F
from torchvision.utils import save_image

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
train_data = torchvision.datasets.FashionMNIST('dataset', train=True, download=True,
                            transform=transforms.ToTensor())
test_data = torchvision.datasets.FashionMNIST('dataset', train=False,
                           transform=transforms.ToTensor())

In [6]:
kwargs = {'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {}

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=128, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=128, shuffle=True, **kwargs)

In [7]:
Z_DIM = 9
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc21 = nn.Linear(500, Z_DIM)  # fc21 for mean of Z
        self.fc22 = nn.Linear(500, Z_DIM)  # fc22 for log variance of Z
        self.fc3 = nn.Linear(Z_DIM, 500)
        self.fc4 = nn.Linear(500, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        mu = self.fc21(h1)
        logvar = self.fc22(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.rand_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        # x: [batch size, 1, 28,28] -> x: [batch size, 784]
        x = x.view(-1, 784)
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar



In [42]:
lr = 1e-3
net = VAE().to(device)
optimizer = optim.Adam(net.parameters(), lr=lr)

In [43]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1)

    return BCE + KLD

In [44]:
def train(epoch):
    net.train()
    train_loss = 0
    for batch_idx, (data, label) in enumerate(train_loader):
        # data: [batch size, 1, 28, 28]
        # label: [batch size] -> we don't use
        optimizer.zero_grad()
        data = data.to(device)
        recon_data, mu, logvar = net(data)
        loss = loss_function(recon_data, data, mu, logvar)
        loss.backward()
        cur_loss = loss.item()
        train_loss += cur_loss
        optimizer.step()

    print('====> Epoch: {} Average train loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)
    ))

In [45]:
def test(epoch):
    net.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (data, label) in enumerate(test_loader):
            data = data.to(device)
            recon_data, mu, logvar = net(data)
            cur_loss = loss_function(recon_data, data, mu, logvar).item()
            test_loss += cur_loss
            

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [46]:
def save_generated_img(image, name, epoch, nrow=8):
    if not os.path.exists('results'):
        os.makedirs('results')

    if epoch % 5 == 0:
        save_path = 'results/'+name+'_'+str(epoch)+'.png'
        save_image(image, save_path, nrow=nrow)

In [47]:
def sample_from_model(epoch):
    with torch.no_grad():
        # p(z) = N(0,I), this distribution is used when calculating KLD. So we can sample z from N(0,I)
        sample = torch.randn(64, Z_DIM).to(device)
        sample = net.decode(sample).cpu().view(64, 1, 28, 28)
        save_generated_img(sample, 'sample', epoch)

In [48]:
for epoch in range(1, 21):
        train(epoch)
        test(epoch)
        sample_from_model(epoch)

====> Epoch: 1 Average train loss: 261.4599
====> Test set loss: 244.2838
====> Epoch: 2 Average train loss: 239.9300
====> Test set loss: 239.7411
====> Epoch: 3 Average train loss: 237.0445
====> Test set loss: 237.9378
====> Epoch: 4 Average train loss: 235.5700
====> Test set loss: 237.0128
====> Epoch: 5 Average train loss: 234.5605
====> Test set loss: 235.9473
====> Epoch: 6 Average train loss: 233.7970
====> Test set loss: 235.5838
====> Epoch: 7 Average train loss: 233.1818
====> Test set loss: 234.8398
====> Epoch: 8 Average train loss: 232.6991
====> Test set loss: 234.5191
====> Epoch: 9 Average train loss: 232.3266
====> Test set loss: 234.0180
====> Epoch: 10 Average train loss: 231.9380
====> Test set loss: 233.9178
====> Epoch: 11 Average train loss: 231.6364
====> Test set loss: 233.6476
====> Epoch: 12 Average train loss: 231.3661
====> Test set loss: 233.3229
====> Epoch: 13 Average train loss: 231.1427
====> Test set loss: 233.1493
====> Epoch: 14 Average train loss