In [5]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

torch.manual_seed(1)

<torch._C.Generator at 0x7fe6281acf78>

# Load dataset

In [6]:
# MNIST dataset
dataset = torchvision.datasets.MNIST(root='./data',
                         train=True,
                         transform=transforms.ToTensor(),
                         download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=100, 
                                          shuffle=True)

# Design linear model

In [7]:
def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

# VAE model
class LinearVAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(LinearVAE, self).__init__()
        self.linear1 = nn.Linear(image_size, h_dim)
        self.linear2 = nn.Linear(h_dim, z_dim)
        self.linear3 = nn.Linear(z_dim, h_dim)
        self.linear4 = nn.Linear(h_dim, 784)
    
    def encoder(self, x):
        x = F.relu(self.linear1(x.view(-1, 28 * 28)))
        z_mean = self.linear2(x)
        z_log_var = self.linear2(x)
        return z_mean, z_log_var
    
    def decoder(self, z):
        z = F.relu(self.linear3(z))
        recon_x = F.sigmoid(self.linear4(z))
        return recon_x
    
    def reparameterize(self, mu, log_var):
        """"z = mean + eps * sigma where eps is sampled from N(0, 1)."""
        eps = to_var(torch.randn(mu.size(0), mu.size(1)))
        z = mu + eps * torch.exp(log_var/2)    # 2 for convert var to std
        return z
                     
    def forward(self, x):
        mu, log_var = self.encoder(x) # torch.chunk(h, 2, dim=1)  # mean and log variance.
        z = self.reparameterize(mu, log_var)
        out = self.decoder(z)
        return out, mu, log_var
    
    def sample(self, z):
        return self.decoder(z)

# Train linear model

In [8]:
#Define model
vae = LinearVAE()

if torch.cuda.is_available():
    vae.cuda()
    
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
iter_per_epoch = len(data_loader)
data_iter = iter(data_loader)

# fixed inputs for debugging
fixed_z = to_var(torch.randn(100, 20))
fixed_x, _ = next(data_iter)
torchvision.utils.save_image(fixed_x.cpu(), './data/real_images.png')
fixed_x = to_var(fixed_x.view(fixed_x.size(0), -1))

for epoch in range(50):
    for i, (images, _) in enumerate(data_loader):
        
        images = to_var(images.view(images.size(0), -1))
        out, mu, log_var = vae(images)
        
        # Compute reconstruction loss and kl divergence
        reconst_loss = F.binary_cross_entropy(out, images, size_average=False)
        kl_divergence = torch.sum(0.5 * (mu**2 + torch.exp(log_var) - log_var -1))
        
        # Backprop + Optimize
        total_loss = reconst_loss + kl_divergence
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print ("Epoch[%d/%d], Step [%d/%d], Total Loss: %.4f, "
                   "Reconst Loss: %.4f, KL Div: %.7f" 
                   %(epoch+1, 50, i+1, iter_per_epoch, total_loss.data[0], 
                     reconst_loss.data[0], kl_divergence.data[0]))
    
    # Save the reconstructed images
    reconst_images, _, _ = vae(fixed_x)
    reconst_images = reconst_images.view(reconst_images.size(0), 1, 28, 28)
    torchvision.utils.save_image(reconst_images.data.cpu(), 
        './data/reconst_images_%d.png' %(epoch+1))

Epoch[1/50], Step [1/600], Total Loss: 55014.4961, Reconst Loss: 55004.5234, KL Div: 9.9731340
Epoch[1/50], Step [101/600], Total Loss: 19750.7656, Reconst Loss: 18899.0703, KL Div: 851.6956177
Epoch[1/50], Step [201/600], Total Loss: 16937.0508, Reconst Loss: 15538.3174, KL Div: 1398.7341309
Epoch[1/50], Step [301/600], Total Loss: 16740.2148, Reconst Loss: 14925.6260, KL Div: 1814.5885010
Epoch[1/50], Step [401/600], Total Loss: 15652.4707, Reconst Loss: 13794.6904, KL Div: 1857.7805176
Epoch[1/50], Step [501/600], Total Loss: 15243.4199, Reconst Loss: 13151.3213, KL Div: 2092.0991211
Epoch[2/50], Step [1/600], Total Loss: 15273.3564, Reconst Loss: 13122.3945, KL Div: 2150.9619141
Epoch[2/50], Step [101/600], Total Loss: 14789.0986, Reconst Loss: 12573.1396, KL Div: 2215.9587402
Epoch[2/50], Step [201/600], Total Loss: 15081.9258, Reconst Loss: 12814.1484, KL Div: 2267.7775879
Epoch[2/50], Step [301/600], Total Loss: 13760.2021, Reconst Loss: 11583.2031, KL Div: 2176.9992676
Epoch[2/

Epoch[14/50], Step [401/600], Total Loss: 12035.3916, Reconst Loss: 9454.2236, KL Div: 2581.1677246
Epoch[14/50], Step [501/600], Total Loss: 12843.8887, Reconst Loss: 10243.4277, KL Div: 2600.4609375
Epoch[15/50], Step [1/600], Total Loss: 12785.8125, Reconst Loss: 10126.5410, KL Div: 2659.2714844
Epoch[15/50], Step [101/600], Total Loss: 12760.8887, Reconst Loss: 10143.3965, KL Div: 2617.4921875
Epoch[15/50], Step [201/600], Total Loss: 13222.0977, Reconst Loss: 10345.2109, KL Div: 2876.8864746
Epoch[15/50], Step [301/600], Total Loss: 13480.0098, Reconst Loss: 10668.7920, KL Div: 2811.2172852
Epoch[15/50], Step [401/600], Total Loss: 13197.3887, Reconst Loss: 10623.4902, KL Div: 2573.8986816
Epoch[15/50], Step [501/600], Total Loss: 12778.7998, Reconst Loss: 10051.1367, KL Div: 2727.6633301
Epoch[16/50], Step [1/600], Total Loss: 12893.8945, Reconst Loss: 10210.1084, KL Div: 2683.7856445
Epoch[16/50], Step [101/600], Total Loss: 12732.3730, Reconst Loss: 10038.7021, KL Div: 2693.671

Epoch[28/50], Step [201/600], Total Loss: 12938.3232, Reconst Loss: 10142.3096, KL Div: 2796.0136719
Epoch[28/50], Step [301/600], Total Loss: 12757.0449, Reconst Loss: 10146.5615, KL Div: 2610.4831543
Epoch[28/50], Step [401/600], Total Loss: 12875.3535, Reconst Loss: 10225.3047, KL Div: 2650.0488281
Epoch[28/50], Step [501/600], Total Loss: 13061.0293, Reconst Loss: 10451.2451, KL Div: 2609.7846680
Epoch[29/50], Step [1/600], Total Loss: 12956.7695, Reconst Loss: 10242.5127, KL Div: 2714.2570801
Epoch[29/50], Step [101/600], Total Loss: 12595.5918, Reconst Loss: 9897.7070, KL Div: 2697.8845215
Epoch[29/50], Step [201/600], Total Loss: 12927.9512, Reconst Loss: 10122.5156, KL Div: 2805.4353027
Epoch[29/50], Step [301/600], Total Loss: 13048.3809, Reconst Loss: 10135.8301, KL Div: 2912.5507812
Epoch[29/50], Step [401/600], Total Loss: 12341.0234, Reconst Loss: 9700.5840, KL Div: 2640.4389648
Epoch[29/50], Step [501/600], Total Loss: 13058.6035, Reconst Loss: 10467.6074, KL Div: 2590.99

Epoch[42/50], Step [1/600], Total Loss: 12351.3057, Reconst Loss: 9723.6865, KL Div: 2627.6193848
Epoch[42/50], Step [101/600], Total Loss: 12989.3398, Reconst Loss: 10193.9453, KL Div: 2795.3950195
Epoch[42/50], Step [201/600], Total Loss: 12400.7617, Reconst Loss: 9858.6709, KL Div: 2542.0908203
Epoch[42/50], Step [301/600], Total Loss: 13046.2617, Reconst Loss: 10348.4268, KL Div: 2697.8354492
Epoch[42/50], Step [401/600], Total Loss: 12212.1562, Reconst Loss: 9645.1836, KL Div: 2566.9721680
Epoch[42/50], Step [501/600], Total Loss: 12921.2119, Reconst Loss: 10186.7949, KL Div: 2734.4167480
Epoch[43/50], Step [1/600], Total Loss: 12406.3867, Reconst Loss: 9676.5762, KL Div: 2729.8107910
Epoch[43/50], Step [101/600], Total Loss: 12313.7969, Reconst Loss: 9760.2305, KL Div: 2553.5666504
Epoch[43/50], Step [201/600], Total Loss: 12760.5859, Reconst Loss: 10142.8340, KL Div: 2617.7517090
Epoch[43/50], Step [301/600], Total Loss: 13110.9980, Reconst Loss: 10275.3096, KL Div: 2835.6884766

# Design conv model

In [9]:
def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

# VAE model
class ConvVAE(nn.Module):
    def __init__(self, image_size=12*12*16, h_dim=400, z_dim=20):
        super(ConvVAE, self).__init__()
        self.linear1 = nn.Linear(image_size, h_dim)
        self.linear2 = nn.Linear(h_dim, z_dim)
        self.linear3 = nn.Linear(z_dim, h_dim)
        self.linear4 = nn.Linear(h_dim, 16*8*8)
        self.batch_norm = nn.BatchNorm2d(16)
        self.conv2D1 = nn.Conv2d(1, 16, 5)
        self.conv2D2 = nn.Conv2d(16, 16, 5)
        self.trans_conv2D_1 = nn.ConvTranspose2d(16, 16, 5)
        self.trans_conv2D_2 = nn.ConvTranspose2d(16, 1, 5)
    
    def encoder(self, x): #(batch, 28, 28, 1)
        x_conv = self.batch_norm(F.relu(self.conv2D1(x))) #(batch, 24, 24, 16)
        x_conv = self.batch_norm(F.relu(self.conv2D2(x_conv))) #(batch, 20, 20, 16)
        x_conv = self.batch_norm(F.relu(self.conv2D2(x_conv))) #(batch, 16, 16, 16)
        x_conv = self.batch_norm(F.relu(self.conv2D2(x_conv))) #(batch, 12, 12, 16)
        x_conv = F.relu(self.linear1(x_conv.view(-1, 12*12*16))) #(batch, 12*12*16) -- > #(batch, 400)
        z_mean = self.linear2(x_conv) #(batch, 20)
        z_log_var = self.linear2(x_conv) #(batch, 20)
        return z_mean, z_log_var
    
    def decoder(self, z):
        z = F.relu(self.linear3(z))
        z = F.relu(self.linear4(z))
        recon_x = z.resize(z.size(0), 16, 8, 8)
        recon_x = self.batch_norm(F.relu(self.trans_conv2D_1(recon_x)))
        recon_x = self.batch_norm(F.relu(self.trans_conv2D_1(recon_x)))
        recon_x = self.batch_norm(F.relu(self.trans_conv2D_1(recon_x)))
        recon_x = self.batch_norm(F.relu(self.trans_conv2D_1(recon_x)))
        recon_x = F.sigmoid(self.trans_conv2D_2(recon_x))
        return recon_x
    
    def reparameterize(self, mu, log_var):
        """"z = mean + eps * sigma where eps is sampled from N(0, 1)."""
        eps = to_var(torch.randn(mu.size(0), mu.size(1)))
        z = mu + eps * torch.exp(log_var/2)    # 2 for convert var to std
        return z
                     
    def forward(self, x):
        mu, log_var = self.encoder(x) # torch.chunk(h, 2, dim=1)  # mean and log variance.
        z = self.reparameterize(mu, log_var)
        out = self.decoder(z)
        return out, mu, log_var
    
    def sample(self, z):
        return self.decoder(z)

# Train conv model

In [11]:
#Define model
vae = ConvVAE()

if torch.cuda.is_available():
    vae.cuda()
    
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
iter_per_epoch = len(data_loader)
data_iter = iter(data_loader)

# fixed inputs for debugging
fixed_z = to_var(torch.randn(100, 20))
fixed_x, _ = next(data_iter)
torchvision.utils.save_image(fixed_x.cpu(), './data/real_images.png')
fixed_x = to_var(fixed_x)

for epoch in range(50):
    for i, (images, _) in enumerate(data_loader):
        
        images = to_var(images)
        out, mu, log_var = vae(images)
        
        # Compute reconstruction loss and kl divergence
        reconst_loss = F.binary_cross_entropy(out, images, size_average=False)
        kl_divergence = torch.sum(0.5 * (mu**2 + torch.exp(log_var) - log_var -1))
        
        # Backprop + Optimize
        total_loss = reconst_loss + kl_divergence
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print ("Epoch[%d/%d], Step [%d/%d], Total Loss: %.4f, "
                   "Reconst Loss: %.4f, KL Div: %.7f" 
                   %(epoch+1, 50, i+1, iter_per_epoch, total_loss.data[0], 
                     reconst_loss.data[0], kl_divergence.data[0]))
    
    # Save the reconstructed images
    reconst_images, _, _ = vae(fixed_x)
    reconst_images = reconst_images.resize(reconst_images.size(0), 1, 28, 28)
    torchvision.utils.save_image(reconst_images.data.cpu(), 
        './data/Conv2D/reconst_images_%d.png' %(epoch+1))

Epoch[1/50], Step [1/600], Total Loss: 54833.7773, Reconst Loss: 54814.1133, KL Div: 19.6647511
Epoch[1/50], Step [101/600], Total Loss: 19597.5410, Reconst Loss: 18214.1035, KL Div: 1383.4370117
Epoch[1/50], Step [201/600], Total Loss: 15769.4365, Reconst Loss: 13937.5801, KL Div: 1831.8562012
Epoch[1/50], Step [301/600], Total Loss: 15296.2949, Reconst Loss: 13054.0576, KL Div: 2242.2368164
Epoch[1/50], Step [401/600], Total Loss: 15303.2285, Reconst Loss: 12995.6611, KL Div: 2307.5671387
Epoch[1/50], Step [501/600], Total Loss: 14686.3906, Reconst Loss: 12294.1777, KL Div: 2392.2126465
Epoch[2/50], Step [1/600], Total Loss: 14962.1816, Reconst Loss: 12686.5479, KL Div: 2275.6340332
Epoch[2/50], Step [101/600], Total Loss: 14555.4727, Reconst Loss: 12280.4082, KL Div: 2275.0642090
Epoch[2/50], Step [201/600], Total Loss: 13933.2549, Reconst Loss: 11559.5010, KL Div: 2373.7536621
Epoch[2/50], Step [301/600], Total Loss: 13660.2422, Reconst Loss: 11308.8398, KL Div: 2351.4025879
Epoch[

Epoch[14/50], Step [501/600], Total Loss: 12595.4951, Reconst Loss: 10182.8066, KL Div: 2412.6884766
Epoch[15/50], Step [1/600], Total Loss: 12977.0127, Reconst Loss: 10382.8408, KL Div: 2594.1721191
Epoch[15/50], Step [101/600], Total Loss: 12452.5986, Reconst Loss: 10004.2793, KL Div: 2448.3193359
Epoch[15/50], Step [201/600], Total Loss: 12509.1504, Reconst Loss: 9973.4902, KL Div: 2535.6601562
Epoch[15/50], Step [301/600], Total Loss: 12219.8555, Reconst Loss: 9708.3311, KL Div: 2511.5249023
Epoch[15/50], Step [401/600], Total Loss: 12611.0332, Reconst Loss: 10042.5293, KL Div: 2568.5041504
Epoch[15/50], Step [501/600], Total Loss: 12921.8936, Reconst Loss: 10416.3389, KL Div: 2505.5546875
Epoch[16/50], Step [1/600], Total Loss: 12675.0723, Reconst Loss: 10079.0371, KL Div: 2596.0356445
Epoch[16/50], Step [101/600], Total Loss: 12374.3887, Reconst Loss: 9881.1738, KL Div: 2493.2150879
Epoch[16/50], Step [201/600], Total Loss: 12270.0254, Reconst Loss: 9724.1709, KL Div: 2545.854492

Epoch[28/50], Step [301/600], Total Loss: 11750.4277, Reconst Loss: 9181.6914, KL Div: 2568.7365723
Epoch[28/50], Step [401/600], Total Loss: 12558.6416, Reconst Loss: 9988.2910, KL Div: 2570.3505859
Epoch[28/50], Step [501/600], Total Loss: 12728.8594, Reconst Loss: 10028.6777, KL Div: 2700.1816406
Epoch[29/50], Step [1/600], Total Loss: 12700.5273, Reconst Loss: 10100.7939, KL Div: 2599.7333984
Epoch[29/50], Step [101/600], Total Loss: 12655.4658, Reconst Loss: 10058.7598, KL Div: 2596.7062988
Epoch[29/50], Step [201/600], Total Loss: 12530.5781, Reconst Loss: 9836.0400, KL Div: 2694.5385742
Epoch[29/50], Step [301/600], Total Loss: 12462.0684, Reconst Loss: 9888.8564, KL Div: 2573.2119141
Epoch[29/50], Step [401/600], Total Loss: 12366.0469, Reconst Loss: 9808.0801, KL Div: 2557.9663086
Epoch[29/50], Step [501/600], Total Loss: 12365.9668, Reconst Loss: 9683.7119, KL Div: 2682.2543945
Epoch[30/50], Step [1/600], Total Loss: 11657.2686, Reconst Loss: 9065.0967, KL Div: 2592.1716309
E

Epoch[42/50], Step [201/600], Total Loss: 12549.2275, Reconst Loss: 9950.3555, KL Div: 2598.8720703
Epoch[42/50], Step [301/600], Total Loss: 12140.4268, Reconst Loss: 9475.9053, KL Div: 2664.5212402
Epoch[42/50], Step [401/600], Total Loss: 11913.6250, Reconst Loss: 9335.5068, KL Div: 2578.1186523
Epoch[42/50], Step [501/600], Total Loss: 12333.9092, Reconst Loss: 9592.9170, KL Div: 2740.9921875
Epoch[43/50], Step [1/600], Total Loss: 12447.9834, Reconst Loss: 9778.2012, KL Div: 2669.7824707
Epoch[43/50], Step [101/600], Total Loss: 12699.3164, Reconst Loss: 10057.1016, KL Div: 2642.2150879
Epoch[43/50], Step [201/600], Total Loss: 12614.4629, Reconst Loss: 10002.2061, KL Div: 2612.2570801
Epoch[43/50], Step [301/600], Total Loss: 11777.5518, Reconst Loss: 9342.7617, KL Div: 2434.7902832
Epoch[43/50], Step [401/600], Total Loss: 11665.8691, Reconst Loss: 9051.5928, KL Div: 2614.2761230
Epoch[43/50], Step [501/600], Total Loss: 12098.1592, Reconst Loss: 9479.6318, KL Div: 2618.5270996
