# convolutional VAE

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
% matplotlib inline

## Set Hyperparameter

In [2]:
batch_size=64
learning_rate=0.002
num_epoch=15
leak=0.05
drop_rate=0.02
z_dim=20

## Data: mnist

In [3]:
# Download data
mnist_train = dset.MNIST("./", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
mnist_test = dset.MNIST("./", train=False, transform=transforms.ToTensor(), target_transform=None, download=True)

In [4]:
train_loader=torch.utils.data.DataLoader(mnist_train,batch_size=batch_size, shuffle=True, num_workers=2,drop_last=True)
test_loader=torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=2,drop_last=True)
# Training에서는 매 epoch마다 shuffling 해줘야 하지만 test에선 shuffle 안 한다. For Reproducibliliy
# drop_last: 부족한 batch는 그냥 drop한다.

## Model


### Encoder
* Two outputs
    * z_mu
    * z_logvar

In [5]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        
        # output size = (image_width - kernel_width)/stride +1
        
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size,
        #                 stride=1, padding=0, dilation=1,
        #                 groups=1, bias=True)

        
        self.conv_layer = nn.Sequential(
                        nn.Conv2d(1,10,3,padding=1),  # ((28+1*2) - 3)/1 +1 = 28
                        nn.ReLU(),
                        nn.BatchNorm2d(10),
                        nn.Dropout(p=drop_rate,inplace=True),
                        # batch_size*  28*28  *10 featuremaps
            
                        nn.Conv2d(10,16,3,padding=1), # ((28+1*2) - 3)/1 +1 = 28
                        nn.ReLU(),
                        nn.BatchNorm2d(16),
                        nn.Dropout(p=drop_rate,inplace=True),
                        nn.MaxPool2d(2,2),
                        # batch_size*  14*14  *16 featuremaps
                
                        nn.Conv2d(16,28,3, padding=1),  # ((28+1*2) - 3)/1 +1 = 28
                        nn.ReLU(),
                        nn.BatchNorm2d(28),
                        nn.MaxPool2d(2,2),
                        # batch_size*  7*7  *28 featuremaps
            
        )
        
        self.fc_layer = nn.Sequential(
                        nn.Linear(7*7*28,512),
                        nn.LeakyReLU(leak),
                        nn.Linear(512,128),
                        nn.LeakyReLU(leak)
        )
   
        # ===== Two output vectors generated by the Encoder =====
        
        # One for z_mu
        self.z_mu=nn.Sequential(
            nn.Linear(128,z_dim),
            nn.LeakyReLU(leak)
        )
        
        # Another for z_logvar
        self.z_logvar=nn.Sequential(
            nn.Linear(128,z_dim),
            nn.LeakyReLU(leak)
        )
        
        
    def forward(self,x):
        out=self.conv_layer(x)
        print("conv_layer_output: ",out.size())
        out=out.view(batch_size,-1)
        print("after_viewchange_output: ",out.size())
        
        out=self.fc_layer(out)
        
        print("fc_layer_output: ",out.size())
        z_mu=self.z_mu(out)
        z_logvar=self.z_logvar(out)
        
        return z_mu,z_logvar

### Decoder
* input: sampled_z
* output: reconstructed image

In [6]:
# torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, 
#   output_padding=0, groups=1, bias=True, dilation=1)

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        
        self.fc_layer=nn.Sequential(
            nn.Linear(z_dim,128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(leak),
            nn.Dropout(p=drop_rate,inplace=True),
            
            nn.Linear(128,256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(leak),
            nn.Dropout(p=drop_rate,inplace=True),
            
            nn.Linear(256,512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(leak),
            nn.Dropout(p=drop_rate,inplace=True),
            
            nn.Linear(512,784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(leak)
        )
        
        self.transConv_layer=nn.Sequential(
            nn.ConvTranspose2d(784,512,3,2,1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512,256,3,2,1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256,256,3,2,1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256,128,3,2,1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128,64,3,2,1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64,32,3,2,1),
            nn.ReLU()
        )
    
    # decode the sampled_z
    def forward(self,sampled_z):
        print("sampled_z: ",sampled_z.size())
        out=self.fc_layer(sampled_z)
        print("after_fc_layer: ",out.size())
#         out=out.view(batch_size,-1,28,28)
#         print("after_viewchange_output: ",out.size())
        out=self.transConv_layer(out)
        print("after_transConv_layer: ",out.size())
        
        return out

### the Model as a whole
* encoder & decoder
* sample z

In [7]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE,self).__init__()
        
        self.encoder=encoder
        self.decoder=decoder
    
    def sample_z(self, z_mu,z_logvar):
        
        print("z_mu: ",z_mu.size())
        print("z_logvar: ",z_logvar.size())
        # sample epsilon ~ N(0, 1)
        epsilon=Variable(torch.randn(batch_size,z_dim),requires_grad=False)
        
        # reparameterization trick
        sampled_z=z_mu+torch.mul(torch.exp(z_logvar),epsilon)
       
        return sampled_z
        
            
    def forward(self,x):
        # encode the latent variable from the image
        # sample z from the encoded result
        # decode the sampled z
        print("==============Encoder==============")
        z_mu,z_logvar=self.encoder(x)
        
        self.z_mu=z_mu
        self.z_logvar=z_logvar
        print("==============Sample Z==============")
        sampled_z=self.sample_z(z_mu,z_logvar)
        print("==============Decoder==============")
        result=self.decoder(sampled_z)
        
        return result

In [8]:
model=VAE(Encoder(),Decoder())