In [None]:
!nvidia-smi

Sat Aug  8 20:15:33 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.57       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   70C    P8    12W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import cv2
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
from torchvision.utils import save_image

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, kernels_per_layer, nout):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin * kernels_per_layer, kernel_size=5, padding=2, groups=nin)
        self.pointwise = nn.Conv2d(nin * kernels_per_layer, nout, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

In [None]:
def swish(x):
    return x * torch.sigmoid(x)

In [None]:
class ChannelSELayer(nn.Module):
    """
    Re-implementation of Squeeze-and-Excitation (SE) block described in:
        *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
    """

    def __init__(self, num_channels, reduction_ratio=2):
        """
        :param num_channels: No of input channels
        :param reduction_ratio: By how much should the num_channels should be reduced
        """
        super(ChannelSELayer, self).__init__()
        num_channels_reduced = num_channels // reduction_ratio
        self.reduction_ratio = reduction_ratio
        self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
        self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_tensor):
        """
        :param input_tensor: X, shape = (batch_size, num_channels, H, W)
        :return: output tensor
        """
        batch_size, num_channels, H, W = input_tensor.size()
        # Average along each channel
        squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2)

        # channel excitation
        fc_out_1 = self.relu(self.fc1(squeeze_tensor))
        fc_out_2 = self.sigmoid(self.fc2(fc_out_1))

        a, b = squeeze_tensor.size()
        output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1))
        return output_tensor

In [None]:
class dec_res(nn.Module):
  def __init__(self,in_channel):
    super(dec_res,self).__init__()
    self.bn1 = nn.BatchNorm2d(in_channel)
    self.c1 = nn.Conv2d(in_channels=in_channel,out_channels=2*in_channel,kernel_size=1,stride=1,padding=0)
    self.bn2 = nn.BatchNorm2d(2*in_channel)
    self.dc1 = depthwise_separable_conv(nin=2*in_channel,kernels_per_layer=3,nout=2*in_channel)
    self.bn3 = nn.BatchNorm2d(2*in_channel)
    self.c2 = nn.Conv2d(in_channels=2*in_channel,out_channels=in_channel,kernel_size=1,stride=1,padding=0)
    self.bn4 = nn.BatchNorm2d(in_channel)
    self.SE = ChannelSELayer(in_channel)
  def forward(self,x1):
    x = self.c1(self.bn1(x1))
    x = swish(self.bn2(x))
    x = self.dc1(x)
    x = swish(self.bn3(x))
    x = self.bn4(self.c2(x))
    x = self.SE(x)
    return x+x1


In [None]:
class enc_res(nn.Module):
  def __init__(self,in_channel):
    super(enc_res,self).__init__()
    self.bn1 = nn.BatchNorm2d(in_channel)
    self.c1 = nn.Conv2d(in_channels=in_channel,out_channels=2*in_channel,kernel_size=3,stride=1,padding=1)
    self.bn2 = nn.BatchNorm2d(2*in_channel)
    self.c2 = nn.Conv2d(in_channels=2*in_channel,out_channels=in_channel,kernel_size=3,stride=1,padding=1)
    self.bn3 = nn.BatchNorm2d(in_channel)
    self.SE = ChannelSELayer(in_channel)
  def forward(self,x1):
    x = self.c1(swish(self.bn1(x1)))
    x = self.c2(swish(self.bn2(x)))
    x = self.SE(x)
    return x+x1


In [None]:
class NVAE(nn.Module):
  def __init__(self,start_channel,original_dim):
    super(NVAE,self).__init__()
    self.original_dim = original_dim
    self.conv1 = nn.Conv2d(in_channels=start_channel,out_channels=8,kernel_size=3,stride=1,padding=1)
    self.encblock1 = enc_res(8)
    self.dsconv1 = nn.Conv2d(in_channels=8,out_channels=8,kernel_size=2,stride=2,padding=0)
    self.encblock2 = enc_res(8)
    self.dsconv2 = nn.Conv2d(in_channels=8,out_channels=8,kernel_size=2,stride=2,padding=0)

    self.qmu1 = nn.Linear(original_dim*original_dim*2,original_dim*original_dim*2)
    self.qvar1 = nn.Linear(original_dim*original_dim*2,original_dim*original_dim*2)
    
    self.qmu0 = nn.Linear(original_dim*original_dim//2,original_dim*original_dim//2)
    self.qvar0 = nn.Linear(original_dim*original_dim//2,original_dim*original_dim//2)

    self.pmu1 = nn.Linear(original_dim*original_dim*2,original_dim*original_dim*2)
    self.pvar1 = nn.Linear(original_dim*original_dim*2,original_dim*original_dim*2)

    self.decblock1 = dec_res(8)
    self.usconv1 = nn.ConvTranspose2d(in_channels=8,out_channels=8,kernel_size=2,stride=2,padding=0)
    self.decblock2 = dec_res(16)
    self.usconv2 = nn.ConvTranspose2d(in_channels=16,out_channels=16,kernel_size=2,stride=2,padding=0)
    self.decblock3 = dec_res(16)
    self.finconv = nn.Conv2d(in_channels=16,out_channels=start_channel,kernel_size=3,stride=1,padding=1)
  
  def forward(self,x):
    z1 = self.dsconv1(self.encblock1(self.conv1(x)))
    z0 = self.dsconv2(self.encblock2(z1))

    qmu0 = self.qmu0(z0.reshape(z0.shape[0],self.original_dim*self.original_dim//2))
    qvar0 = self.qvar0(z0.reshape(z0.shape[0],self.original_dim*self.original_dim//2))

    qmu1 = self.qmu1(z1.reshape(z1.shape[0],self.original_dim*self.original_dim*2))
    qvar1 = self.qvar1(z1.reshape(z1.shape[0],self.original_dim*self.original_dim*2))

    stdvar0 = qvar0.mul(0.5).exp_()
    stdvar1 = qvar1.mul(0.5).exp_()

    e0 = torch.randn(qmu0.shape).to(device)
    ez0 = qmu0+e0*stdvar0
    ez0 = ez0.reshape(ez0.shape[0],8,self.original_dim//4,self.original_dim//4)
    ez1 = self.usconv1(self.decblock1(ez0))

    pmu1 = self.pmu1(ez1.reshape(ez1.shape[0],self.original_dim*self.original_dim*2))
    pvar1 = self.pvar1(ez1.reshape(ez1.shape[0],self.original_dim*self.original_dim*2))

    pstdvar1 = pvar1.mul(0.5).exp_()

    e2 = torch.randn(qmu1.shape).to(device)
    ez2 = pmu1+qmu1 + e2*pstdvar1*stdvar1
    ez2 = ez2.reshape(ez2.shape[0],8,self.original_dim//2,self.original_dim//2)
    
    final = torch.cat((ez1,ez2),1)

    recons = nn.Sigmoid()(self.finconv(self.decblock3(self.usconv2(self.decblock2(final)))))

    return qmu0,qvar0,qmu1,qvar1,pmu1,pvar1,recons

  def sample(self,bs):
    e = torch.randn([bs,8,self.original_dim//4,self.original_dim//4]).to(device)
    ez1 = self.usconv1(self.decblock1(e))

    pmu1 = self.pmu1(ez1.reshape(ez1.shape[0],self.original_dim*self.original_dim*2))
    pvar1 = self.pvar1(ez1.reshape(ez1.shape[0],self.original_dim*self.original_dim*2))

    stdvar1 = pvar1.mul(0.5).exp_()

    e1 = torch.randn([ez1.shape[0],self.original_dim*self.original_dim*2]).to(device)
    e1 = pmu1 + e1*stdvar1
    e1 = e1.reshape(e1.shape[0],8,self.original_dim//2,self.original_dim//2)
    recons = nn.Sigmoid()(self.finconv(self.decblock3(self.usconv2(self.decblock2(torch.cat((ez1,e1),1))))))

    return recons

  def loss(self,x):
    qmu0,qvar0,qmu1,qvar1,pmu1,pvar1,recons = self.forward(x)
    klz0 = 0.5*torch.sum(torch.square(qmu0)+qvar0.exp()-qvar0-1)/x.shape[0]
    klz1 = 0.5*torch.sum(torch.square(qmu1)/pvar1.exp()+qvar1.exp()-qvar1-1)
    reconsloss = nn.BCELoss()(recons,x)
    return klz0,klz1,reconsloss

  


    






In [None]:
batch_size=64

In [None]:
transform = transforms.Compose([
        transforms.ToTensor()])

In [None]:
#mnist data
train_dataset = torchvision.datasets.MNIST(root='data/mnist',
                                           train=True, 
                                           transform=transform,
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='data/mnist',
                                          train=False, 
                                          transform=transform)
#put into batches
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

In [None]:
model = NVAE(1,28).to(device)

In [None]:
optim = torch.optim.Adamax(model.parameters())

In [None]:
epochs=50

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

In [None]:
for epoch in range(epochs):
    minloss = 1
    running_kl0_loss=0
    running_recons_loss=0
    running_kl1_loss=0
    num_images=0
    for i,(img,label) in enumerate(train_loader):
      img = img.to(device)
      # label=label.to(device)
      optim.zero_grad()
      klz0,klz1,recons = model.loss(img)
      loss=recons+epoch*0.001*klz0+epoch*0.0001*klz1
      loss.backward()
      optim.step()
      running_kl0_loss = running_kl0_loss + klz0.item()*len(img)
      running_kl1_loss = running_kl1_loss + klz1.item()*len(img)
      running_recons_loss = running_recons_loss + recons.item()*len(img)

      num_images= num_images+len(img)
    print('epoch: '+str(epoch)+' kl0_loss: '+str(running_kl0_loss/num_images)+' recons_loss: '+str(running_recons_loss/num_images)+' kl1_loss: '+str(running_kl1_loss/num_images))
    imgs = model.sample(64).cpu().detach().reshape(64,28,28)
    plt.gray()
    fig = plt.figure(figsize=(8., 8.))
    grid = ImageGrid(fig, 111,  # similar to subplot(111)
                    nrows_ncols=(8, 8),  # creates 2x2 grid of axes
                    axes_pad=0.05  # pad between axes in inch.
                    )

    for ax, im in zip(grid, imgs):
        # Iterating over the grid returns the Axes.
        ax.imshow(im)
    plt.savefig(str(epoch)+".png")




epoch: 0 kl0_loss: 606.4141020177206 recons_loss: 0.08246776552001635 kl1_loss: 5356494.420431966
epoch: 1 kl0_loss: 103.98859626871744 recons_loss: 0.18555358707904815 kl1_loss: 27655.114556966146
epoch: 2 kl0_loss: 35.97397475382487 recons_loss: 0.20749857869942984 kl1_loss: 397.4461079264323
epoch: 3 kl0_loss: 15.648203428141276 recons_loss: 0.20467347184816997 kl1_loss: 132.21166833089194
epoch: 4 kl0_loss: 9.369819069417318 recons_loss: 0.1984669742425283 kl1_loss: 48.45297568359375
epoch: 5 kl0_loss: 7.066677468363444 recons_loss: 0.19788822547594706 kl1_loss: 16.822802369689942
epoch: 6 kl0_loss: 5.902504239400228 recons_loss: 0.19954941023190817 kl1_loss: 6.5516442260742185
epoch: 7 kl0_loss: 5.115420579020182 recons_loss: 0.20067655108769736 kl1_loss: 3.0555039043426513
epoch: 8 kl0_loss: 4.619134643554688 recons_loss: 0.197725900888443 kl1_loss: 1.7534462911923727
epoch: 9 kl0_loss: 4.31119142074585 recons_loss: 0.19328678325017293 kl1_loss: 1.002563635778427
epoch: 10 kl0_lo



epoch: 20 kl0_loss: 1.7575090525945027 recons_loss: 0.21876556200186412 kl1_loss: 0.04653615934054057
epoch: 21 kl0_loss: 1.5570258632659912 recons_loss: 0.22268498921394347 kl1_loss: 0.0396785405476888
epoch: 22 kl0_loss: 1.3601469254175822 recons_loss: 0.22699188867410025 kl1_loss: 0.03254029804865519
epoch: 23 kl0_loss: 1.1691280216217041 recons_loss: 0.23126601158777874 kl1_loss: 0.02503287035624186
epoch: 24 kl0_loss: 0.9971819259961446 recons_loss: 0.2348970758597056 kl1_loss: 0.021144128227233886
epoch: 25 kl0_loss: 0.8291180100123088 recons_loss: 0.23902759352525074 kl1_loss: 0.016420564317703248
epoch: 26 kl0_loss: 0.6568803822835286 recons_loss: 0.24345691878000894 kl1_loss: 0.01354682772954305
epoch: 27 kl0_loss: 0.5299254002253214 recons_loss: 0.24654099411964417 kl1_loss: 0.01146696524620056
epoch: 28 kl0_loss: 0.43261469926834106 recons_loss: 0.2492763815800349 kl1_loss: 0.009007540734608969
epoch: 29 kl0_loss: 0.3427351002852122 recons_loss: 0.2516372921625773 kl1_loss: 

In [None]:
imgs = model.sample(64).cpu().detach().reshape(64,28,28)

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

In [None]:
plt.gray()
fig = plt.figure(figsize=(8., 8.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(8, 8),  # creates 2x2 grid of axes
                 axes_pad=0.05  # pad between axes in inch.
                 )

for ax, im in zip(grid, imgs):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

plt.show()