In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
"""Get the device"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
"""A Convolutional Variational Autoencoder"""
class CVAE(nn.Module):
  def __init__(self, channels=1, features=32*20*20, z=256):
    super(CVAE, self).__init__()
    
    self.enc_conv1 = nn.Conv2d(channels, 16, 5)
    self.enc_conv2 = nn.Conv2d(16, 32, 5)
    self.enc_fc1 = nn.Linear(features, z)
    self.enc_fc2 = nn.Linear(features, z)
    
    self.dec_fc1 = nn.Linear(z, features)
    self.dec_conv1 = nn.ConvTranspose2d(32, 16, 5)
    self.dec_conv2 = nn.ConvTranspose2d(16, channels, 5)
  
  def encoder(self, x):
    x = F.relu(self.enc_conv1(x))
    x = F.relu(self.enc_conv2(x))
    x = x.view(-1, 32*20*20)
    mu = self.enc_fc1(x)
    log_var = self.enc_fc2(x)
    return mu, log_var
  
  def reparameterize(self, mu, logvar):
      std = torch.exp(0.5 * logvar)
      eps = torch.randn_like(std)
      return mu + (eps * std)
  
  def decoder(self, z):
      z = F.relu(self.dec_fc1(z))
      z = z.view(-1, 32, 20, 20)
      z = F.relu(self.dec_conv1(z))
      z = torch.sigmoid(self.dec_conv2(z))
      return z
  
  def forward(self, x):
      mu, logvar = self.encoder(x)
      z = self.reparameterize(mu, logvar)
      out = self.decoder(z)
      return out, mu, logvar

In [4]:
"""Hyperparameters"""
batch_size = 128
lr = 1e-3
epochs = 50

In [5]:
"""Dataloaders"""
# train and validation data
train_loader = torch.utils.data.DataLoader(
  datasets.MNIST(
    root='input/data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
  ), batch_size=batch_size, shuffle=True)


test_loader = torch.utils.data.DataLoader(
  datasets.MNIST(
    root='input/data',
    train=False,
    download=True,
    transform=transforms.ToTensor()
  ), batch_size=batch_size, shuffle=False)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [6]:
"""Init network and optimizer"""
model = CVAE().to(device)
optim = torch.optim.Adam(model.parameters(), lr=lr)

In [8]:
for epoch in range(epochs):
  for idx, data in tqdm(enumerate(train_loader, 0), total=int(len(train_loader.dataset) / train_loader.batch_size)):
      imgs, _ = data
      imgs = imgs.to(device)
      
      out, mu, log_var = model(imgs)
      
      KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
      loss = F.binary_cross_entropy(out, imgs, reduction='sum') + KLD
      print(loss)
      optim.zero_grad()
      loss.backward()
      optim.step()
  print(f"Epoch {epoch}/{epochs}: Loss {loss:.4f}")

  0%|          | 1/468 [00:00<01:22,  5.66it/s]

tensor(15899.1514, grad_fn=<AddBackward0>)
tensor(15919.6562, grad_fn=<AddBackward0>)


  1%|          | 3/468 [00:00<01:05,  7.09it/s]

tensor(16005.9688, grad_fn=<AddBackward0>)
tensor(15904.4580, grad_fn=<AddBackward0>)


  1%|          | 5/468 [00:00<01:03,  7.26it/s]

tensor(15718.1006, grad_fn=<AddBackward0>)
tensor(15817.6592, grad_fn=<AddBackward0>)


  1%|▏         | 7/468 [00:00<01:02,  7.36it/s]

tensor(15592.0029, grad_fn=<AddBackward0>)
tensor(16155.1289, grad_fn=<AddBackward0>)


  2%|▏         | 9/468 [00:01<01:01,  7.43it/s]

tensor(16041.6289, grad_fn=<AddBackward0>)
tensor(15810.9893, grad_fn=<AddBackward0>)


  2%|▏         | 11/468 [00:01<01:05,  6.99it/s]

tensor(15373.3994, grad_fn=<AddBackward0>)
tensor(16035.8848, grad_fn=<AddBackward0>)


  3%|▎         | 13/468 [00:01<01:05,  6.95it/s]

tensor(15780.9570, grad_fn=<AddBackward0>)
tensor(15671.7236, grad_fn=<AddBackward0>)


  3%|▎         | 15/468 [00:02<01:03,  7.16it/s]

tensor(15708.7432, grad_fn=<AddBackward0>)
tensor(16314.8916, grad_fn=<AddBackward0>)


  4%|▎         | 17/468 [00:02<01:05,  6.94it/s]

tensor(15832.1250, grad_fn=<AddBackward0>)
tensor(16492.9551, grad_fn=<AddBackward0>)


  4%|▍         | 19/468 [00:02<01:05,  6.87it/s]

tensor(15948.4141, grad_fn=<AddBackward0>)
tensor(15422.1289, grad_fn=<AddBackward0>)


  4%|▍         | 21/468 [00:02<01:04,  6.98it/s]

tensor(15409.9707, grad_fn=<AddBackward0>)
tensor(16053.3984, grad_fn=<AddBackward0>)


  5%|▍         | 23/468 [00:03<01:02,  7.16it/s]

tensor(16562.4961, grad_fn=<AddBackward0>)
tensor(15927.4150, grad_fn=<AddBackward0>)


  5%|▌         | 24/468 [00:03<01:03,  7.02it/s]


KeyboardInterrupt: 

In [None]:
import random
model.eval()
with torch.no_grad():
  for data in random.sample(list(test_loader), 1):
    imgs, _ = data
    imgs = imgs.to(device)
    img = np.transpose(imgs[0].cpu().numpy(), [1, 2, 0])
    plt.subplot(121)
    plt.imshow(np.squeeze(img))
    out, mu, log_var = model(imgs)
    reconstruction = np.transpose(out[0].cpu().numpy(), [1, 2, 0])
    plt.subplot(122)
    plt.imshow(np.squeeze(reconstruction))