In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms

from context import rf_pool

**Load MNIST**

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=True, num_workers=2)

**Build VAE**

In [None]:
# init model
model = rf_pool.models.VAE()

In [None]:
# append Autoencoder layer
model.append('0', rf_pool.modules.Autoencoder(input_shape=(-1, 28*28),
                                              linear=torch.nn.Linear(28*28, 1024),
                                              activation=torch.nn.LeakyReLU(0.2),
                                              reconstruct_activation=torch.nn.Sigmoid()))

In [None]:
# add branching module to model mu, var of z
model.add_output_branch(1024)

In [None]:
# view model and output shapes
print(model)
model.output_shapes((1,1,28,28))

**Set Metrics**

In [None]:
class Metrics(object):
    def show_recon(self, dataloader, model=None, cmap='gray'):
        x = iter(dataloader).next()[0]
        recon = model.reconstruct(model.forward(x))
        return rf_pool.utils.visualize.show_images(recon, cmap=cmap)
    
    def show_samples(self, n_samples=1, model=None, cmap='gray'):
        z = torch.randn(n_samples, 1024)
        recon = model.reconstruct(z)
        return rf_pool.utils.visualize.show_images(recon, cmap=cmap)

**Train VAE**

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# train VAE and monitor samples from model
loss_history = model.train_model(10, trainloader, torch.nn.BCELoss(reduction='sum'),
                                 optimizer=optimizer, monitor=100,
                                 metrics=Metrics(),
                                 show_samples={'n_samples': 100, 'model': model})