In [11]:
import torch
import torch.nn as nn
import numpy as np
from scipy.stats import norm

In [12]:
class Encoder(nn.Module):
    def __init__(self,im_chan=3,output_chan=128,hidden_dim=16):
        super(Encoder, self).__init__()
        self.z_dim=output_chan
        self.enc=nn.Sequential(
            self.make_enc_block(im_chan,hidden_dim),
            self.make_enc_block(hidden_dim,hidden_dim*2),
            self.make_enc_block(hidden_dim*2,hidden_dim*4),
            self.make_enc_block(hidden_dim*4,hidden_dim*8),
            self.make_enc_block(hidden_dim * 8, hidden_dim * 8),
            self.make_enc_block(hidden_dim*8,2*output_chan,kernel_size=4,stride=1,padding=0,final_layer=True)
        )
        
    def make_enc_block(self,in_channels,out_channels,kernel_size=4,stride=2,padding=1,final_layer=False):
        if not final_layer:
            return nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2,inplace=True),
            )
        else:
            return nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
            )
    
    def forward(self,x):
        x=self.enc(x)
        enc=x.view(len(x),-1)
#         print(enc.shape)
        return enc[:,:self.z_dim],enc[:,self.z_dim:].exp()
    

    
        

In [13]:
class Decoder(nn.Module):
    def __init__(self, z_dim=128, im_chan=3, hidden_dim=64):
        super(Decoder, self).__init__()
        self.z_dim = z_dim
        self.gen=nn.Sequential(
        self.make_gen_block(z_dim,hidden_dim*8,kernel_size=4,stride=1,padding=0), # ch x 4 x4
        self.make_gen_block(hidden_dim * 8,hidden_dim * 4,kernel_size=4,stride=2,padding=1),  #ch x 8 x8
        self.make_gen_block(hidden_dim * 4,hidden_dim*2,kernel_size=4,stride=2,padding=1), # ch x 16 x 16
        self.make_gen_block(hidden_dim * 2,hidden_dim*1,kernel_size=4,stride=2,padding=1), # ch x 32 x 32
        self.make_gen_block(hidden_dim * 1,hidden_dim*1,kernel_size=4,stride=2,padding=1), # ch x 64 x 64
        self.make_gen_block(hidden_dim,im_chan,kernel_size=4,stride=2,padding=1,final_layer=True), # 128 X 128
        )
        
    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2,padding=0,final_layer=False):
        
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride,padding),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride,padding),
                nn.Tanh(),
            )
    def forward(self, noise):
       
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)
    

In [14]:
from torch.distributions.normal import Normal

class VAE(nn.Module):
    def __init__(self,z_dim=128,im_chan=3,hidden_dim=64):
        super(VAE, self).__init__()
        self.z_dim=z_dim
        self.encoder=Encoder(im_chan,z_dim)
        self.decoder=Decoder(z_dim,im_chan)
        
    def forward(self,images):
        mean,std=self.encoder(images)
#         print(torch.cat((mean[0].detach(),std[0].detach()),dim=0))
        dist=Normal(mean,std)
        z=dist.rsample()
        decoding=self.decoder(z)
        
        return decoding,dist

$\mathrm{NLL}(x) \propto (x-\mu)^2$ for $x \sim \mathcal{N}(\mu,\sigma)$

In [15]:
reconstruction_loss=nn.MSELoss(reduction='sum')

In [16]:
from torch.distributions.kl import kl_divergence
def kl_divergence_loss(q_dist):
    return kl_divergence(
    q_dist,Normal(torch.zeros_like(q_dist.mean),torch.ones_like(q_dist.stddev)*2)
    ).sum(-1)

In [17]:
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

batch_size=64

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Resize((128,128))
])

data=ImageFolder(root="P:\\dataset\\Flickr-Faces\\thumbnails128x128",transform=transform)
dataloader=DataLoader(dataset=data,batch_size=batch_size,shuffle=True,num_workers=4,persistent_workers=True)

In [18]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"]=(16,8)

from torchvision.utils import make_grid
from tqdm import tqdm
import time

def show_tensor_images(image_tensor,num_images=4,size=(3,128,128)):
#     print(image_tensor.shape)
    image_tensor=(image_tensor+1)/2
    image_unflat=image_tensor.detach().cpu()
    image_grid= make_grid(image_unflat[:num_images],nrow=5)
    plt.axis("off")
    plt.imshow(image_grid.permute(1,2,0).squeeze())
#     plt.show()
    

device='cuda'
vae=VAE().to(device)
# print(vae.parameters())
# model_parameters = filter(lambda p: p.requires_grad, vae.parameters())
# params = sum([np.prod(p.size()) for p in model_parameters])
# print(params)
vae.load_state_dict(torch.load("vae.pth"))
vae_opt=torch.optim.Adam(vae.parameters(),lr=0.001)

for epoch in range(20):
    for images,_ in tqdm(dataloader):
        images=images.to(device)
        vae_opt.zero_grad()
        recon_images,encoding=vae(images)
        
        loss=reconstruction_loss(recon_images,images)+kl_divergence_loss(encoding).sum()
        loss.backward()
        
        vae_opt.step()
    print(torch.mean(encoding.loc).item(),torch.mean(encoding.scale).item())
    print("Epoch",epoch,"loss ",loss.item())
    plt.subplot(1,2,1)
    show_tensor_images(images)
    plt.title("True")
    plt.subplot(1,2,2)
    show_tensor_images(recon_images)
    plt.title("Reconstructed")
    plt.show()
    torch.save(vae.decoder.state_dict(), "vae_decoder.pth")
    torch.save(vae.state_dict(), "vae.pth")
        

  5%|███▊                                                                            | 48/1024 [00:06<02:16,  7.16it/s]


KeyboardInterrupt: 

In [None]:
gen=Decoder()
gen.load_state_dict(torch.load("vae_decoder.pth"))
gen.eval()

In [None]:
def show_tensor_images(image_tensor,num_images=16,size=(3,128,128),nrow=3):
    image_tensor=(image_tensor+1)/2
    image_unflat=image_tensor.detach().cpu()
    image_grid=make_grid(image_unflat[:num_images],nrow=nrow)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    plt.show()

In [22]:
z=torch.randn(10,128)

In [23]:
images=gen(z)

In [None]:
show_tensor_images(images)

In [None]:
a=torch.load("vae.pth")

In [106]:
a.keys()


odict_keys(['encoder.enc.0.0.weight', 'encoder.enc.0.0.bias', 'encoder.enc.0.1.weight', 'encoder.enc.0.1.bias', 'encoder.enc.0.1.running_mean', 'encoder.enc.0.1.running_var', 'encoder.enc.0.1.num_batches_tracked', 'encoder.enc.1.0.weight', 'encoder.enc.1.0.bias', 'encoder.enc.1.1.weight', 'encoder.enc.1.1.bias', 'encoder.enc.1.1.running_mean', 'encoder.enc.1.1.running_var', 'encoder.enc.1.1.num_batches_tracked', 'encoder.enc.2.0.weight', 'encoder.enc.2.0.bias', 'encoder.enc.2.1.weight', 'encoder.enc.2.1.bias', 'encoder.enc.2.1.running_mean', 'encoder.enc.2.1.running_var', 'encoder.enc.2.1.num_batches_tracked', 'encoder.enc.3.0.weight', 'encoder.enc.3.0.bias', 'encoder.enc.3.1.weight', 'encoder.enc.3.1.bias', 'encoder.enc.3.1.running_mean', 'encoder.enc.3.1.running_var', 'encoder.enc.3.1.num_batches_tracked', 'encoder.enc.4.0.weight', 'encoder.enc.4.0.bias', 'encoder.enc.4.1.weight', 'encoder.enc.4.1.bias', 'encoder.enc.4.1.running_mean', 'encoder.enc.4.1.running_var', 'encoder.enc.4.1.