In [1]:
import pandas as pd
import numpy as np
import os
import seaborn as sn
import random
import copy
import matplotlib.pyplot as plt
import pygal
import cv2
from sklearn.metrics import confusion_matrix
from PIL import Image
from IPython import display
from tqdm import tqdm, trange
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as trns
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

visual_folder = "./visual"

In [2]:
class animateDataset(Dataset):
    def __init__(self):
        self.files = os.listdir(os.getcwd()+'/data')
        self.transform = trns.Compose([trns.ToTensor()])
        
    def __getitem__(self,index):
        image = Image.open(os.getcwd()+'/data/'+self.files[index]).convert('RGB')
        image = self.transform(image)
        
        return image
    
    def __len__(self):
        return len(self.files)

In [3]:
batch_size = 64

dataset = animateDataset()
train_sz = int(len(dataset)*0.7)
test_sz = len(dataset)-train_sz
train_set,test_set = random_split(dataset,[train_sz,test_sz])

train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size, 
                          shuffle=True,
                          num_workers=4)

test_loader = DataLoader(dataset=test_set,
                          batch_size=batch_size, 
                          shuffle=False,
                          num_workers=4)

In [4]:
class VAEmodule(nn.Module):
    def __init__(self):
        super(VAEmodule, self).__init__()
        self.image_size = 64
        self.encoder_channel = 64
        self.decoder_channel = 64
        self.latent_size = 512
        # ----------------------------------------------------------------------------------------                          
        self.en_conv1 = self.conv_layer(3,self.encoder_channel)
        self.en_conv2 = self.conv_layer(self.encoder_channel,self.encoder_channel*2)
        self.en_conv3 = self.conv_layer(self.encoder_channel*2,self.encoder_channel*4)
        self.en_conv4 = self.conv_layer(self.encoder_channel*4,self.encoder_channel*8)
        self.en_fc1 = nn.Sequential(nn.Linear(self.encoder_channel*8*(self.image_size//16)**2, self.latent_size),
                                    nn.BatchNorm1d(self.latent_size))
        self.en_fc2 = nn.Sequential(nn.Linear(self.encoder_channel*8*(self.image_size//16)**2, self.latent_size),
                                    nn.BatchNorm1d(self.latent_size))
        # ----------------------------------------------------------------------------------------                          
        self.de_fc1 = nn.Linear(self.latent_size, self.decoder_channel*4*(self.image_size//8)**2)
        self.de_fc2 = nn.Sequential(nn.BatchNorm2d(self.decoder_channel*4),
                                    nn.ReLU())
        self.de_conv1 = self.de_conv_layer(self.decoder_channel*4, self.decoder_channel*4)
        self.de_conv2 = self.de_conv_layer(self.decoder_channel*4, self.decoder_channel*2)
        self.de_conv3 = self.de_conv_layer(self.decoder_channel*2, self.decoder_channel//2)
        self.dc1 = nn.ConvTranspose2d(self.decoder_channel//2, 3, 5, padding=2)
        
    
    def conv_layer(self,input_channel, output_channel, kernel=5, stride=2, padding=2):
        return nn.Sequential(nn.Conv2d(input_channel, output_channel, kernel, stride=stride, padding=padding),
                             nn.BatchNorm2d(output_channel),
                             nn.ReLU())

    def de_conv_layer(self,input_channel, output_channel, kernel=6, stride=2, padding=2):
        return nn.Sequential(nn.ConvTranspose2d(input_channel, output_channel, kernel, stride=stride, padding=padding),
                         nn.BatchNorm2d(output_channel),
                         nn.ReLU())
    
    def encoder(self,x):
        x = self.en_conv1(x)
        x = self.en_conv2(x)
        x = self.en_conv3(x)
        x = self.en_conv4(x)
        x = x.view(x.size(0),-1)
        mean = self.en_fc1(x)
        logvar = self.en_fc2(x)
        return mean, logvar
    
    def reparameter(self, mean, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mean + eps*std
    
    def decoder(self,x):
        # forward pass
        x = self.de_fc1(x)
        x = x.view(-1, self.decoder_channel*4, (self.image_size//8), (self.image_size//8))
        x = self.de_fc2(x)
        x = self.de_conv1(x)
        x = self.de_conv2(x)
        x = self.de_conv3(x)
        x = self.dc1(x)
        x = torch.sigmoid(x)
        return x
    
    def forward(self,x):
        mean, logvar = self.encoder(x)
        x = self.reparameter(mean, logvar)
        x = self.decoder(x)
        return x ,mean ,logvar

In [5]:
def loss_function(x, x_hat, mean, logvar):
    BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')

#     KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    KLD = 0
    return BCE + KLD

In [None]:
learning_rate = 0.001
epoch = 240

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

vae = VAEmodule()
vae.to(device)
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

iteration = 0
ELBO = []
ELBO_logs = []
train_accuracy = []
test_accuracy = []

t = trange(epoch)

for e in t:  # loop over the dataset multiple times    
    for imgs in train_loader:
        imgs = imgs.to(device)
        optimizer.zero_grad()
        x_hat ,mean ,logvar = vae(imgs)
        elbo = loss_function(imgs, x_hat ,mean ,logvar)
        elbo.backward()
        optimizer.step()

        ELBO.append(elbo.item()/imgs.size(0))
        
    ELBO_logs.append(np.mean(ELBO))
    ELBO = []

    t.set_description("training loss:%.4f"%(ELBO_logs[-1]))

vae.eval()


training loss:5802.5401:  13%|█▎        | 32/240 [11:14<1:13:25, 21.18s/it]

In [None]:
plt.ylabel('Evidence Lower Bound')
plt.xlabel('epochs')
plt.plot(ELBO_logs)
plt.savefig(visual_folder+'/train_elbo.png', bbox_inches = "tight")
plt.show()

In [None]:
def save_imgs(imgs, file_name):
    torchvision.utils.save_image(imgs,visual_folder+'/'+file_name)
    imgs = torchvision.utils.make_grid(imgs)
    toPIL = trns.ToPILImage()
    imgs = toPIL(imgs)
    plt.imshow(imgs)
    plt.show()
    
def reconstruct(dataloader, model, batch_num=1):
    i = 0
    origin = None
    recons = None
    for imgs in dataloader:
        imgs = imgs.to(device)
        x_hat ,mean ,logvar = model(imgs)
        if recons is None : 
            origin=imgs
            recons=x_hat
        else: 
            recons=torch.cat((x_hat,recons),0)
            origin=torch.cat((imgs,origin),0)
        
        i+=1
        if i==batch_num:
            return recons.cpu() ,origin.cpu()

In [None]:
imgs, origins = reconstruct(test_loader, vae)
save_imgs(imgs,'reconstruct.png')
save_imgs(origins,'origin.png')

In [None]:
toPIL = trns.ToPILImage()
sample = torch.randn(batch_size, vae.latent_size).to(device)
imgs = vae.decoder(sample).cpu()
save_imgs(imgs,'test.png')

In [None]:
def interpolate(t,size):
    interval = (t[1]-t[0])/size
    new = t[0]
    for i in range(1,size):
        new = torch.cat((new,t[0]+interval*i),0)
    new = new.view(size,-1)
    print(new.shape)
    return new

sample = torch.randn(2, vae.latent_size).to(device)
imgs = vae.decoder(sample).cpu()
save_imgs(imgs,'interpolate.png')

sample = interpolate(sample,16)

imgs = vae.decoder(sample).cpu()
save_imgs(imgs,'interpolate.png')