In [None]:
# imports
import os
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision.utils import make_grid
from torchvision.utils import save_image
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torchvision.datasets as dset
import torchvision
from sklearn.metrics import accuracy_score
import sklearn
import torch.utils.data as torch_data
import torch.nn as nn
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
from sklearn.datasets import load_digits
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F

from tqdm.notebook import tqdm
%matplotlib inline

In [None]:
from poslayers.poslayers import Dense, PosDense, PosConv2d

In [None]:
from vae.vanila_vae import *

# Simple ConvAE

In [None]:
IMAGE_SIZE = 64 * 64
IMAGE_WIDTH = IMAGE_HEIGHT = 64
INPUT_CHANNELS = 3
DATA_PATH = r'./data/celeba/'

In [None]:
from torch.utils.data import random_split

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = ImageFolder(DATA_PATH, transform=transform)

# for dataset reduction
# dataset, _ = random_split(
#     dataset, (int(len(dataset) * 0.01), len(dataset) - int(len(dataset) * 0.01)))

train_size = int(len(dataset) * 0.8)
val_size = len(dataset) - train_size
train_data, val_data = random_split(dataset, (train_size, val_size))


train_loader = DataLoader(dataset=train_data, batch_size=128,
                          shuffle=True, num_workers=4, drop_last=True)
val_loader = DataLoader(dataset=val_data, batch_size=128,
                        shuffle=True, num_workers=4, drop_last=True)

In [None]:
#primitive one
class AutoEncoder(nn.Module):
    
    def __init__(self, code_size):
        super().__init__()
        self.code_size = code_size
        
        # Encoder specification
        self.enc_cnn_1 = nn.Conv2d(INPUT_CHANNELS, 5, kernel_size=5)
        self.enc_cnn_2 = nn.Conv2d(5, 10, kernel_size=5)
        self.enc_linear_1 = nn.Linear(10 * 13 * 13, 800)
        self.enc_linear_2 = nn.Linear(800, self.code_size)
        
        # Decoder specification
        self.dec_linear_1 = nn.Linear(self.code_size, 4000)
        self.dec_linear_2 = nn.Linear(4000, IMAGE_SIZE * INPUT_CHANNELS)
        
    def forward(self, images):
        code = self.encode(images)
        out = self.decode(code)
        return out, code
    
    def encode(self, images):
        code = self.enc_cnn_1(images)
        code = F.selu(F.max_pool2d(code, 2))
        
        code = self.enc_cnn_2(code)
        code = F.selu(F.max_pool2d(code, 2))
        
        code = code.view([images.size(0), -1])
        code = F.selu(self.enc_linear_1(code))
        code = self.enc_linear_2(code)
        return code
    
    def decode(self, code):
        out = F.selu(self.dec_linear_1(code))
        out = F.selu(self.dec_linear_2(out))
        out = out.view([code.size(0), INPUT_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT])
        return out

In [None]:
# advanced one
class AutoEncoder(nn.Module):
    
    def __init__(self, code_size):
        super().__init__()
        self.code_size = code_size
        
        # Encoder specification
        self.enc_cnn_1 = nn.Conv2d(INPUT_CHANNELS, 5, kernel_size=5) # -> (-4, -4) / 2
        self.enc_cnn_2 = nn.Conv2d(5, 10, kernel_size=5) # -> (-4, -4) / 2
        self.imdim = ((IMAGE_WIDTH - 4) // 2 - 4) // 2
        self.enc_linear_1 = nn.Linear(10 * self.imdim * self.imdim, 800)
        self.enc_linear_2 = nn.Linear(800, self.code_size)
        
        # Decoder specification
        self.dec_linear_1 = nn.Linear(self.code_size, 800)
        self.dec_linear_2 = nn.Linear(800, 10 * self.imdim * self.imdim)
        
        self.dec_up1 = nn.UpsamplingNearest2d(scale_factor=2)
        self.dec_cnn_1 = nn.ConvTranspose2d(10, 5, kernel_size=5)
        
        self.dec_up2 = nn.UpsamplingNearest2d(scale_factor=2)
        self.dec_cnn_2 = nn.ConvTranspose2d(5, INPUT_CHANNELS, kernel_size=5)
        
    def forward(self, images):
        code = self.encode(images)
        out = self.decode(code)
        return out, code
    
    def encode(self, images):
        code = self.enc_cnn_1(images)
        code = F.relu(F.max_pool2d(code, 2))
        
        code = self.enc_cnn_2(code)
        code = F.relu(F.max_pool2d(code, 2))
        
        code = code.view([images.size(0), -1])
        code = F.relu(self.enc_linear_1(code))
        code = self.enc_linear_2(code)
        return code
    
    def decode(self, code):
        out = F.relu(self.dec_linear_1(code))
        out = F.relu(self.dec_linear_2(out))
        out = out.view([-1, 10, self.imdim, self.imdim])
        out = self.dec_up1(out)
        out = F.relu(self.dec_cnn_1(out))
        out = self.dec_up2(out)
        out = F.relu(self.dec_cnn_2(out))
        out = out.view([-1, INPUT_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT])
        return out

In [None]:
# Hyperparameters
code_size = 500
num_epochs = 5
batch_size = 64
lr = 0.005


device = 'cpu'
net = AutoEncoder(code_size=code_size)#.to(device)
criterion = nn.MSELoss()#.to(device)
optimizer = optim.Adam(net.parameters())

In [None]:
def train_ae(epochs, net, criterion, optimizer,
             scheduler=None, verbose=True, save_dir=None, device=0):
    for epoch in (range(1, epochs + 1)):
        train_loss = []
        net.train()
        for X, _ in (train_loader):
            X = X#.to(device)
            out, code = net(X)
            optimizer.zero_grad()
            loss = criterion(out, X)
            train_loss.append(loss.item())
            loss.backward()
            optimizer.step()
        
        
        val_loss = []
        net.eval()
        for X, _ in val_loader:
            X = X#.to(device)
            out, code = net(X)
            optimizer.zero_grad()
            loss = criterion(out, X)
            val_loss.append(loss.item())
            loss.backward()
            optimizer.step()
         
        if scheduler is not None:
            scheduler.step()
        freq = 1
        if verbose and epoch%freq==0:
            print('Epoch {}/{} || Loss:  Train {:.4f} | Validation {:.4f}'.format(epoch, epochs, np.mean(train_loss), np.mean(val_loss)))
            

In [None]:
train_ae(20, net, criterion, optimizer, 
         scheduler=None, verbose=True, save_dir=None, device=1)