In [1]:
import os
import numpy as np
import time
import pickle
import matplotlib.pyplot as plt
import matplotlib
from sklearn.metrics import accuracy_score, f1_score, classification_report
from scipy.stats import norm
from itertools import cycle


import torch
from torch import nn
import torch.optim as opt
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torchvision import transforms
from livelossplot import PlotLosses
from mpl_toolkits.axes_grid1 import ImageGrid
from collections import OrderedDict

from sprites import Sprites
from cycle_consistent_vae import Encoder, Decoder
from latent_classifier import Classifier


np.random.bit_generator = np.random._bit_generator

In [2]:
def mse_loss(inp, target):
    return torch.sum((inp - target).pow(2)) / inp.data.nelement()


def l1_loss(inp, target):
    return torch.sum(torch.abs(inp - target)) / inp.data.nelement()

def reparameterize(training, mu, logvar):
    if training:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu
    else:
        return mu

def weights_init(layer):
    if isinstance(layer, nn.Conv2d):
        layer.weight.data.normal_(0.0, 0.05)
        layer.bias.data.zero_()
    elif isinstance(layer, nn.BatchNorm2d):
        layer.weight.data.normal_(1.0, 0.02)
        layer.bias.data.zero_()
    elif isinstance(layer, nn.Linear):
        layer.weight.data.normal_(0.0, 0.05)
        layer.bias.data.zero_()

def kl_divergence_loss(mu, logvar):
    loss = 3 * (- 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()))
    
    return loss / (BATCH_SIZE * 3 * 60 * 60)

def imshow_grid(images, shape=[2, 8], name='default', save=False):
    """
    Plot images in a grid of a given shape.
    Initial code from: https://github.com/pumpikano/tf-dann/blob/master/utils.py
    """
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in range(size):
        grid[i].axis('off')
        grid[i].imshow(images[i])  # The AxesGrid object work as a list of axes.

    if save:
        plt.savefig('./reconstructed_images/' + str(name) + '.png')
        plt.clf()
    else:
        plt.show()

In [3]:
BATCH_SIZE = 16
train_data = Sprites()
test_data = Sprites(split='test')
train_loader = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=2, drop_last=True)
test_loader = DataLoader(test_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=2, drop_last=True)

In [4]:
cuda = 1
device = torch.device("cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")

In [5]:
MODEL_PATH = "./models/cycle_vae_06052020-030456_99.pth"
checkpoint = torch.load(MODEL_PATH)

In [6]:
Z_DIM = 16 #Style Dimension (Unspecified)
S_DIM = 16 # Class Dimension (Specified)

In [7]:
encoder = Encoder(style_dim=Z_DIM, class_dim=S_DIM)
encoder.load_state_dict(checkpoint['encoder'])
decoder = Decoder(style_dim=Z_DIM, class_dim=S_DIM)
decoder.load_state_dict(checkpoint['decoder'])

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [8]:
encoder.to(device)
encoder.eval()

Encoder(
  (conv_model): Sequential(
    (conv_1): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu_1): ReLU(inplace)
    (conv_2): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
    (bn_2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu_2): ReLU(inplace)
    (conv_3): Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
    (bn_3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu_3): ReLU(inplace)
    (conv_4): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
    (bn_4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu_4): ReLU(inplace)
  )
  (style_mu): Linear(in_features=512, out_features=16, bias=True)
  (style_logvar): Linear(in_features=512, out_features=16, bias=True)
  (class_output): Line

In [9]:
decoder.to(device)
decoder.eval()

Decoder(
  (style_input): Sequential(
    (0): Linear(in_features=16, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace)
  )
  (class_input): Sequential(
    (0): Linear(in_features=16, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace)
  )
  (deconv_model): Sequential(
    (deconv_1): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2))
    (de_bn_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (leakyrelu_1): LeakyReLU(negative_slope=0.2, inplace)
    (deconv_2): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2))
    (de_bn_2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (leakyrelu_2): LeakyReLU(negative_slope=0.2, inplace)
    (deconv_3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2))
    (de_bn_3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (leakyrelu_3): LeakyReLU(negative_slope=0.2, i

# Latent Classifier - Unspecified Features

In [None]:
model = Classifier(z_dim=Z_DIM, num_classes=)

In [11]:
TIME_STAMP = time.strftime("%d%m%Y-%H%M%S")

NUM_EPOCHS = 200

LEARNING_RATE = 1e-3

MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4
CUDA = True

In [14]:
is_better = True
prev_acc = float('inf')
name = "unspecified_latent_classifier"

liveloss = PlotLosses(fig_path='./figures/'+name+".png")

In [15]:
dataloaders = {'train':train_loader, 'validation':test_loader}

In [None]:
for epoch in range(NUM_EPOCHS):
    logs = {}
    t_start = time.time()
    
    for phase in ['train', 'validation']:
        if phase == 'train':
            model.train()
            
        else:
            model.eval()
        model.to(device)
        
        print("Started Phase")

        running_loss = 0.0
                
        predicted_phase = torch.zeros(len(dataloaders[phase].dataset), 8)
        target_phase = torch.zeros(len(dataloaders[phase].dataset))
        
        if phase == 'validation':
            
            with torch.no_grad():
                
                recons = None
                
                for (i,batch) in enumerate(dataloaders[phase]):
                    input_tensor = batch[0]
                    bs = input_tensor.shape[0]

                    (recons, inp, mu, logvar) = model(torch.Tensor(input_tensor).to(device))

                    loss = vae_loss(inp,recons,mu,logvar)

                    input_tensor = input_tensor.cpu()
                    running_loss += loss.detach() * bs
                
                print(epoch)
                    
                if epoch % DISPLAY_STATE == 0:

                    plot_model_state(recons, epoch)
     
        else:
            
            for (i,batch) in enumerate(dataloaders[phase]):
                input_tensor = batch[0]
                bs = input_tensor.shape[0]
                
                (recons, inp, mu, logvar) = model(torch.Tensor(input_tensor).to(device))

                loss = vae_loss(inp,recons,mu,logvar)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()


                input_tensor = input_tensor.cpu()
                running_loss += loss.detach() * bs
    

        epoch_loss = running_loss / len(dataloaders[phase].dataset)

        
        model.to('cpu')

        prefix = ''
        if phase == 'validation':
            prefix = 'val_'

        logs[prefix + 'log loss'] = epoch_loss.item()
        
        print('Phase time - ',time.time() - t_start)

    delta = time.time() - t_start
    is_better = logs['val_log loss'] < prev_acc
    if is_better:
        prev_acc = logs['val_log loss']
        torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(), 'loss': logs['log loss']}, "../models2/"+name+"_"+TIME_STAMP+"_"+str(logs['val_log loss'])+".pth")


    liveloss.update(logs)
    liveloss.draw()