In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import random
import os
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision

import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

import os
from google.colab import files

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import seaborn as sns
import math
from tqdm.autonotebook import tqdm

In [None]:
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw
Processing...
Done!




In [None]:
def train(epoch, device, weigth = 1):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var, weigth)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [None]:
def test(device, weigth = 1):
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon, mu, log_var = vae(data)
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var, weigth).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

# Task 1: Design the autoencoder structured network for MNIST

In [None]:
# YOUR CODE!!

class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        self.z_dim = z_dim
        self.enc_fc1 = nn.Linear(784, h_dim1)
        self.enc_fc2 = nn.Linear(h_dim1, h_dim2)
        self.enc_fc31 = nn.Linear(h_dim2, z_dim)    # mu
        self.enc_fc32 = nn.Linear(h_dim2, z_dim)    # log_var
        
        # sampling z

        # decoder part   
        self.dec_fc1 = nn.Linear(z_dim, h_dim2)
        self.dec_fc2 = nn.Linear(h_dim2, h_dim1)
        self.dec_fc3 = nn.Linear(h_dim1, 784)
        
    def encoder(self, x):
        # return mu, log_var
        out = F.relu(self.enc_fc1(x))
        out = F.relu(self.enc_fc2(out))
        return self.enc_fc31(out), self.enc_fc32(out)
    
    def sampling(self, mu, log_var):
        # return z sample
        # use reparametrization trick!
        sigma = torch.exp(log_var)
        std_z = torch.torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float().to(device)
        return mu + sigma * Variable(std_z, requires_grad=False)
        
    def decoder(self, z):
        # return generated img
        out = F.relu(self.dec_fc1(z))
        out = F.relu(self.dec_fc2(out))
        return torch.sigmoid(self.dec_fc3(out))
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=50)
if torch.cuda.is_available():
    vae.to(device)

# Task 2: Design the loss function for autoencoder with weight of KLD term

In [None]:
optimizer = optim.Adam(vae.parameters())
def loss_function(recon_x, x, mu, log_var, weight = 1):
    # return reconstruction error + weight * KL divergence losses
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = weight * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE - KLD

In [None]:
from tqdm.autonotebook import tqdm

for epoch in tqdm(range(1, 51)):
    train(epoch, device)
    test(device)

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

====> Epoch: 1 Average loss: 188.2943
====> Test set loss: 149.8646
====> Epoch: 2 Average loss: 136.9308
====> Test set loss: 127.3043
====> Epoch: 3 Average loss: 123.9106
====> Test set loss: 119.3393
====> Epoch: 4 Average loss: 118.1875
====> Test set loss: 115.2025
====> Epoch: 5 Average loss: 114.5394
====> Test set loss: 111.9647
====> Epoch: 6 Average loss: 111.6079
====> Test set loss: 109.9293
====> Epoch: 7 Average loss: 109.7665
====> Test set loss: 108.9745
====> Epoch: 8 Average loss: 108.4524
====> Test set loss: 107.7843
====> Epoch: 9 Average loss: 107.4887
====> Test set loss: 107.4524
====> Epoch: 10 Average loss: 106.6609
====> Test set loss: 106.0942
====> Epoch: 11 Average loss: 105.8767
====> Test set loss: 105.8695
====> Epoch: 12 Average loss: 105.3583
====> Test set loss: 105.3749
====> Epoch: 13 Average loss: 104.7979
====> Test set loss: 104.9743
====> Epoch: 14 Average loss: 104.3604
====> Test set loss: 104.3585
====> Epoch: 15 Average loss: 103.9747
====

# Task 3

In [None]:
with torch.no_grad():
    # random sampled latent points
    z_dim = 50
    z = torch.randn(64, z_dim)     
    z = torch.tensor(z).to(device)
    sample = vae.decoder(z.float())
    if not os.path.exists('./samples'):
        os.makedirs('./samples')
    save_image(sample.view(64, 1, 28, 28), './samples/sample' + '.png')

  """


In [None]:
!ls -la ./samples/

total 308
drwxr-xr-x 2 root root  4096 Jun 22 15:54 .
drwxr-xr-x 1 root root  4096 Jun 22 15:19 ..
-rw-r--r-- 1 root root 61740 Jun 22 15:36 sample_10.png
-rw-r--r-- 1 root root 63600 Jun 22 15:45 sample_25.png
-rw-r--r-- 1 root root 63451 Jun 22 15:28 sample_2.png
-rw-r--r-- 1 root root 62783 Jun 22 15:54 sample_50.png
-rw-r--r-- 1 root root 40983 Jun 22 15:19 sample.png


In [None]:
# files.download('./samples/sample.png')

# Homework 4

### Functions

In [None]:
def train_vae(vae, weigth=1):
    optimizer = optim.Adam(vae.parameters())
    vae.train()
    for epoch in tqdm(range(1, 51)):
        train_epoch(vae, epoch, device, optimizer, weigth=weigth)
        test_epoch(vae, device, weigth=weigth)

In [None]:
def random_samples_image(vae, n_samples, z_dim, img_name):
    with torch.no_grad():
        vae.eval()
        z = torch.randn(n_samples, z_dim)
        z = torch.tensor(z).to(device)
        samples = vae.decoder(z.float())
        if not os.path.exists('./samples'):
            os.makedirs('./samples')
        print("Generating image...")
        save_image(samples.view(n_samples, 1, 28, 28), './samples/' + img_name + '.png', nrow = int(math.sqrt(n_samples)))

In [21]:
def grid_samples_image(vae, n_samples, img_name):
    """
    z_dim must be equal to 2
    """
    lat = int(math.sqrt(n_samples))
    with torch.no_grad():
        vae.eval()
        span = torch.linspace(start = -4, end = 4, steps = lat)
        elems = []
        for a in range(lat):
            for b in range(lat):
                elems.append(torch.cat([span[a].unsqueeze(dim=0).unsqueeze(dim=1), span[b].unsqueeze(dim=0).unsqueeze(dim=1)], dim = 1))
        z1 = torch.cat([elem for elem in elems], dim = 0)
        z = torch.tensor(z1).to(device)
        samples = vae.decoder(z.float())
        if not os.path.exists('./samples'):
            os.makedirs('./samples')
        print("Generating image...")
        save_image(samples.view(n_samples, 1, 28, 28), './samples/' + img_name + '.png', nrow = int(math.sqrt(n_samples)))

In [None]:
def show_samples(images, row, col, image_shape, name="Unknown", save=False, shift=False):
    sns.axes_style("white")
    num_images = row*col
    if shift:
        images = (images+1.)/2.
    fig = plt.figure(figsize=(col, row))
    grid = ImageGrid(fig, 111,
                     nrows_ncols=(row, col),
                     axes_pad=0.)
    for i in range(num_images):
        im = images[i].reshape(image_shape)
        axis = grid[i]
        axis.axis('off')
        axis.imshow(im)
    plt.axis('off')
    plt.tight_layout()
    if save:
        fig.savefig('./samples/'+name+'.png', bbox_inches="tight", pad_inches=0, format='png')
    else:
        plt.show()

In [None]:
def train_epoch(vae, epoch, device, optimizer, weigth = 1):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var, weigth)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [None]:
def test_epoch(vae, device, weigth = 1):
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon, mu, log_var = vae(data)
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var, weigth).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

## Problem2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
configurations = [2, 10, 25, 50]

for conf in configurations:
    vae_c = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=conf)
    if torch.cuda.is_available():
        vae_c.to(device)
    train_vae(vae_c)
    random_samples_image(vae_c, 100, conf, 'sample_'+str(conf))
    # files.download('./samples/' + 'sample_' + str(conf) + '.png')

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

====> Epoch: 1 Average loss: 188.9437
====> Test set loss: 167.2820
====> Epoch: 2 Average loss: 162.0612
====> Test set loss: 158.2659
====> Epoch: 3 Average loss: 155.5265
====> Test set loss: 154.1741
====> Epoch: 4 Average loss: 151.5720
====> Test set loss: 149.9337
====> Epoch: 5 Average loss: 148.7195
====> Test set loss: 148.3980
====> Epoch: 6 Average loss: 146.8747
====> Test set loss: 146.9166
====> Epoch: 7 Average loss: 145.5129
====> Test set loss: 145.8176
====> Epoch: 8 Average loss: 144.4485
====> Test set loss: 145.1548
====> Epoch: 9 Average loss: 143.7400
====> Test set loss: 144.1086
====> Epoch: 10 Average loss: 142.9228
====> Test set loss: 143.3332
====> Epoch: 11 Average loss: 142.2911
====> Test set loss: 142.9987
====> Epoch: 12 Average loss: 141.7257
====> Test set loss: 142.5257
====> Epoch: 13 Average loss: 141.2695
====> Test set loss: 142.6247
====> Epoch: 14 Average loss: 140.8339
====> Test set loss: 141.8141
====> Epoch: 15 Average loss: 140.4779
====

  """


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

====> Epoch: 1 Average loss: 182.6478
====> Test set loss: 147.9590
====> Epoch: 2 Average loss: 139.5121
====> Test set loss: 133.1990
====> Epoch: 3 Average loss: 130.7646
====> Test set loss: 126.7415
====> Epoch: 4 Average loss: 124.3272
====> Test set loss: 121.3614
====> Epoch: 5 Average loss: 119.9804
====> Test set loss: 117.4320
====> Epoch: 6 Average loss: 116.2300
====> Test set loss: 114.3505
====> Epoch: 7 Average loss: 113.6362
====> Test set loss: 112.0421
====> Epoch: 8 Average loss: 111.2544
====> Test set loss: 109.9776
====> Epoch: 9 Average loss: 109.5525
====> Test set loss: 108.9124
====> Epoch: 10 Average loss: 108.2868
====> Test set loss: 107.5888
====> Epoch: 11 Average loss: 107.2867
====> Test set loss: 107.1606
====> Epoch: 12 Average loss: 106.3884
====> Test set loss: 106.2211
====> Epoch: 13 Average loss: 105.7226
====> Test set loss: 105.5373
====> Epoch: 14 Average loss: 105.1294
====> Test set loss: 105.4979
====> Epoch: 15 Average loss: 104.5571
====

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

====> Epoch: 1 Average loss: 182.5759
====> Test set loss: 141.9503
====> Epoch: 2 Average loss: 131.6153
====> Test set loss: 123.1641
====> Epoch: 3 Average loss: 121.1146
====> Test set loss: 117.6153
====> Epoch: 4 Average loss: 116.7961
====> Test set loss: 114.4916
====> Epoch: 5 Average loss: 113.8389
====> Test set loss: 112.1025
====> Epoch: 6 Average loss: 111.5679
====> Test set loss: 109.9739
====> Epoch: 7 Average loss: 109.7984
====> Test set loss: 108.9351
====> Epoch: 8 Average loss: 108.3127
====> Test set loss: 107.6030
====> Epoch: 9 Average loss: 107.2656
====> Test set loss: 106.5361
====> Epoch: 10 Average loss: 106.4299
====> Test set loss: 105.9099
====> Epoch: 11 Average loss: 105.7120
====> Test set loss: 105.4084
====> Epoch: 12 Average loss: 105.0970
====> Test set loss: 105.2605
====> Epoch: 13 Average loss: 104.6125
====> Test set loss: 104.4380
====> Epoch: 14 Average loss: 104.1038
====> Test set loss: 104.7487
====> Epoch: 15 Average loss: 103.6634
====

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

====> Epoch: 1 Average loss: 190.5092
====> Test set loss: 155.3668
====> Epoch: 2 Average loss: 139.1445
====> Test set loss: 128.3710
====> Epoch: 3 Average loss: 125.3655
====> Test set loss: 121.7343
====> Epoch: 4 Average loss: 120.7008
====> Test set loss: 117.8550
====> Epoch: 5 Average loss: 117.4849
====> Test set loss: 115.2343
====> Epoch: 6 Average loss: 114.2258
====> Test set loss: 111.8053
====> Epoch: 7 Average loss: 111.3277
====> Test set loss: 109.8322
====> Epoch: 8 Average loss: 109.5953
====> Test set loss: 108.6706
====> Epoch: 9 Average loss: 108.3141
====> Test set loss: 107.6325
====> Epoch: 10 Average loss: 107.3550
====> Test set loss: 106.7271
====> Epoch: 11 Average loss: 106.5032
====> Test set loss: 106.1911
====> Epoch: 12 Average loss: 105.8251
====> Test set loss: 105.6138
====> Epoch: 13 Average loss: 105.2146
====> Test set loss: 105.2452
====> Epoch: 14 Average loss: 104.7468
====> Test set loss: 104.8349
====> Epoch: 15 Average loss: 104.2516
====

In [None]:
for conf in configurations:
    files.download('./samples/' + 'sample_' + str(conf) + '.png')

## Problem3

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
configurations = [1, 5, 10, 40]
for conf in configurations:
    vae_c = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)
    if torch.cuda.is_available():
        vae_c.to(device)
    train_vae(vae_c, weigth=conf)
    grid_samples_image(vae_c, 100, 'sample_prob3_'+str(conf))
    # files.download('./samples/' + 'sample_prob3_' + str(conf) + '.png')


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

====> Epoch: 1 Average loss: 190.6835
====> Test set loss: 167.1185
====> Epoch: 2 Average loss: 161.3808
====> Test set loss: 157.3404
====> Epoch: 3 Average loss: 155.2925
====> Test set loss: 153.3622
====> Epoch: 4 Average loss: 151.8807
====> Test set loss: 151.0202
====> Epoch: 5 Average loss: 149.2092
====> Test set loss: 148.3842
====> Epoch: 6 Average loss: 147.3971
====> Test set loss: 146.7403
====> Epoch: 7 Average loss: 145.9970
====> Test set loss: 146.0556
====> Epoch: 8 Average loss: 145.1154
====> Test set loss: 145.0354
====> Epoch: 9 Average loss: 144.0939
====> Test set loss: 144.1388
====> Epoch: 10 Average loss: 143.4040
====> Test set loss: 143.6863
====> Epoch: 11 Average loss: 142.8317
====> Test set loss: 143.8473
====> Epoch: 12 Average loss: 142.2215
====> Test set loss: 142.7365
====> Epoch: 13 Average loss: 141.6711
====> Test set loss: 142.3826
====> Epoch: 14 Average loss: 141.0656
====> Test set loss: 141.6166
====> Epoch: 15 Average loss: 140.5619
====

  


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

====> Epoch: 1 Average loss: 205.0168
====> Test set loss: 197.1689
====> Epoch: 2 Average loss: 193.1473
====> Test set loss: 185.7422
====> Epoch: 3 Average loss: 180.5852
====> Test set loss: 176.5383
====> Epoch: 4 Average loss: 174.8278
====> Test set loss: 173.4302
====> Epoch: 5 Average loss: 172.5125
====> Test set loss: 171.8625
====> Epoch: 6 Average loss: 170.8609
====> Test set loss: 170.6834
====> Epoch: 7 Average loss: 169.3240
====> Test set loss: 169.0537
====> Epoch: 8 Average loss: 168.2442
====> Test set loss: 167.9535
====> Epoch: 9 Average loss: 167.4120
====> Test set loss: 167.2114
====> Epoch: 10 Average loss: 166.6061
====> Test set loss: 166.5927
====> Epoch: 11 Average loss: 166.0976
====> Test set loss: 165.9626
====> Epoch: 12 Average loss: 165.6000
====> Test set loss: 165.7482
====> Epoch: 13 Average loss: 165.0908
====> Test set loss: 164.9074
====> Epoch: 14 Average loss: 164.6162
====> Test set loss: 164.9282
====> Epoch: 15 Average loss: 164.2298
====

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

====> Epoch: 1 Average loss: 207.5477
====> Test set loss: 200.3913
====> Epoch: 2 Average loss: 200.1381
====> Test set loss: 198.8195
====> Epoch: 3 Average loss: 198.8852
====> Test set loss: 197.8037
====> Epoch: 4 Average loss: 198.1436
====> Test set loss: 197.3728
====> Epoch: 5 Average loss: 197.9421
====> Test set loss: 197.5303
====> Epoch: 6 Average loss: 197.9079
====> Test set loss: 197.5847
====> Epoch: 7 Average loss: 197.8090
====> Test set loss: 197.5202
====> Epoch: 8 Average loss: 197.6560
====> Test set loss: 197.5633
====> Epoch: 9 Average loss: 197.6496
====> Test set loss: 197.4079
====> Epoch: 10 Average loss: 197.6509
====> Test set loss: 197.4668
====> Epoch: 11 Average loss: 197.4093
====> Test set loss: 196.9697
====> Epoch: 12 Average loss: 197.0112
====> Test set loss: 196.2944
====> Epoch: 13 Average loss: 195.6960
====> Test set loss: 195.3329
====> Epoch: 14 Average loss: 194.5814
====> Test set loss: 193.9885
====> Epoch: 15 Average loss: 194.0621
====

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

====> Epoch: 1 Average loss: 212.0386
====> Test set loss: 206.1523
====> Epoch: 2 Average loss: 206.2594
====> Test set loss: 206.0388
====> Epoch: 3 Average loss: 206.1101
====> Test set loss: 205.7875
====> Epoch: 4 Average loss: 206.0452
====> Test set loss: 205.8435
====> Epoch: 5 Average loss: 206.0496
====> Test set loss: 205.7815
====> Epoch: 6 Average loss: 205.8700
====> Test set loss: 205.5692
====> Epoch: 7 Average loss: 205.8948
====> Test set loss: 205.4356
====> Epoch: 8 Average loss: 205.8604
====> Test set loss: 205.8674
====> Epoch: 9 Average loss: 205.8137
====> Test set loss: 205.1000
====> Epoch: 10 Average loss: 205.8199
====> Test set loss: 205.6185
====> Epoch: 11 Average loss: 205.8536
====> Test set loss: 205.6247
====> Epoch: 12 Average loss: 205.7695
====> Test set loss: 205.4526
====> Epoch: 13 Average loss: 205.8425
====> Test set loss: 205.4761
====> Epoch: 14 Average loss: 205.7388
====> Test set loss: 205.7068
====> Epoch: 15 Average loss: 205.7242
====

In [24]:
for conf in configurations:
    files.download('./samples/' + 'sample_prob3_' + str(conf) + '.png')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>