# Eaxmple training CAE and VAE on MNIST


In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import torch

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from Autoencoders.encoders import Encoder2DConv, VAEEncoder2DConv
from Autoencoders.decoders import Decoder2DConv
from Autoencoders.autoencoders import Autoencoder, VAE
from Autoencoders.losses import vae_loss

## Load FashionMNIST data and create a dataloader

In [2]:
"""MNIST = datasets.MNIST('./sampledata/MNIST', download=True, train=True, transform=transforms.ToTensor())
dataloader = DataLoader(MNIST, batch_size=32, num_workers=2)"""

"MNIST = datasets.MNIST('./sampledata/MNIST', download=True, train=True, transform=transforms.ToTensor())\ndataloader = DataLoader(MNIST, batch_size=32, num_workers=2)"

In [3]:
traindata = datasets.FashionMNIST('./sampledata/FashionMNIST', download=True, train=True, transform=transforms.ToTensor())
trainloader = DataLoader(traindata, batch_size=32, num_workers=2)

testdata = datasets.FashionMNIST('./sampledata/FashionMNIST', download=True, train=False, transform=transforms.ToTensor())
testloader = DataLoader(testdata, batch_size=32, num_workers=2)

In [4]:
for data, _ in trainloader:
    print(data.size())
    break

torch.Size([32, 1, 28, 28])


## Create the autoencoders


In [5]:
inputdims = (28,28)
latentdims = 32
nlayers = 2

# CAE
cae_encoder = Encoder2DConv(inputdims, latentdims, nlayers=nlayers, use_batchnorm=True)
cae_decoder = Decoder2DConv(inputdims, latentdims, nlayers=nlayers, use_batchnorm=True)
cae = Autoencoder(cae_encoder, cae_decoder)
cae_loss = torch.nn.functional.mse_loss
cae_optimizer = torch.optim.Adam(cae.parameters())

#VAE
vae_encoder = VAEEncoder2DConv(inputdims, latentdims, use_batchnorm=True)
vae_decoder = Decoder2DConv(inputdims, latentdims, use_batchnorm=True)
vae = VAE(vae_encoder, vae_decoder)
#vae_loss
vae_optimizer = torch.optim.Adam(vae.parameters())

## Train the CAE

In [None]:
epochs = 2

def train_cae(epochs):
    cae.train()
    train_loss = 0
    for batch_idx, (x, _) in enumerate(trainloader):
        cae_optimizer.zero_grad()
        recon_x = cae(x)
        loss = cae_loss(recon_x, x, reduction='sum')
        loss.backward()
        train_loss += loss.item()
        cae_optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(x), len(trainloader.dataset),
                100. * batch_idx / len(trainloader),
                loss.item() / len(x)),
                end="\r", flush=True)

    print('\n====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(trainloader.dataset)))
    
for epoch in range(epochs):
    train_cae(epoch)
    



## Train the VAE

In [7]:
epochs = 2

def train_vae(epochs):
    cae.train()
    train_loss = 0
    recon_loss = 0
    kld_loss = 0
    for batch_idx, (x, _) in enumerate(trainloader):
        vae_optimizer.zero_grad()
        recon_x, mu, logvar = vae(x)
        loss, rloss, kloss = vae_loss(recon_x, x, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        recon_loss += rloss.item()
        kld_loss += kloss.item()
        vae_optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tRecon Loss: {:.6f}\tKLD loss: {:.6f}'.format(
                    epoch, batch_idx * len(x), 
                    len(trainloader.dataset),
                    100. * batch_idx / len(trainloader),
                    loss.item() / len(x),
                    rloss.item() / len(x),
                    kloss.item() / len(x)),
                end="\r", flush=True)

    print('\n====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(trainloader.dataset)))
    
for epoch in range(epochs):
    train_vae(epoch)

====> Epoch: 0 Average loss: 33.7637Loss: 29.769445	Recon Loss: 20.544529	KLD loss: 9.224916
====> Epoch: 1 Average loss: 28.2514Loss: 28.627825	Recon Loss: 19.285494	KLD loss: 9.342331
