## Imports

In [1]:
from __future__ import print_function

import os
import pdb
import nltk
import json
import utils
import pickle
import shutil
import PIL  
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
from torch.nn.utils.rnn import pad_sequence
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader

from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

import utils
import spacy    
import itertools
import collections
import matplotlib.pyplot as plt

torch.manual_seed(69)

cuda = torch.cuda.is_available()
cuda

True

## Model

In [2]:
# https://github.com/sksq96/pytorch-vae/blob/master/vae-cnn.ipynb

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
    
class UnFlatten(nn.Module):
    def __init__(self, h):
        super(UnFlatten, self).__init__() 
        self.h = h
        
    def forward(self, input):
        return input.view(input.size(0), self.h, 1, 1)
    
class VAE_CNN(nn.Module):
    def __init__(self, image_channels=1, h_dim=1024, z_dim=32):
        super(VAE_CNN, self).__init__()
        self.h = h_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 128, kernel_size=4, stride=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=4, stride=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=4, stride=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=2),
            nn.ReLU(),
            Flatten()
        )
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(self.h),
            nn.ConvTranspose2d(h_dim, 512, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, image_channels, kernel_size=6, stride=2),
            nn.MaxPool2d(2),
            nn.Sigmoid(),
        )
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def encode(self, x):
        h = self.encoder(x)
        z, mu, logvar = self.bottleneck(h)
        return z, mu, logvar

    def decode(self, z):
        z = self.fc3(z)
        z = self.decoder(z)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        z = self.decode(z)
        return z, mu, logvar

## Train/ Test

In [3]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x.view(-1, 4096), x.view(-1, 4096), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD


def train(epoch, log_interval):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))
    
    return train_loss / len(train_loader.dataset)


def test(epoch, batch_size):
    model.eval()
    test_loss = 0
    loader = test_loader
    with torch.no_grad():
        for i, (data, _) in enumerate(loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                new_recon_batch = recon_batch.view(batch_size, 1, 64, 64)[:n]
                comparison = torch.cat([data[:n], new_recon_batch])
                save_image(comparison.cpu(),
                           './results/recon_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    return test_loss

## Run

In [4]:
# Parameters
batch_size = 64
epochs = 50
log_interval = 100
h = 2048
z = 256

# Cuda
device = torch.device("cuda" if cuda else "cpu")
print('Device: ', device)

# Data
kwargs = {'num_workers': 8} if cuda else {}

# https://stackoverflow.com/questions/52439364/how-to-convert-rgb-images-to-grayscale-in-pytorch-dataloader
trainTransform  = transforms.Compose([transforms.Grayscale(num_output_channels=1), transforms.ToTensor()])

# https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder
train_percentage = .8
dataset = datasets.ImageFolder('./image_folder_data/', transform=trainTransform)

train_num = int(len(dataset) * train_percentage)
test_num = len(dataset) - train_num

train_set, test_set = torch.utils.data.random_split(dataset, [train_num, test_num])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, **kwargs)

# Model
model = VAE_CNN(1, h, z).to(device)
model.load_state_dict(torch.load('./pretrained_models/model%d.pt' % 19))
optimizer = optim.Adam(model.parameters(), lr=1e-3)

trains = []
tests = []
print('Running Epochs...')
for epoch in range(20, epochs + 1):
    tr = train(epoch, log_interval)
    te = test(epoch, batch_size)
    with torch.no_grad():
        sample = torch.randn(64, z).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 64, 64),
                   './results/sample_' + str(epoch) + '.png')
    trains.append(tr)
    tests.append(te)
    new_list = [trains, tests]
    with open('losses.pkl', 'wb') as b:
        pickle.dump(new_list, b)
    
    torch.save(model.state_dict(), './pretrained_models/model%d.pt' % epoch)

Device:  cuda
Running Epochs...


KeyboardInterrupt: 