<a   href="https://colab.research.google.com/github/eduardojdiniz/Buzznauts/blob/master/scripts/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install duecredit --quiet
!git clone https://github.com/eduardojdiniz/Buzznauts --quiet

In [10]:
# install pytorch (http://pytorch.org/) if run from Google Colaboratory
# Imports
import torch
import random
import nltk

import numpy as np
import matplotlib.pylab as plt
from sklearn.decomposition import PCA

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, transforms


from tqdm.notebook import tqdm, trange

from google.colab import drive
drive.mount("/content/drive")

import os.path as op

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# @title Download MNIST and CIFAR10 datasets
import tarfile, requests, os

fname = 'MNIST.tar.gz'
name = 'mnist'
url = 'https://osf.io/y2fj6/download'

if not os.path.exists(name):
  print('\nDownloading MNIST dataset...')
  r = requests.get(url, allow_redirects=True)
  with open(fname, 'wb') as fh:
    fh.write(r.content)
  print('\nDownloading MNIST completed..\n')

if not os.path.exists(name):
  with tarfile.open(fname) as tar:
    tar.extractall(name)
    os.remove(fname)
else:
  print('MNIST dataset has been dowloaded.\n')


fname = 'cifar-10-python.tar.gz'
name = 'cifar10'
url = 'https://osf.io/jbpme/download'

if not os.path.exists(name):
  print('\nDownloading CIFAR10 dataset...')
  r = requests.get(url, allow_redirects=True)
  with open(fname, 'wb') as fh:
    fh.write(r.content)
  print('\nDownloading CIFAR10 completed.')

if not os.path.exists(name):
  with tarfile.open(fname) as tar:
    tar.extractall(name)
    os.remove(fname)
else:
  print('CIFAR10 dataset has been dowloaded.')
  

# @markdown Load MNIST and CIFAR10 image datasets
# See https://pytorch.org/docs/stable/torchvision/datasets.html

# MNIST
mnist = datasets.MNIST('./mnist/',
                       train=True,
                       transform=transforms.ToTensor(),
                       download=False)
mnist_val = datasets.MNIST('./mnist/',
                           train=False,
                           transform=transforms.ToTensor(),
                           download=False)

# CIFAR 10
cifar10 = datasets.CIFAR10('./cifar10/',
                           train=True,
                           transform=transforms.ToTensor(),
                           download=False)
cifar10_val = datasets.CIFAR10('./cifar10/',
                               train=False,
                               transform=transforms.ToTensor(),
                               download=False)

def get_data(name='mnist'):
  if name == 'mnist':
    my_dataset_name = "MNIST"
    my_dataset = mnist
    my_valset = mnist_val
    my_dataset_shape = (1, 28, 28)
    my_dataset_size = 28 * 28
  elif name == 'cifar10':
    my_dataset_name = "CIFAR10"
    my_dataset = cifar10
    my_valset = cifar10_val
    my_dataset_shape = (3, 32, 32)
    my_dataset_size = 3 * 32 * 32

  return my_dataset, my_dataset_name, my_dataset_shape, my_dataset_size, my_valset


train_set, dataset_name, data_shape, data_size, valid_set = get_data(name='mnist')

MNIST dataset has been dowloaded.

CIFAR10 dataset has been dowloaded.


In [3]:
class BiasLayer(nn.Module):
  def __init__(self, shape):
    super(BiasLayer, self).__init__()
    init_bias = torch.zeros(shape)
    self.bias = nn.Parameter(init_bias, requires_grad=True)

  def forward(self, x):
    return x + self.bias


def cout(x, layer):
  """Unnecessarily complicated but complete way to
  calculate the output depth, height and width size for a Conv2D layer

  Args:
    x (tuple): input size (depth, height, width)
    layer (nn.Conv2d): the Conv2D layer

  returns:
    (int): output shape as given in [Ref]

  Ref:
    https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
  """
  assert isinstance(layer, nn.Conv2d)
  p = layer.padding if isinstance(layer.padding, tuple) else (layer.padding,)
  k = layer.kernel_size if isinstance(layer.kernel_size, tuple) else (layer.kernel_size,)
  d = layer.dilation if isinstance(layer.dilation, tuple) else (layer.dilation,)
  s = layer.stride if isinstance(layer.stride, tuple) else (layer.stride,)
  in_depth, in_height, in_width = x
  out_depth = layer.out_channels
  out_height = 1 + (in_height + 2 * p[0] - (k[0] - 1) * d[0] - 1) // s[0]
  out_width = 1 + (in_width + 2 * p[-1] - (k[-1] - 1) * d[-1] - 1) // s[-1]
  return (out_depth, out_height, out_width)


# @title Helper functions

#@title Helper functions

def image_moments(image_batches, n_batches=None):
  """
  Compute mean an covariance of all pixels from batches of images
  """
  m1, m2 = torch.zeros((), device=DEVICE), torch.zeros((), device=DEVICE)
  n = 0
  for im in tqdm(image_batches, total=n_batches, leave=False,
                 desc='Computing pixel mean and covariance...'):
    im = im.to(DEVICE)
    b = im.size()[0]
    im = im.view(b, -1)
    m1 = m1 + im.sum(dim=0)
    m2 = m2 + (im.view(b,-1,1) * im.view(b,1,-1)).sum(dim=0)
    n += b
  m1, m2 = m1/n, m2/n
  cov = m2 - m1.view(-1,1)*m1.view(1,-1)
  return m1.cpu(), cov.cpu()


def interpolate(A, B, num_interps):
  if A.shape != B.shape:
    raise ValueError('A and B must have the same shape to interpolate.')
  alphas = np.linspace(0, 1, num_interps)
  return np.array([(1-a)*A + a*B for a in alphas])


def kl_q_p(zs, phi):
  """Given [b,n,k] samples of z drawn from q, compute estimate of KL(q||p).
  phi must be size [b,k+1]

  This uses mu_p = 0 and sigma_p = 1, which simplifies the log(p(zs)) term to
  just -1/2*(zs**2)
  """
  b, n, k = zs.size()
  mu_q, log_sig_q = phi[:,:-1], phi[:,-1]
  log_p = -0.5*(zs**2)
  log_q = -0.5*(zs - mu_q.view(b,1,k))**2 / log_sig_q.exp().view(b,1,1)**2 - log_sig_q.view(b,1,-1)
  # Size of log_q and log_p is [b,n,k]. Sum along [k] but mean along [b,n]
  return (log_q - log_p).sum(dim=2).mean(dim=(0,1))


def log_p_x(x, mu_xs, sig_x):
  """Given [batch, ...] input x and [batch, n, ...] reconstructions, compute
  pixel-wise log Gaussian probability

  Sum over pixel dimensions, but mean over batch and samples.
  """
  b, n = mu_xs.size()[:2]
  # Flatten out pixels and add a singleton dimension [1] so that x will be
  # implicitly expanded when combined with mu_xs
  x = x.reshape(b, 1, -1)
  _, _, p = x.size()
  squared_error = (x - mu_xs.view(b, n, -1))**2 / (2*sig_x**2)

  # Size of squared_error is [b,n,p]. log prob is by definition sum over [p].
  # Expected value requires mean over [n]. Handling different size batches
  # requires mean over [b].
  return -(squared_error + torch.log(sig_x)).sum(dim=2).mean(dim=(0,1))


def pca_encoder_decoder(mu, cov, k):
  """
  Compute encoder and decoder matrices for PCA dimensionality reduction
  """
  mu = mu.view(1,-1)
  u, s, v = torch.svd_lowrank(cov, q=k)
  W_encode = v / torch.sqrt(s)
  W_decode = u * torch.sqrt(s)

  def pca_encode(x):
    # Encoder: subtract mean image and project onto top K eigenvectors of
    # the data covariance
    return (x.view(-1,mu.numel()) - mu) @ W_encode

  def pca_decode(h):
    # Decoder: un-project then add back in the mean
    return (h @ W_decode.T) + mu

  return pca_encode, pca_decode


def cout(x, layer):
  """Unnecessarily complicated but complete way to
  calculate the output depth, height and width size for a Conv2D layer

  Args:
    x (tuple): input size (depth, height, width)
    layer (nn.Conv2d): the Conv2D layer

  returns:
    (int): output shape as given in [Ref]

  Ref:
    https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
  """
  assert isinstance(layer, nn.Conv2d)
  p = layer.padding if isinstance(layer.padding, tuple) else (layer.padding,)
  k = layer.kernel_size if isinstance(layer.kernel_size, tuple) else (layer.kernel_size,)
  d = layer.dilation if isinstance(layer.dilation, tuple) else (layer.dilation,)
  s = layer.stride if isinstance(layer.stride, tuple) else (layer.stride,)
  in_depth, in_height, in_width = x
  out_depth = layer.out_channels
  out_height = 1 + (in_height + 2 * p[0] - (k[0] - 1) * d[0] - 1) // s[0]
  out_width = 1 + (in_width + 2 * p[-1] - (k[-1] - 1) * d[-1] - 1) // s[-1]
  return (out_depth, out_height, out_width)

Convolutional Auto Encoder [FULL]


In [4]:
K_VAE = 1024

class ConvVarAutoEncoder(nn.Module):
  def __init__(self, K, num_filters=[192, 256, 384, 512, 768], filter_size=3):
    super(ConvVarAutoEncoder, self).__init__()
    ## 5 Conv Layers
    filter_reduction = 5 * (filter_size // 2)
    self.fs = 0

    self.shape_after_conv = (768,
                             data_shape[1] - 2 * filter_reduction,
                             data_shape[2] - 2 * filter_reduction)
    
    # Double for each additional layer of Conv
    flat_size_after_conv = self.shape_after_conv[0] * self.shape_after_conv[1] * self.shape_after_conv[2]

    # ENCODER
    self.q_bias = BiasLayer(data_shape)
    self.q_conv_1 = nn.Conv2d(data_shape[0], num_filters[0], filter_size)
    self.q_conv_2 = nn.Conv2d(num_filters[0], num_filters[1], filter_size)
    self.q_conv_3 = nn.Conv2d(num_filters[1], num_filters[2], filter_size)
    self.q_conv_4 = nn.Conv2d(num_filters[2], num_filters[3], filter_size)
    self.q_conv_5 = nn.Conv2d(num_filters[3], num_filters[4], filter_size)
    self.q_flatten = nn.Flatten()
    self.q_fc_phi = nn.Linear(248832, K+1)

    # DECODER
    self.p_fc_upsample = nn.Linear(K, 248832)
    self.p_unflatten = nn.Unflatten(-1, self.shape_after_conv)
    self.p_deconv_1 = nn.ConvTranspose2d(num_filters[4], num_filters[3], filter_size)
    self.p_deconv_2 = nn.ConvTranspose2d(num_filters[3], num_filters[2], filter_size)
    self.p_deconv_3 = nn.ConvTranspose2d(num_filters[2], num_filters[1], filter_size)
    self.p_deconv_4 = nn.ConvTranspose2d(num_filters[1], num_filters[0], filter_size)
    self.p_deconv_5 = nn.ConvTranspose2d(num_filters[0], data_shape[0], filter_size)

    self.p_bias = BiasLayer(data_shape)

    # Define a special extra parameter to learn scalar sig_x for all pixels
    self.log_sig_x = nn.Parameter(torch.zeros(()))


  def infer(self, x):
    """Map (batch of) x to (batch of) phi which can then be passed to
    rsample to get z
    """
    s = self.q_bias(x)
    s = F.relu(self.q_conv_1(s))
    s = F.
    s = F.relu(self.q_conv_2(s))
    s = F.relu(self.q_conv_3(s))
    s = F.relu(self.q_conv_4(s))
    s = F.relu(self.q_conv_5(s))
    flat_s = s.view(s.size()[0], -1)
    self.fs = s.size()
    print("FLAT S")
    print(s.size())
    phi = self.q_fc_phi(flat_s)
    return phi


  def generate(self, zs):
    """Map [b,n,k] sized samples of z to [b,n,p] sized images
    """
    # Note that for the purposes of passing through the generator, we need
    # to reshape zs to be size [b*n,k]
    b, n, k = zs.size()
    s = zs.view(b*n, -1)
    s = F.relu(self.p_fc_upsample(s)).view((b*n,) + self.shape_after_conv)
    s = F.relu(self.p_deconv_1(s))
    s = F.relu(self.p_deconv_2(s))
    s = F.relu(self.p_deconv_3(s))
    s = F.relu(self.p_deconv_4(s))
    s = self.p_deconv_5(s)
    s = self.p_bias(s)
    mu_xs = s.view(b, n, -1)
    return mu_xs


  def decode(self, zs):
    # Included for compatability with conv-AE code
    return self.generate(zs.unsqueeze(0))


  def forward(self, x):
    # VAE.forward() is not used for training, but we'll treat it like a
    # classic autoencoder by taking a single sample of z ~ q
    phi = self.infer(x)
    zs = rsample(phi, 1)
    return self.generate(zs).view(x.size())


  def elbo(self, x, n=1):
    """Run input end to end through the VAE and compute the ELBO using n
    samples of z
    """
    phi = self.infer(x)
    zs = rsample(phi, n)
    mu_xs = self.generate(zs)
    return log_p_x(x, mu_xs, self.log_sig_x.exp()) - kl_q_p(zs, phi)


def expected_z(phi):
  return phi[:, :-1]


def rsample(phi, n_samples):
  """Sample z ~ q(z;phi)
  Ouput z is size [b,n_samples,K] given phi with shape [b,K+1]. The first K
  entries of each row of phi are the mean of q, and phi[:,-1] is the log
  standard deviation
  """
  b, kplus1 = phi.size()
  k = kplus1-1
  mu, sig = phi[:, :-1], phi[:,-1].exp()
  eps = torch.randn(b, n_samples, k, device=phi.device)
  return eps*sig.view(b,1,1) + mu.view(b,1,k)


def train_vae(vae, dataset, epochs=10, n_samples=1000):
  opt = torch.optim.Adam(vae.parameters(), lr=1e-3, weight_decay=0)
  elbo_vals = []
  vae.to(DEVICE)
  vae.train()
  loader = DataLoader(dataset, batch_size=10, shuffle=True, pin_memory=True)
  for epoch in trange(epochs, desc='Epochs'):
    for im, _ in tqdm(loader, total=len(dataset) // 10, desc='Batches', leave=False):
      im = im.to(DEVICE)
      opt.zero_grad()
      loss = -vae.elbo(im)
      loss.backward()
      opt.step()

      elbo_vals.append(-loss.item())
  vae.to('cpu')
  vae.eval()
  return elbo_vals


convVAE = ConvVarAutoEncoder(K=K_VAE)

In [5]:
convVAE

ConvVarAutoEncoder(
  (q_bias): BiasLayer()
  (q_conv_1): Conv2d(1, 192, kernel_size=(3, 3), stride=(1, 1))
  (q_conv_2): Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1))
  (q_conv_3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1))
  (q_conv_4): Conv2d(384, 512, kernel_size=(3, 3), stride=(1, 1))
  (q_conv_5): Conv2d(512, 768, kernel_size=(3, 3), stride=(1, 1))
  (q_flatten): Flatten(start_dim=1, end_dim=-1)
  (q_fc_phi): Linear(in_features=248832, out_features=1025, bias=True)
  (p_fc_upsample): Linear(in_features=1024, out_features=248832, bias=True)
  (p_unflatten): Unflatten(dim=-1, unflattened_size=(768, 18, 18))
  (p_deconv_1): ConvTranspose2d(768, 512, kernel_size=(3, 3), stride=(1, 1))
  (p_deconv_2): ConvTranspose2d(512, 384, kernel_size=(3, 3), stride=(1, 1))
  (p_deconv_3): ConvTranspose2d(384, 256, kernel_size=(3, 3), stride=(1, 1))
  (p_deconv_4): ConvTranspose2d(256, 192, kernel_size=(3, 3), stride=(1, 1))
  (p_deconv_5): ConvTranspose2d(192, 1, kernel_size=(3, 3)

In [6]:
DEVICE = 'cpu'

#trained_CVAE = train_vae(convVAE, train_set, epochs = 1, n_samples = 10)

In [7]:
p = 0
for param in convVAE.parameters():
  p += torch.numel(param)

p

523386147