In [1]:
import torch
import torch.nn.functional as F
from pyPDMP.models import VAE
from pyPDMP.utils import loss_function

import matplotlib.pyplot as plt
import torchvision
import torchvision.datasets as datasets

In [2]:
m = VAE().cuda()

In [3]:
m

VAE(
  (fc1): Linear(in_features=2, out_features=400, bias=True)
  (fc21): Linear(in_features=400, out_features=20, bias=True)
  (fc22): Linear(in_features=400, out_features=20, bias=True)
  (fc3): Linear(in_features=20, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=2, bias=True)
)

## Training the VAE

In [4]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()
    ])

dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

In [5]:
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=32, 
                                          shuffle=True)

In [8]:
size = 28*28

def train(m, loader, epoch):
    optimizer = torch.optim.Adam(m.parameters(), lr=1e-3)
    m.train()
    train_loss = 0
    for batch_idx, (x, _) in enumerate(loader):
        # flatten the batch
        x = x.cuda().view(-1, size)
        optimizer.zero_grad()
        recon_batch, mu, logvar = m(x)
        loss = loss_function(recon_batch, x, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(x), len(loader.dataset),
                100. * batch_idx / len(loader),
                loss.item() / len(x)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(loader.dataset)))

In [9]:
train(m, data_loader, 10)

====> Epoch: 10 Average loss: 280.7677
