In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
from tqdm import tnrange
import numpy as np
%matplotlib inline

from vae import VAE

In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")    

In [3]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [4]:
train = dsets.MNIST("./Mnist/",train=True,transform=img_transform,download=True)
test = dsets.MNIST("./Mnist/",train=False,transform=img_transform,download=True)

In [5]:
train_data = torch.utils.data.DataLoader(train,batch_size=128,shuffle=True)
test_data = torch.utils.data.DataLoader(test,batch_size=128,shuffle=False)

In [10]:
Model = VAE(nz = 20,device = device).to(device)

In [11]:
def lossfun(x_out,x_in,z_mu,z_logvar):
    bce_loss = F.binary_cross_entropy(x_out,x_in,size_average=False)
    kld_loss = -0.5 * torch.sum(1 + z_logvar - (z_mu ** 2) - torch.exp(z_logvar))
    loss = (bce_loss + kld_loss) / x_out.size(0) # normalize by batch size
    return loss

In [12]:
optimizer = torch.optim.Adam(Model.parameters())

In [13]:
for i in tnrange(100,desc = 'Epochs'):
    losses = []
    for j,(images,labels) in enumerate(train_data):
        images = Variable(images.view(-1,1,28,28)).to(device)
        
        out,z_mean,z_stddev = Model.forward(images)
        
        total_loss = lossfun(out,images,z_mean,z_stddev)
        
        optimizer.zero_grad()
        
        total_loss.backward()
        
        optimizer.step()
        
        losses.append(total_loss.item())
    training_loss = sum(losses)/len(losses)
    print("Epoch : {} Train_error : {}".format(i,training_loss) )

HBox(children=(IntProgress(value=0, description='Epochs', style=ProgressStyle(description_width='initial')), H…



Epoch : 0 Train_error : -18093.990177577492
Epoch : 1 Train_error : -25350.228861273987
Epoch : 2 Train_error : -27363.270872201494
Epoch : 3 Train_error : -28904.151244336354
Epoch : 4 Train_error : -30086.612410880865
Epoch : 5 Train_error : -30972.195204224412
Epoch : 6 Train_error : -31563.743041211354
Epoch : 7 Train_error : -31882.836341451228
Epoch : 8 Train_error : -32021.75579274387
Epoch : 9 Train_error : -32052.882862473347
Epoch : 10 Train_error : -32083.368599247067
Epoch : 11 Train_error : -32137.95954657516
Epoch : 12 Train_error : -32233.71772971082
Epoch : 13 Train_error : -32360.087661580492
Epoch : 14 Train_error : -32511.29450959488
Epoch : 15 Train_error : -32673.43857442697
Epoch : 16 Train_error : -32761.441277152186
Epoch : 17 Train_error : -32652.28536447228
Epoch : 18 Train_error : -32537.055982642592
Epoch : 19 Train_error : -32101.50544709488
Epoch : 20 Train_error : -31440.02249633529
Epoch : 21 Train_error : -30732.447119869405
Epoch : 22 Train_error : -29