In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

from tqdm import tqdm

In [2]:
batch_size = 100

input_dim = 784
hidden_dim = 256
latent_dim = 2
epochs = 50
epsilon_std = 1.0

lr = 1e-4

In [3]:
class Net(nn.Module):
  def __init__(self, input_dim, hidden_dim, latent_dim):
    super(Net, self).__init__()
    input_dim = torch.tensor((input_dim, ))
    self.enc1 = nn.Linear(input_dim, hidden_dim)
    self.enc21 = nn.Linear(hidden_dim, latent_dim)
    self.enc22 = nn.Linear(hidden_dim, latent_dim)

    self.dec1 = nn.Linear(latent_dim, hidden_dim)
    self.dec2 = nn.Linear(hidden_dim, input_dim)

  def reparameterize(self, mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.rand_like(std)
    sample = mu + (eps * std)
    return sample
        
  def forward(self, x):
    x = F.relu(self.enc1(x))
    mu = self.enc21(x)
    logvar = self.enc22(x)

    z = self.reparameterize(mu, logvar)

    x = F.relu(self.dec1(z))
    x = self.dec2(x)
    reconstruction = torch.sigmoid(x)
    return reconstruction, mu, logvar

In [4]:
model = Net(input_dim, hidden_dim, latent_dim)

In [5]:
def final_loss(reconstruction, data, mu, log_var):
    BCE = nn.BCELoss(reduction='sum')(reconstruction, data)
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [6]:
opt = torch.optim.Adam(model.parameters(), lr=lr)

In [7]:
# transformer
transform = transforms.Compose([transforms.ToTensor(),])

# train and validation data
train_data = datasets.MNIST(
    root='input/data',
    train=True,
    download=True,
    transform=transform
)

# training data loader
train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True
)

val_data = datasets.MNIST(
    root='input/data',
    train=False,
    download=True,
    transform=transform
)

val_loader = DataLoader(
    val_data,
    batch_size=batch_size,
    shuffle=False
)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [8]:
def log_to_message(log, precision=4):
  fmt = "{0}: {1:." + str(precision) + "f}"
  return " | ".join(fmt.format(k, v) for k, v in log.items())

In [9]:
class ProgressBar():
  def __init__(self, n, length=40):
    # protect against division by zero
    self.n = max(1, n)
    self.nf = float(n)
    self.length = length
    # precalculate the i values that should trigger a write operation
    self.ticks = set([round(i / 100.0 * n) for i in range(101)])
    self.ticks.add(n - 1)
  
  def bar(self, i, message=""):
    """Assumes i ranges through [0, n - 1]"""
    if i in self.ticks:
      b = int(np.ceil(((i + 1) / self.nf) * self.length))
      sys.stdout.write("\r[{0}{1}] {2}%\t{3}".format(
        "="*b, " " * (self.length - b), int(100 * ((i + 1) / self.nf)), message))
      sys.stdout.flush()
  
  def close(self, message=""):
    """Move the bar to 100% before closing"""
    self.bar(self.n-1)
    sys.stdout.write("{0}\n".format(message))
    sys.stdout.flush()

In [10]:
from collections import OrderedDict 
p = ProgressBar(len(train_loader))
for epoch in range(epochs):
  model.train()
  log = OrderedDict()
  print(f"Epoch {epoch + 1}/{epochs}")
  running_loss = 0.0
  # Train
  for i, data in enumerate(train_loader):
      data, _ = data
      data = data.view(data.size(0), -1)
      opt.zero_grad()
      reconstruction, mu, log_var = model(data)
      loss = final_loss(reconstruction, data, mu, log_var)
      running_loss += loss.item()
      p.bar(i)
      loss.backward()
      opt.step()
  log['Train Loss'] = running_loss / len(train_loader.dataset)
  p.bar(i)
  
  # validate
  model.eval()
  running_loss = 0.0
  with torch.no_grad():
    for i, data in enumerate(val_loader):
        data, _ = data
        data = data.view(data.size(0), -1)
        reconstruction, mu, log_var = model(data)
        loss = final_loss(reconstruction, data, mu, log_var)
        running_loss += loss.item()
    log['Validation Loss'] = running_loss / len(val_loader.dataset)
  p.close(log_to_message(log))

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
