In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

In [3]:
# MNIST dataset
dataset = torchvision.datasets.MNIST(root='../../data',
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)


In [4]:
# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [5]:
# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        # For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    
    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))



Epoch[1/15], Step [10/469], Reconst Loss: 37079.0781, KL Div: 3871.6084
Epoch[1/15], Step [20/469], Reconst Loss: 29926.7949, KL Div: 954.2637
Epoch[1/15], Step [30/469], Reconst Loss: 28103.2969, KL Div: 1145.8278
Epoch[1/15], Step [40/469], Reconst Loss: 26669.2852, KL Div: 694.5894
Epoch[1/15], Step [50/469], Reconst Loss: 25759.6875, KL Div: 745.4025
Epoch[1/15], Step [60/469], Reconst Loss: 24631.3672, KL Div: 915.0233
Epoch[1/15], Step [70/469], Reconst Loss: 25130.7383, KL Div: 914.4310
Epoch[1/15], Step [80/469], Reconst Loss: 23800.0410, KL Div: 954.4619
Epoch[1/15], Step [90/469], Reconst Loss: 24106.8301, KL Div: 1122.5366
Epoch[1/15], Step [100/469], Reconst Loss: 22789.3613, KL Div: 1320.8110
Epoch[1/15], Step [110/469], Reconst Loss: 22088.5840, KL Div: 1353.2518
Epoch[1/15], Step [120/469], Reconst Loss: 21427.8555, KL Div: 1439.4177
Epoch[1/15], Step [130/469], Reconst Loss: 20197.9727, KL Div: 1664.0808
Epoch[1/15], Step [140/469], Reconst Loss: 19475.9492, KL Div: 172

Epoch[3/15], Step [210/469], Reconst Loss: 11299.6523, KL Div: 2942.2620
Epoch[3/15], Step [220/469], Reconst Loss: 11406.1777, KL Div: 3020.8474
Epoch[3/15], Step [230/469], Reconst Loss: 11675.7891, KL Div: 3063.2966
Epoch[3/15], Step [240/469], Reconst Loss: 11567.3652, KL Div: 3155.5940
Epoch[3/15], Step [250/469], Reconst Loss: 11369.1729, KL Div: 3072.6624
Epoch[3/15], Step [260/469], Reconst Loss: 12318.1611, KL Div: 3097.6685
Epoch[3/15], Step [270/469], Reconst Loss: 11274.5605, KL Div: 2928.3696
Epoch[3/15], Step [280/469], Reconst Loss: 11449.5156, KL Div: 3079.3726
Epoch[3/15], Step [290/469], Reconst Loss: 11667.5713, KL Div: 3172.6001
Epoch[3/15], Step [300/469], Reconst Loss: 11331.6562, KL Div: 3055.3020
Epoch[3/15], Step [310/469], Reconst Loss: 11258.3369, KL Div: 2924.6538
Epoch[3/15], Step [320/469], Reconst Loss: 11145.0215, KL Div: 3132.3491
Epoch[3/15], Step [330/469], Reconst Loss: 11869.5176, KL Div: 3156.2590
Epoch[3/15], Step [340/469], Reconst Loss: 11348.56

Epoch[5/15], Step [410/469], Reconst Loss: 11175.7988, KL Div: 3190.4648
Epoch[5/15], Step [420/469], Reconst Loss: 11042.8350, KL Div: 3219.3569
Epoch[5/15], Step [430/469], Reconst Loss: 10477.7246, KL Div: 3167.5459
Epoch[5/15], Step [440/469], Reconst Loss: 10937.5908, KL Div: 3063.1104
Epoch[5/15], Step [450/469], Reconst Loss: 10433.0127, KL Div: 3179.4331
Epoch[5/15], Step [460/469], Reconst Loss: 10567.0195, KL Div: 3307.3638
Epoch[6/15], Step [10/469], Reconst Loss: 10277.8701, KL Div: 3107.6472
Epoch[6/15], Step [20/469], Reconst Loss: 10736.3896, KL Div: 3215.6562
Epoch[6/15], Step [30/469], Reconst Loss: 10381.4922, KL Div: 3111.1646
Epoch[6/15], Step [40/469], Reconst Loss: 10541.6162, KL Div: 3251.4485
Epoch[6/15], Step [50/469], Reconst Loss: 11202.1729, KL Div: 3195.9629
Epoch[6/15], Step [60/469], Reconst Loss: 11120.1133, KL Div: 3203.3801
Epoch[6/15], Step [70/469], Reconst Loss: 10566.9072, KL Div: 3228.8340
Epoch[6/15], Step [80/469], Reconst Loss: 11102.8262, KL D

Epoch[8/15], Step [150/469], Reconst Loss: 10581.7822, KL Div: 3277.4067
Epoch[8/15], Step [160/469], Reconst Loss: 10696.2998, KL Div: 3176.2520
Epoch[8/15], Step [170/469], Reconst Loss: 10841.8955, KL Div: 3247.9673
Epoch[8/15], Step [180/469], Reconst Loss: 10614.6348, KL Div: 3285.1055
Epoch[8/15], Step [190/469], Reconst Loss: 10458.0625, KL Div: 3225.2378
Epoch[8/15], Step [200/469], Reconst Loss: 10663.6475, KL Div: 3250.8474
Epoch[8/15], Step [210/469], Reconst Loss: 10759.5781, KL Div: 3276.9731
Epoch[8/15], Step [220/469], Reconst Loss: 10547.7236, KL Div: 3295.5410
Epoch[8/15], Step [230/469], Reconst Loss: 10610.0605, KL Div: 3211.9624
Epoch[8/15], Step [240/469], Reconst Loss: 10583.4365, KL Div: 3265.1924
Epoch[8/15], Step [250/469], Reconst Loss: 10705.7217, KL Div: 3163.0298
Epoch[8/15], Step [260/469], Reconst Loss: 10623.8379, KL Div: 3358.5610
Epoch[8/15], Step [270/469], Reconst Loss: 10312.5693, KL Div: 3161.5137
Epoch[8/15], Step [280/469], Reconst Loss: 10444.04

Epoch[10/15], Step [350/469], Reconst Loss: 9753.1230, KL Div: 3118.3948
Epoch[10/15], Step [360/469], Reconst Loss: 10843.2363, KL Div: 3350.0386
Epoch[10/15], Step [370/469], Reconst Loss: 10069.0068, KL Div: 3288.1152
Epoch[10/15], Step [380/469], Reconst Loss: 10314.9043, KL Div: 3176.8750
Epoch[10/15], Step [390/469], Reconst Loss: 10388.3965, KL Div: 3238.1863
Epoch[10/15], Step [400/469], Reconst Loss: 10685.2920, KL Div: 3280.2310
Epoch[10/15], Step [410/469], Reconst Loss: 10286.0938, KL Div: 3290.4988
Epoch[10/15], Step [420/469], Reconst Loss: 10190.5059, KL Div: 3300.4453
Epoch[10/15], Step [430/469], Reconst Loss: 9976.7441, KL Div: 3155.2769
Epoch[10/15], Step [440/469], Reconst Loss: 10132.3604, KL Div: 3182.5500
Epoch[10/15], Step [450/469], Reconst Loss: 10496.7734, KL Div: 3255.4468
Epoch[10/15], Step [460/469], Reconst Loss: 10102.9082, KL Div: 3291.2695
Epoch[11/15], Step [10/469], Reconst Loss: 10224.4102, KL Div: 3300.2705
Epoch[11/15], Step [20/469], Reconst Loss

Epoch[13/15], Step [80/469], Reconst Loss: 10517.2100, KL Div: 3236.8618
Epoch[13/15], Step [90/469], Reconst Loss: 10518.3818, KL Div: 3371.1323
Epoch[13/15], Step [100/469], Reconst Loss: 9966.4336, KL Div: 3212.1768
Epoch[13/15], Step [110/469], Reconst Loss: 10587.7227, KL Div: 3331.4158
Epoch[13/15], Step [120/469], Reconst Loss: 10077.3799, KL Div: 3224.2639
Epoch[13/15], Step [130/469], Reconst Loss: 10293.9961, KL Div: 3272.3022
Epoch[13/15], Step [140/469], Reconst Loss: 10057.9736, KL Div: 3231.3379
Epoch[13/15], Step [150/469], Reconst Loss: 10035.7061, KL Div: 3244.3662
Epoch[13/15], Step [160/469], Reconst Loss: 10268.5605, KL Div: 3288.2234
Epoch[13/15], Step [170/469], Reconst Loss: 10341.8672, KL Div: 3279.3843
Epoch[13/15], Step [180/469], Reconst Loss: 10165.0889, KL Div: 3214.9832
Epoch[13/15], Step [190/469], Reconst Loss: 10008.9863, KL Div: 3257.1475
Epoch[13/15], Step [200/469], Reconst Loss: 10407.5166, KL Div: 3310.4790
Epoch[13/15], Step [210/469], Reconst Los

Epoch[15/15], Step [270/469], Reconst Loss: 10070.1602, KL Div: 3202.6008
Epoch[15/15], Step [280/469], Reconst Loss: 9670.2266, KL Div: 3101.6777
Epoch[15/15], Step [290/469], Reconst Loss: 10122.2920, KL Div: 3314.1934
Epoch[15/15], Step [300/469], Reconst Loss: 10149.4854, KL Div: 3250.3789
Epoch[15/15], Step [310/469], Reconst Loss: 10218.1260, KL Div: 3171.2886
Epoch[15/15], Step [320/469], Reconst Loss: 10129.1729, KL Div: 3312.7554
Epoch[15/15], Step [330/469], Reconst Loss: 10249.3652, KL Div: 3352.6230
Epoch[15/15], Step [340/469], Reconst Loss: 10151.9746, KL Div: 3166.7422
Epoch[15/15], Step [350/469], Reconst Loss: 9736.6914, KL Div: 3234.3789
Epoch[15/15], Step [360/469], Reconst Loss: 9945.6211, KL Div: 3223.7300
Epoch[15/15], Step [370/469], Reconst Loss: 10025.1953, KL Div: 3223.1973
Epoch[15/15], Step [380/469], Reconst Loss: 10390.7354, KL Div: 3240.0039
Epoch[15/15], Step [390/469], Reconst Loss: 10423.6377, KL Div: 3269.2993
Epoch[15/15], Step [400/469], Reconst Los