In [254]:
#imports
%matplotlib inline
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.utils.data as torch_data
import sklearn
from sklearn.metrics import accuracy_score

In [255]:
from poslayers.poslayers import Dense, PosDense, PosConv2d

In [257]:
from torchvision.datasets import CelebA

# Simple ConvAE

In [265]:
import random

import torch
from   torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from   torchvision import datasets, transforms
from PIL import Image

import numpy as np

In [316]:
class CelebA():
    def __init__(self, root_path, shape=(128, 128), dataset_size=30000, train_part=0.8):
        self.root_path = root_path
        self.shape = shape
        self.dataset_size = dataset_size
        self.train_size = round(train_part * dataset_size)
        self.totensor = transforms.ToTensor()
        self.train_idx_map = np.arange(self.train_size) + 1
        self.test_idx_map = np.arange(self.train_size, self.dataset_size) + 1
        self.shuffle()
    
    def __get_item__(self, index):
        name = "{:06d}.jpg".format(index + 1)
        img = Image(f'{self.root_path}/{name}').resize(self.shape)
        return self.totensor(img)
    
    def shuffle(self):
        np.random.shuffle(self.train_idx_map)
        np.random.shuffle(self.test_idx_map)
    
    def load_train_batch(self, batch_first_idx, batch_size):
        indexes = batch_first_idx + np.arange(min(batch_size, self.train_size - batch_first_idx))
        indexes = self.train_idx_map[indexes]
        names = ["{:06d}.jpg".format(ind) for ind in indexes]
        
        batch = []
        for name in names:
            img = Image.open(f'{self.root_path}/{name}').resize(self.shape)
            batch.append(self.totensor(img))
        return torch.stack(batch, dim=0)
    
    def load_test_batch(self, batch_first_idx, batch_size):
        indexes = batch_first_idx + np.arange(min(batch_size - 1, self.dataset_size - self.train_size - 1))
        indexes = self.test_idx_map[indexes]
        names = ["{:06d}.jpg".format(ind) for ind in indexes]

        batch = []
        for name in names:
            img = Image.open(f'{self.root_path}/{name}').resize(self.shape)
            batch.append(self.totensor(img))
        return torch.stack(batch, dim=0)

In [317]:
class AutoEncoder(nn.Module):
    
    def __init__(self, code_size):
        super().__init__()
        self.code_size = code_size
        
        # Encoder specification
        self.enc_cnn_1 = nn.Conv2d(INPUT_CHANNELS, 5, kernel_size=5)
        self.enc_cnn_2 = nn.Conv2d(5, 10, kernel_size=5)
        self.enc_linear_1 = nn.Linear(10 * 13 * 13, 800)
        self.enc_linear_2 = nn.Linear(800, self.code_size)
        
        # Decoder specification
        self.dec_linear_1 = nn.Linear(self.code_size, 4000)
        self.dec_linear_2 = nn.Linear(4000, IMAGE_SIZE * INPUT_CHANNELS)
        
    def forward(self, images):
        code = self.encode(images)
        out = self.decode(code)
        return out, code
    
    def encode(self, images):
        code = self.enc_cnn_1(images)
        code = F.selu(F.max_pool2d(code, 2))
        
        code = self.enc_cnn_2(code)
        code = F.selu(F.max_pool2d(code, 2))
        
        code = code.view([images.size(0), -1])
        code = F.selu(self.enc_linear_1(code))
        code = self.enc_linear_2(code)
        return code
    
    def decode(self, code):
        out = F.selu(self.dec_linear_1(code))
        out = F.sigmoid(self.dec_linear_2(out))
        out = out.view([code.size(0), INPUT_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT])
        return out

In [318]:
from tqdm.notebook import tqdm

In [319]:
IMAGE_SIZE = 64 * 64
IMAGE_WIDTH = IMAGE_HEIGHT = 64
INPUT_CHANNELS = 3

# Hyperparameters
code_size = 500
num_epochs = 5
batch_size = 64
lr = 0.002
loss_fn = nn.BCELoss()


net = AutoEncoder(code_size=code_size)
celeba = CelebA('./celeba/img_align_celeba', shape=(IMAGE_WIDTH, IMAGE_HEIGHT))
optimizer = optim.Adam(net.parameters())

In [320]:
def train_ae(epochs, net, criterion, optimizer, ds, batch_size=128, 
             scheduler=None, verbose=True, save_dir=None, device=0):
    
    for epoch in range(1, epochs + 1):

        net.train()
        train_loss = []
        for batch_idx in tqdm(range(0, ds.train_size, batch_size)):
            batch = ds.load_train_batch(batch_idx, batch_size)
            out, code = net(Variable(batch))
            optimizer.zero_grad()
            loss = criterion(out, batch)
            loss.backward()
            optimizer.step()
            train_loss.append(loss)
        
        
        net.eval()
        val_loss = []
        for batch_idx in tqdm(range(0, ds.dataset_size - ds.train_size, batch_size)):
            batch = ds.load_test_batch(batch_idx, batch_size)
            out, code = net(Variable(batch))
            optimizer.zero_grad()
            loss = criterion(out, batch)
            loss.backward()
            optimizer.step()
            val_loss.append(loss)
         
        if scheduler is not None:
            scheduler.step()
        freq = 3
        if verbose and epoch%freq==0:
            print('Epoch {}/{} || Loss:  Train {:.4f} | Validation {:.4f}'.format(epoch, epochs, np.mean(train_loss), np.mean(val_loss)))

In [321]:
train_ae(50, net, loss_fn, optimizer, celeba, batch_size=64, 
         scheduler=None, verbose=True, save_dir=None, device=1)

HBox(children=(IntProgress(value=0, max=375), HTML(value='')))

KeyboardInterrupt: 