In [7]:
!git clone https://github.com/dalabis/VAE_1

Cloning into 'VAE_1'...
remote: Enumerating objects: 11, done.[K
remote: Counting objects: 100% (11/11), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 11 (delta 3), reused 11 (delta 3), pack-reused 0[K
Unpacking objects: 100% (11/11), done.


In [0]:
import os
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

from VAE_1.VAE import VAE
from VAE_1.trainer import train, test

In [0]:
# configuration
CUDA = True
SEED = 1
BATCH_SIZE = 128
LOG_INTERVAL = 10
EPOCHS = 30
# connections through the autoencoder bottleneck
ZDIMS = 10

In [0]:
torch.manual_seed(SEED)
if CUDA:
    torch.cuda.manual_seed(SEED)
    
# DataLoader instances will load tensors directly into GPU memory
kwargs = {'num_workers': 1, 'pin_memory':True} if CUDA else {}

In [0]:
# Download or load downloaded MNIST dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE,
    shuffle=True,
    **kwargs)

# Same for test data
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE,
    shuffle=True,
    **kwargs)

In [0]:
model = VAE(ZDIMS)
if CUDA:
    model.cuda()

optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [19]:
from VAE_1.loss import loss_function

def train(epoch):
    # toggle model to train mode
    model.train()
    train_loss = 0
    # in the case of MNIST, len(train_loader.dataset) is 60000
    # each 'data' is of BATCH_SIZE samples and has shape [128, 1, 28, 28]
    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data)
        if CUDA:
            data = data.cuda()
        optimizer.zero_grad()
        
        # push whole batch of data through VAE.forvard() to get recon_loss
        recon_batch, mu, logvar = model(data)
        # calculate scalar loss
        loss = loss_function(recon_batch, data, mu, logvar, BATCH_SIZE)
        # calculate the gradient of the loss w.r.t. the graph leaves
        loss.backward()
        train_loss += loss.data.item()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]]tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.data.item() / len(data)))
            
    print('====> Epoch: {} Average loss: {:.5f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    # toggle model to test / inference mode
    model.eval()
    test_loss = 0

    # each data is of BATCH_SIZE (default 128) samples
    for i, (data, _) in enumerate(test_loader):
        if CUDA:
            # make sure this lives on the GPU
            data = data.cuda()

        # we're only going to infer, so no autograd at all required: volatile=True
        data = Variable(data, volatile=True)
        recon_batch, mu, logvar = model(data)
        test_loss += loss_function(recon_batch, data, mu, logvar, BATCH_SIZE).item()
        if i == 0:
          n = min(data.size(0), 8)
          # for the first 128 batch of the epoch, show the first 8 input digits
          # with right below them the reconstructed output digits
          comparison = torch.cat([data[:n],
                                  recon_batch.view(BATCH_SIZE, 1, 28, 28)[:n]])
          save_image(comparison.data.cpu(),
                     'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.5f}'.format(test_loss))
                                

for epoch in range(1, EPOCHS + 1):
    train(epoch)
    test(epoch)
    
    # 64 sets of random ZDIMS-float vectors, i.e. 64 locations / MNIST digits in latent space
    sample = Variable(torch.randn(64, ZDIMS))
    if CUDA:
        sample = sample.cuda()
    sample = model.decode(sample).cpu()
    
    # save out as an 8x8 matrix of MNIST digits
    save_image(sample.data.view(64, 1, 28, 28), 'results/sample ' + str(epoch) + '.png')

====> Epoch: 1 Average loss: 0.00163




====> Test set loss: 0.00125
====> Epoch: 2 Average loss: 0.00125
====> Test set loss: 0.00116
====> Epoch: 3 Average loss: 0.00119
====> Test set loss: 0.00113
====> Epoch: 4 Average loss: 0.00117
====> Test set loss: 0.00111
====> Epoch: 5 Average loss: 0.00115
====> Test set loss: 0.00109
====> Epoch: 6 Average loss: 0.00114
====> Test set loss: 0.00108
====> Epoch: 7 Average loss: 0.00113
====> Test set loss: 0.00108
====> Epoch: 8 Average loss: 0.00112
====> Test set loss: 0.00107
====> Epoch: 9 Average loss: 0.00111
====> Test set loss: 0.00107
====> Epoch: 10 Average loss: 0.00111
====> Test set loss: 0.00106
====> Epoch: 11 Average loss: 0.00110
====> Test set loss: 0.00106
====> Epoch: 12 Average loss: 0.00110
====> Test set loss: 0.00105
====> Epoch: 13 Average loss: 0.00109
====> Test set loss: 0.00105
====> Epoch: 14 Average loss: 0.00109
====> Test set loss: 0.00105
====> Epoch: 15 Average loss: 0.00109
====> Test set loss: 0.00105
====> Epoch: 16 Average loss: 0.00108
===

In [20]:
!zip -r /content/result.zip /content/results

updating: content/results/ (stored 0%)
updating: content/results/sample 3.png (deflated 3%)
updating: content/results/reconstruction_9.png (deflated 3%)
updating: content/results/reconstruction_3.png (deflated 3%)
updating: content/results/reconstruction_6.png (deflated 3%)
updating: content/results/sample 5.png (deflated 3%)
updating: content/results/reconstruction_8.png (deflated 4%)
updating: content/results/reconstruction_2.png (deflated 3%)
updating: content/results/.ipynb_checkpoints/ (stored 0%)
updating: content/results/sample 4.png (deflated 3%)
updating: content/results/sample 2.png (deflated 2%)
updating: content/results/sample 6.png (deflated 4%)
updating: content/results/reconstruction_1.png (deflated 3%)
updating: content/results/reconstruction_10.png (deflated 3%)
updating: content/results/reconstruction_5.png (deflated 3%)
updating: content/results/reconstruction_7.png (deflated 3%)
updating: content/results/sample 10.png (deflated 4%)
updating: content/results/sample 1