In [17]:
import data, model.unet, model.autoencoder, loss, function
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils
import numpy as np
import torch.utils.tensorboard as tb
import torchvision
import scipy.stats as stats
import pickle
import datetime
import os
from sklearn.decomposition import PCA
from PIL import Image
from scipy import spatial

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
date = datetime.datetime.now()
timestamp = date.strftime(f"ae_and_vae_%d-%b-%Y_%H.%M.%S")
os.makedirs(f"log/{timestamp}")
tb_writer = tb.SummaryWriter(f"log/{timestamp}")

In [3]:
VAL_PORTION = 0.2
ITERATIONS = 100001
VAL_ITERATIONS = 5
VAL_ITERATIONS_OVERFIT = 1
RESOLUTION = 96
CHANNELS = 3
STYLE_DIM = 512

BATCH_SIZE = 16
LOSS_TYPE = 'l2'

CONTENT_LOSS_WEIGHTS = {
    #'relu_1_1' : 1e-2,
    #'relu_4_2' : 5e-3,
    'relu_4_2' : 2e-2,
    #'relu_4_2' : 1e0,
}

STYLE_LOSS_WEIGHTS = {
    'relu_1_1' : 1e3,
    'relu_2_1' : 5e3, # 5e3
    'relu_3_1' : 1e3,
    'relu_4_1' : 1e3, # 1e3
    'relu_5_1' : 1e3,
}

STYLE_LOSS_ALPHA = 1.0
KLD_LOSS_WEIGHT = 5e-5

In [4]:
torch.manual_seed(0)
np.random.seed(0)

TRAINING_PORTION_STYLE=128 # 128

data_style_train = data.load_dataset("../dataset/style_cherrypicked/train", resolution=RESOLUTION)
data_style_test = data.load_dataset("../dataset/style_cherrypicked/test", resolution=RESOLUTION)
data_loader_style_train = DataLoader(data_style_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_style_test = DataLoader(data_style_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_content_test = data.load_debug_dataset('../dataset/content_test', resolution=RESOLUTION)
data_loader_content_test = DataLoader(data_content_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_loader_test_overfit = data.DatasetPairIterator(data_loader_content_test, data_loader_style_train)
data_loader_test = data.DatasetPairIterator(data_loader_content_test, data_loader_style_test)

In [5]:
DOWNUP_CONVOLUTIONS = 5 #3
ADAIN_CONVOLUTIONS = 7 #3
STYLE_DOWN_CONVOLUTIONS = 5 #3
NUM_LAYERS_NO_CONNECTION = 0
RESIDUAL_STYLE = True # False
RESIDUAL_DOWN = True # False
RESIDUAL_ADAIN = True
RESIDUAL_UP = True
STYLE_NORM = True
DOWN_NORM = 'in'
UP_NORM = 'adain'


#normalization = ['adain', 'adain', None, None, None]
unet_vae = model.unet.UNetAutoencoder(3, STYLE_DIM, residual_downsampling=RESIDUAL_DOWN, residual_adain=RESIDUAL_ADAIN, residual_upsampling=RESIDUAL_UP, 
        down_normalization=DOWN_NORM, up_normalization=UP_NORM, num_adain_convolutions=ADAIN_CONVOLUTIONS, 
        num_downup_convolutions=DOWNUP_CONVOLUTIONS, num_downup_without_connections=NUM_LAYERS_NO_CONNECTION, output_activation='sigmoid')
unet = model.unet.UNetAutoencoder(3, 2 * STYLE_DIM, residual_downsampling=RESIDUAL_DOWN, residual_adain=RESIDUAL_ADAIN, residual_upsampling=RESIDUAL_UP, 
        down_normalization=DOWN_NORM, up_normalization=UP_NORM, num_adain_convolutions=ADAIN_CONVOLUTIONS, 
        num_downup_convolutions=DOWNUP_CONVOLUTIONS, num_downup_without_connections=NUM_LAYERS_NO_CONNECTION, output_activation='sigmoid')


style_encoder_vae = model.autoencoder.Encoder(2 * STYLE_DIM, normalization=STYLE_NORM, residual=RESIDUAL_STYLE, num_down_convolutions=STYLE_DOWN_CONVOLUTIONS)
style_encoder = model.autoencoder.Encoder(2 * STYLE_DIM, normalization=STYLE_NORM, residual=RESIDUAL_STYLE, num_down_convolutions=STYLE_DOWN_CONVOLUTIONS)

In [6]:
# Load pre-trained model for vae
state_dict = torch.load('./models/model.pt', map_location=lambda storage, loc: storage)
unet_vae.load_state_dict(state_dict['unet_state_dict'])
style_encoder_vae.load_state_dict(state_dict['style_encoder_state_dict'])

<All keys matched successfully>

In [7]:
# Load pre-trained  model for ae
state_dict_ae = torch.load('./models/model_non_vae.pt', map_location=lambda storage, loc: storage)
unet.load_state_dict(state_dict_ae['unet_state_dict'])
style_encoder.load_state_dict(state_dict_ae['style_encoder_state_dict'])


<All keys matched successfully>

# Embed all training styles for the ae and vae model

In [9]:
# Embed all training styles. This way, we can find the style most similar to the sampled style
style_images, embeddings, embeddings_vae = [], [], []
style_encoder.cuda()
style_encoder.eval()
style_encoder_vae.cuda()
style_encoder_vae.eval()

with torch.no_grad():
    for idx, (style_image, _) in enumerate(data_style_train):
        print(f'\r{idx}...', end='\r')
        style_images.append(style_image)
        style_image = style_image.unsqueeze(0).to('cuda')
        
        embedding = style_encoder(style_image)
        embedding_vae_stats = style_encoder_vae(style_image)
        embedding_vae = embedding_vae_stats[..., : STYLE_DIM]
        embeddings.append(embedding)
        embeddings_vae.append(embedding_vae)

99...

In [10]:
style_images = torch.stack(style_images, dim=0)
embeddings = torch.cat(embeddings, dim=0)
embeddings_vae = torch.cat(embeddings_vae, dim=0)

In [16]:
# Remove encoders form gpu... :(
style_encoder.cpu()
style_encoder_vae.cpu()
None

In [18]:
# Build KD-Trees for the embeddings
kd_embeddings = spatial.KDTree(embeddings.detach().cpu().numpy())
kd_embeddings_vae = spatial.KDTree(embeddings_vae.detach().cpu().numpy())

In [19]:
embeddings.size(), embeddings_vae.size()

(torch.Size([100, 1024]), torch.Size([100, 512]))

# Sample styles for the ae and vae model and find the closest training style

In [24]:
def closest_style_images(kd_tree, style):
    """ Finds images that are closest to a style.
    
    Parameters:
    -----------
    kd_tree : scipy.spatial.KDTree
        The kd tree that stores all embeddings.
    style : torch.Tensor, shape [batch_size, style_dim]
        The styles to find nearest neighbour of.
    
    Returns:
    --------
    images : torch.Tensor, shape [batch_size, 3, H, W]
        Style images that are closest to the style in style space.
    """
    idxs = []
    for idx in range(style.size(0)):
        distance, image_idx = kd_tree.query(style[idx].detach().cpu().numpy())
        idxs.append(image_idx)
    return style_images[idxs]

In [26]:
MAX_ITERS = 100
with torch.no_grad():
    iteration = 0
    unet.eval()
    unet_vae.eval()
    for (content_image, content_path), (style_image, style_path) in data_loader_test_overfit:

        if iteration >= MAX_ITERS:
            break
            
        style = torch.randn((content_image.size(0), 2 * STYLE_DIM), device=content_image.device, requires_grad=False)
        style_vae = torch.randn((content_image.size(0), STYLE_DIM), device=content_image.device, requires_grad=False)
        stylized = unet(content_image, style)
        stylized_vae = unet_vae(content_image, style_vae)

        # Find closest style image
        closest = closest_style_images(kd_embeddings, style)
        closest_vae = closest_style_images(kd_embeddings_vae, style_vae)            
        tb_writer.add_image('style samples and closest', 
                            torchvision.utils.make_grid(torch.cat(
                                [content_image, stylized, closest, stylized_vae, closest_vae], dim=0
                            ), nrow=content_image.size(0)),
                            iteration)
        iteration += 1
        print(f'\r{iteration}', end='\r')

74

KeyboardInterrupt: 