In [1]:
import data, model.unet, model.autoencoder, loss, function
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils
import numpy as np
import torch.utils.tensorboard as tb
import torchvision
import scipy.stats as stats
import pickle
import datetime
import os
from sklearn.decomposition import PCA
from PIL import Image

%load_ext autoreload
%autoreload 2

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
date = datetime.datetime.now()
timestamp = date.strftime(f"ae_and_vae_%d-%b-%Y_%H.%M.%S")
os.makedirs(f"log/{timestamp}")
tb_writer = tb.SummaryWriter(f"log/{timestamp}")

In [3]:
VAL_PORTION = 0.2
ITERATIONS = 100001
VAL_ITERATIONS = 5
VAL_ITERATIONS_OVERFIT = 1
RESOLUTION = 96
CHANNELS = 3
STYLE_DIM = 512

BATCH_SIZE = 16
LOSS_TYPE = 'l2'

CONTENT_LOSS_WEIGHTS = {
    #'relu_1_1' : 1e-2,
    #'relu_4_2' : 5e-3,
    'relu_4_2' : 2e-2,
    #'relu_4_2' : 1e0,
}

STYLE_LOSS_WEIGHTS = {
    'relu_1_1' : 1e3,
    'relu_2_1' : 5e3, # 5e3
    'relu_3_1' : 1e3,
    'relu_4_1' : 1e3, # 1e3
    'relu_5_1' : 1e3,
}

STYLE_LOSS_ALPHA = 1.0
KLD_LOSS_WEIGHT = 5e-5

In [15]:
torch.manual_seed(0)
np.random.seed(0)

TRAINING_PORTION_STYLE=128 # 128

data_style_train = data.load_dataset("../dataset/style_cherrypicked/train", resolution=RESOLUTION)
data_style_test = data.load_dataset("../dataset/style_cherrypicked/test", resolution=RESOLUTION)
data_loader_style_train = DataLoader(data_style_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_style_test = DataLoader(data_style_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_content_test = data.load_debug_dataset('../dataset/content_test', resolution=RESOLUTION)
data_loader_content_test = DataLoader(data_content_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_loader_test_overfit = data.DatasetPairIterator(data_loader_content_test, data_loader_style_train)
data_loader_test = data.DatasetPairIterator(data_loader_content_test, data_loader_style_test)

In [5]:
DOWNUP_CONVOLUTIONS = 5 #3
ADAIN_CONVOLUTIONS = 7 #3
STYLE_DOWN_CONVOLUTIONS = 5 #3
NUM_LAYERS_NO_CONNECTION = 0
RESIDUAL_STYLE = True # False
RESIDUAL_DOWN = True # False
RESIDUAL_ADAIN = True
RESIDUAL_UP = True
STYLE_NORM = True
DOWN_NORM = 'in'
UP_NORM = 'adain'


#normalization = ['adain', 'adain', None, None, None]
unet_vae = model.unet.UNetAutoencoder(3, STYLE_DIM, residual_downsampling=RESIDUAL_DOWN, residual_adain=RESIDUAL_ADAIN, residual_upsampling=RESIDUAL_UP, 
        down_normalization=DOWN_NORM, up_normalization=UP_NORM, num_adain_convolutions=ADAIN_CONVOLUTIONS, 
        num_downup_convolutions=DOWNUP_CONVOLUTIONS, num_downup_without_connections=NUM_LAYERS_NO_CONNECTION, output_activation='sigmoid')
unet = model.unet.UNetAutoencoder(3, 2 * STYLE_DIM, residual_downsampling=RESIDUAL_DOWN, residual_adain=RESIDUAL_ADAIN, residual_upsampling=RESIDUAL_UP, 
        down_normalization=DOWN_NORM, up_normalization=UP_NORM, num_adain_convolutions=ADAIN_CONVOLUTIONS, 
        num_downup_convolutions=DOWNUP_CONVOLUTIONS, num_downup_without_connections=NUM_LAYERS_NO_CONNECTION, output_activation='sigmoid')


style_encoder_vae = model.autoencoder.Encoder(2 * STYLE_DIM, normalization=STYLE_NORM, residual=RESIDUAL_STYLE, num_down_convolutions=STYLE_DOWN_CONVOLUTIONS)
style_encoder = model.autoencoder.Encoder(2 * STYLE_DIM, normalization=STYLE_NORM, residual=RESIDUAL_STYLE, num_down_convolutions=STYLE_DOWN_CONVOLUTIONS)

In [6]:
# Load pre-trained model for vae
state_dict = torch.load('./models/model.pt', map_location=lambda storage, loc: storage)
unet_vae.load_state_dict(state_dict['unet_state_dict'])
style_encoder_vae.load_state_dict(state_dict['style_encoder_state_dict'])

<All keys matched successfully>

In [8]:
# Load pre-trained  model for ae
state_dict_ae = torch.load('./models/model_non_vae.pt', map_location=lambda storage, loc: storage)
unet.load_state_dict(state_dict_ae['unet_state_dict'])
style_encoder.load_state_dict(state_dict_ae['style_encoder_state_dict'])


<All keys matched successfully>

In [9]:
def forward(content_image, style_image):
    """ Forward pass through the architecture.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_image : torch.Tensor, shape [batch_size, 3, H, W]
        The style images.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_encoding : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    style_encoding_mean : torch.Tensor, shape [batch_size, STYLE_DIM]
        Means for the style encodings.
    style_encoding_logvar : torch.Tensor, shape [batch_size, STYLE_DIM]
        Logarithm of the variances of the style encodings.
    """
    style_stats = style_encoder_vae(style_image)
    style_mean = style_stats[..., : STYLE_DIM]
    style_logvar = style_stats[..., STYLE_DIM : ]
    style_sample = function.sample_normal(style_mean, style_logvar)
    stylized = unet_vae(content_image, style_sample)
    return stylized, style_sample, style_mean, style_logvar

def forward_no_vae(content_image, style_image):
    """ Forward pass through the architecture.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_image : torch.Tensor, shape [batch_size, 3, H, W]
        The style images.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_encoding : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    """
    style_encoding = style_encoder(style_image)
    stylized = unet(content_image, style_encoding)
    return stylized, style_encoding

# Compare Autoencoder to Variational Autoencoder on train styles

In [17]:
MAX_ITERS = 100
with torch.no_grad():
    iteration = 0
    unet.eval()
    style_encoder.eval()
    style_encoder_vae.eval()
    for (content_image, content_path), (style_image, style_path) in data_loader_test_overfit:

        if iteration >= MAX_ITERS:
            break

        
        stylized_vae = forward(content_image, style_image)[0]
        stylized_ae = forward_no_vae(content_image, style_image)[0]
            
        tb_writer.add_image('vae vs ae', 
                            torchvision.utils.make_grid(torch.cat(
                                [content_image, style_image, stylized_ae, stylized_vae], dim=0
                            ), nrow=content_image.size(0)),
                            iteration)
        iteration += 1
        print(f'\r{iteration}', end='\r')

KeyboardInterrupt: 

# Compare Autoencoder to Variational Autoencoder on test styles

In [18]:
MAX_ITERS = 100
with torch.no_grad():
    iteration = 0
    unet.eval()
    style_encoder.eval()
    style_encoder_vae.eval()
    for (content_image, content_path), (style_image, style_path) in data_loader_test:

        if iteration >= MAX_ITERS:
            break

        
        stylized_vae = forward(content_image, style_image)[0]
        stylized_ae = forward_no_vae(content_image, style_image)[0]
            
        tb_writer.add_image('vae vs ae test', 
                            torchvision.utils.make_grid(torch.cat(
                                [content_image, style_image, stylized_ae, stylized_vae], dim=0
                            ), nrow=content_image.size(0)),
                            iteration)
        iteration += 1
        print(f'\r{iteration}', end='\r')

8

KeyboardInterrupt: 