In [None]:
import data, model.unet, model.autoencoder, loss, function
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import torch.utils.tensorboard as tb
import torchvision
import scipy.stats as stats
import pickle
import datetime
import os

%load_ext autoreload
%autoreload 2

In [None]:
VAE = True

prefix = "vae"
if not VAE:
    prefix = "no_vae"
    
date = datetime.datetime.now()
timestamp = date.strftime(f"{prefix}_%d-%b-%Y_%H.%M.%S")
os.makedirs(f"log/{timestamp}")
tb_writer = tb.SummaryWriter(f"log/{timestamp}")

In [None]:
DEBUG=False

In [None]:
VAL_PORTION = 0.2
ITERATIONS = 100001
VAL_ITERATIONS = 5
VAL_ITERATIONS_OVERFIT = 1
RESOLUTION = 96
CHANNELS = 3
STYLE_DIM = 512

BATCH_SIZE = 16
LOSS_TYPE = 'l2'

CONTENT_LOSS_WEIGHTS = {
    'relu_4_2' : 2e-2,
}

STYLE_LOSS_WEIGHTS = {
    'relu_1_1' : 1e3,
    'relu_2_1' : 5e3,
    'relu_3_1' : 1e3,
    'relu_4_1' : 1e3,
    'relu_5_1' : 1e3,
}

STYLE_LOSS_ALPHA = 1.0
KLD_LOSS_WEIGHT = 5e-5

In [None]:
CHERRYPICKED_DATASET_100 = True

bad_dirs = [
    "Baroque",
    "Contemporary_Realism",
    "Early_Renaissance",
    "High_Renaissance",
    "Mannerism_Late_Renaissance",
    "New_Realism",
    "Northern_Renaissance",
    "Realism",
    "Rococo",
    "Impressionism",
    "Minimalism",
    "Pointillism",
    "Pop_Art",
    "Romanticism",
    "Symbolism"
]

torch.manual_seed(0)
np.random.seed(0)

TRAINING_PORTION_STYLE=128

if CHERRYPICKED_DATASET_100:
    data_style_train = data.load_dataset("../dataset/style_cherrypicked/train", resolution=RESOLUTION)
    data_style_val = data.load_dataset("../dataset/style_cherrypicked/validation", resolution=RESOLUTION)
else:
    paths = data.list_images("../dataset/style")
    filtered_paths = data.filter_images(paths, bad_dirs, "style_images_no_face.pkl", "../dataset/style/wikiart")

    data_style = data.load_dataset_from_list(filtered_paths, resolution=RESOLUTION)
    data_style, _ = torch.utils.data.random_split(data_style, [TRAINING_PORTION_STYLE, len(data_style) - TRAINING_PORTION_STYLE])
    data_style_train, data_style_val = torch.utils.data.random_split(data_style, [len(data_style) - int(VAL_PORTION * len(data_style)), int(VAL_PORTION * len(data_style))])
    
data_loader_style_train = DataLoader(data_style_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_style_val = DataLoader(data_style_val, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

TRAINING_PORTION=2048
data_content = data.load_debug_dataset('../dataset/content', resolution=RESOLUTION)
data_content, _ = torch.utils.data.random_split(data_content, [TRAINING_PORTION, len(data_content) - TRAINING_PORTION])
data_content_train, data_content_val = torch.utils.data.random_split(data_content, [len(data_content) - int(VAL_PORTION * len(data_content)), int(VAL_PORTION * len(data_content))])
data_loader_content_train = DataLoader(data_content_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_content_val = data.load_debug_dataset('../dataset/debug/content', resolution=RESOLUTION)
data_loader_content_val = DataLoader(data_content_val, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_loader_train = data.DatasetPairIterator(data_loader_content_train, data_loader_style_train)
data_loader_val = data.DatasetPairIterator(data_loader_content_val, data_loader_style_val)

# NO REAL VALIDATION, USES TRAINING STYLE DATA
data_loader_val_overfit = data.DatasetPairIterator(data_loader_content_val, data_loader_style_train)

In [None]:
DOWNUP_CONVOLUTIONS = 5
ADAIN_CONVOLUTIONS = 7
STYLE_DOWN_CONVOLUTIONS = 5
NUM_LAYERS_NO_CONNECTION = 0
RESIDUAL_STYLE = True
RESIDUAL_DOWN = True
RESIDUAL_ADAIN = True
RESIDUAL_UP = True
STYLE_NORM = True
DOWN_NORM = 'in'
UP_NORM = 'adain'

if not VAE:
    STYLE_DIM = STYLE_DIM * 2

unet = model.unet.UNetAutoencoder(3, STYLE_DIM, residual_downsampling=RESIDUAL_DOWN, residual_adain=RESIDUAL_ADAIN, residual_upsampling=RESIDUAL_UP, 
        down_normalization=DOWN_NORM, up_normalization=UP_NORM, num_adain_convolutions=ADAIN_CONVOLUTIONS, 
        num_downup_convolutions=DOWNUP_CONVOLUTIONS, num_downup_without_connections=NUM_LAYERS_NO_CONNECTION, output_activation='sigmoid')

if VAE:
    style_encoder = model.autoencoder.Encoder(2 * STYLE_DIM, normalization=STYLE_NORM, residual=RESIDUAL_STYLE, num_down_convolutions=STYLE_DOWN_CONVOLUTIONS)
else:
    style_encoder = model.autoencoder.Encoder(STYLE_DIM, normalization=STYLE_NORM, residual=RESIDUAL_STYLE, num_down_convolutions=STYLE_DOWN_CONVOLUTIONS)

loss_net = loss.LossNet()
_ = loss_net.eval()

In [None]:
if torch.cuda.is_available(): 
    unet = unet.cuda()
    style_encoder = style_encoder.cuda()
    loss_net = loss_net.cuda()

trainable_parameters = []
for parameter in unet.parameters():
    trainable_parameters.append(parameter)
for parameter in style_encoder.parameters():
    trainable_parameters.append(parameter)

In [None]:
optimizer = torch.optim.Adam(trainable_parameters, lr=5e-5)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True, min_lr=1e-6)

In [None]:
def forward(content_image, style_image):
    """ Forward pass through the architecture.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_image : torch.Tensor, shape [batch_size, 3, H, W]
        The style images.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_encoding : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    style_encoding_mean : torch.Tensor, shape [batch_size, STYLE_DIM]
        Means for the style encodings.
    style_encoding_logvar : torch.Tensor, shape [batch_size, STYLE_DIM]
        Logarithm of the variances of the style encodings.
    """
    style_stats = style_encoder(style_image)
    style_mean = style_stats[..., : STYLE_DIM]
    style_logvar = style_stats[..., STYLE_DIM : ]
    style_sample = function.sample_normal(style_mean, style_logvar)
    stylized = unet(content_image, style_sample)
    return stylized, style_sample, style_mean, style_logvar

def forward_no_vae(content_image, style_image):
    """ Forward pass through the architecture, does not use the variational part.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_image : torch.Tensor, shape [batch_size, 3, H, W]
        The style images.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_encoding : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    """
    style_encoding = style_encoder(style_image)
    stylized = unet(content_image, style_encoding)
    return stylized, style_encoding

def forward_interpolate(content_image, style_image1, style_image2, interpolation_factor):
    """ Forward pass through the architecture, creates style interpolations.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_image1 : torch.Tensor, shape [batch_size, 3, H, W]
        The first style images.
    style_image2 : torch.Tensor, shape [batch_size, 3, H, W]
        The second style images (for interpolation).
    interpolation_factor : float
        Value between 0 and 1, decides how to weight style_image1 and style_image2.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_sample : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    """
    style_stats1 = style_encoder(style_image1)
    style_mean1 = style_stats1[..., : STYLE_DIM]
    style_logvar1 = style_stats1[..., STYLE_DIM : ]
    style_sample1 = function.sample_normal(style_mean1, style_logvar1)
    
    style_stats2 = style_encoder(style_image2)
    style_mean2 = style_stats2[..., : STYLE_DIM]
    style_logvar2 = style_stats2[..., STYLE_DIM : ]
    style_sample2 = function.sample_normal(style_mean2, style_logvar2)
    
    style_sample = interpolation_factor * style_sample1 + (1 - interpolation_factor) * style_sample2
    
    stylized = unet(content_image, style_sample)
    return stylized, style_sample

def forward_sample(content_image):
    """ Forward pass through the architecture, uses sampling to create random styles.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_sample : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    """
    style_sample = torch.randn((BATCH_SIZE, STYLE_DIM), device=content_image.device, requires_grad=False)
    stylized = unet(content_image, style_sample)
    return stylized, style_sample

def forward_sample_from_embedding(content_image, style_mean, style_logvar):
    """ Forward pass through the architecture. Samples a style from a given mean and logvar.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_mean : torch.Tensor, shape [batch_size, STYLE_DIM]
        The style embedding mean.
    style_logvar : torch.Tensor, shape [batch_size, STYLE_DIM]
        The style embedding logvar.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_sample : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    """
    style_sample = function.sample_normal(style_mean, style_logvar)
    stylized = unet(content_image, style_sample)
    return stylized, style_sample

def forward_interpolate_no_vae(content_image, style_image1, style_image2, interpolation_factor):
    """ Forward pass through the architecture without the variational part, creates style interpolations.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_image1 : torch.Tensor, shape [batch_size, 3, H, W]
        The first style images.
    style_image2 : torch.Tensor, shape [batch_size, 3, H, W]
        The second style images (for interpolation).
    interpolation_factor : float
        Value between 0 and 1, decides how to weight style_image1 and style_image2.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_embedding : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    """
    style_embedding1 = style_encoder(style_image1)
    style_embedding2 = style_encoder(style_image2)
    
    style_embedding = interpolation_factor * style_embedding1 + (1 - interpolation_factor) * style_embedding2
    
    stylized = unet(content_image, style_embedding)
    return stylized, style_embedding


In [None]:
parameters = {}
parameters["RESOLUTION"] = RESOLUTION
parameters["STYLE_DIM"] = STYLE_DIM
parameters["BATCH_SIZE"] = BATCH_SIZE
parameters["LOSS_TYPE"] = LOSS_TYPE
parameters["CONTENT_LOSS_WEIGHTS"] = CONTENT_LOSS_WEIGHTS
parameters["STYLE_LOSS_WEIGHTS"] = STYLE_LOSS_WEIGHTS
parameters["STYLE_LOSS_ALPHA"] = STYLE_LOSS_ALPHA
parameters["KLD_LOSS_WEIGHT"] = KLD_LOSS_WEIGHT
parameters["DOWNUP_CONVOLUTIONS"] = DOWNUP_CONVOLUTIONS
parameters["ADAIN_CONVOLUTIONS"] = ADAIN_CONVOLUTIONS
parameters["STYLE_DOWN_CONVOLUTIONS"] = STYLE_DOWN_CONVOLUTIONS
parameters["NUM_LAYERS_NO_CONNECTION"] = NUM_LAYERS_NO_CONNECTION
parameters["TRAINING_PORTION_STYLE"] = TRAINING_PORTION_STYLE
parameters["TRAINING_PORTION"] = TRAINING_PORTION
parameters["VAL_PORTION"] = VAL_PORTION
parameters["RESIDUAL_STYLE"] = RESIDUAL_STYLE
parameters["RESIDUAL_DOWN"] = RESIDUAL_DOWN
parameters["RESIDUAL_ADAIN"] = RESIDUAL_ADAIN
parameters["RESIDUAL_UP"] = RESIDUAL_UP
parameters["STYLE_NORM"] = STYLE_NORM
parameters["DOWN_NORM"] = DOWN_NORM
parameters["UP_NORM"] = UP_NORM
parameters["CHERRYPICKED_DATASET_100"] = CHERRYPICKED_DATASET_100
parameters["VAE"] = VAE

tb_writer.add_text("parameters", str(parameters))

In [None]:
iteration = 0
val_step = 0
val_step_overfit = 0
for (content_image, content_path), (style_image, style_path) in data_loader_train:
    if iteration >= ITERATIONS: 
        break
    if torch.cuda.is_available():
        content_image = content_image.to('cuda')
        style_image = style_image.to('cuda')
    
    unet.train(), style_encoder.train()
    optimizer.zero_grad()
    
    stylized, style_encoding, style_mean, style_logvar = forward(content_image, style_image)
    
    features_content = loss_net(content_image)
    features_style = loss_net(style_image)
    features_stylized = loss_net(stylized)
    
    perceptual_loss = loss.perceptual_loss(features_content, features_stylized, CONTENT_LOSS_WEIGHTS, loss=LOSS_TYPE)
    style_loss = loss.style_loss(features_style, features_stylized, STYLE_LOSS_WEIGHTS, loss=LOSS_TYPE)
    kld_loss = loss.kld_loss(style_mean, style_logvar)
    total_loss = perceptual_loss + STYLE_LOSS_ALPHA * style_loss + KLD_LOSS_WEIGHT * kld_loss
    total_loss.backward()
    optimizer.step()
    
    tb_writer.add_scalar('train loss', total_loss.item(), iteration)
    tb_writer.add_scalar('train perceptual loss', perceptual_loss.item(), iteration)
    tb_writer.add_scalar('train kld loss', kld_loss.item(), iteration)
    tb_writer.add_scalar('train style loss', style_loss.item(), iteration)
    print(f'\r{iteration:5d} / {ITERATIONS}: loss : {total_loss.item():.4f} -- perceptual loss : {perceptual_loss.item():.4f} -- style loss : {style_loss.item():.4f} -- kld loss : {kld_loss.item():.4f}', end='\r')
    
    if iteration % 1000 == 0:
        torch.save({
            'unet_state_dict': unet.state_dict(),
            'style_encoder_state_dict': style_encoder.state_dict(),
        }, f"log/{timestamp}/model_{iteration}.pt")
    
    if iteration % 100 == 0:
        # Validate
        print('\nValidation...')
        total_val_loss = 0.0
        with torch.no_grad():
            val_iteration = 0
            val_iteration_overfit = 0
            unet.eval(), style_encoder.eval()
            fig = plt.figure(figsize=(20,10))
            for (content_image, content_path), (style_image, style_path) in data_loader_val:
                
                if val_iteration >= VAL_ITERATIONS:
                    break
                    
                if torch.cuda.is_available():
                    content_image = content_image.to('cuda')
                    style_image = style_image.to('cuda')
                
                stylized, style_encoding, style_mean, style_logvar = forward(content_image, style_image)
    
                features_content = loss_net(content_image)
                features_style = loss_net(style_image)
                features_stylized = loss_net(stylized)

                perceptual_loss = loss.perceptual_loss(features_content, features_stylized, CONTENT_LOSS_WEIGHTS, loss=LOSS_TYPE)
                style_loss = loss.style_loss(features_style, features_stylized, STYLE_LOSS_WEIGHTS, loss=LOSS_TYPE)
                kld_loss = loss.kld_loss(style_mean, style_logvar)
                total_loss = perceptual_loss + STYLE_LOSS_ALPHA * style_loss + KLD_LOSS_WEIGHT * kld_loss
                total_val_loss += total_loss
                
                fig.add_subplot(2, VAL_ITERATIONS, val_iteration + 1)
                x = np.linspace(-5, 5, 100)
                plt.plot(x, stats.norm.pdf(x), color='red', linestyle='dashed')
                plt.hist(style_encoding.detach().cpu().numpy().reshape(-1), density=True)
                fig.add_subplot(2, VAL_ITERATIONS, VAL_ITERATIONS + val_iteration + 1)
                plt.hist(style_encoding.detach().cpu().numpy().std(0), density=True)
                
                tb_writer.add_scalar('validation loss', total_loss.item(), val_step)
                tb_writer.add_scalar('validation perceptual loss', perceptual_loss.item(), val_step)
                tb_writer.add_scalar('validation style loss', style_loss.item(), val_step)
                tb_writer.add_scalar('validation kld loss', kld_loss.item(), val_step)
                tb_writer.add_images('validation images', torch.from_numpy(np.concatenate([
                    img.detach().cpu().numpy() for img in [content_image, style_image, stylized]
                ])), val_step)
                val_iteration += 1
                val_step += 1
                print(f'\r{val_iteration:5d} / {VAL_ITERATIONS}: loss : {total_loss.item():.4f} -- perceptual loss : {perceptual_loss.item():.4f} -- style loss : {style_loss.item():.4f} -- kld loss : {kld_loss.item():.4f}', end='\r')
            
            # generate stylization from training styles
            for (content_image, content_path), (style_image, style_path) in data_loader_val_overfit:
                if val_iteration_overfit >= VAL_ITERATIONS_OVERFIT:
                    break
                    
                if torch.cuda.is_available():
                    content_image = content_image.to('cuda')
                    style_image = style_image.to('cuda')
                
                stylized, style_encoding, style_mean, style_logvar = forward(content_image, style_image)
    
                features_content = loss_net(content_image)
                features_style = loss_net(style_image)
                features_stylized = loss_net(stylized)
                
                tb_writer.add_images('overfit images', torch.from_numpy(np.concatenate([
                    img.detach().cpu().numpy() for img in [content_image, style_image, stylized]
                ])), val_step)
                val_iteration_overfit += 1
                val_step_overfit += 1
            
            plt.show()
            total_val_loss /= VAL_ITERATIONS
            print(f'\nAverage val loss: {total_val_loss}')
            lr_scheduler.step(total_val_loss)
            print(f'Training with lr {optimizer.param_groups[0]["lr"]}...')
            
            
    iteration += 1
    
    

# -------------------------------------------------------------------------------
# The following part of the Notebook is for evaluation and to create visualizations of the results


# Load previous model

In [None]:
LOAD_PATH = "log_alphas/alpha_1_vae_25-Jan-2020_16.40.00/model_29000.pt"

DOWNUP_CONVOLUTIONS = 5 #3
ADAIN_CONVOLUTIONS = 7 #3
STYLE_DOWN_CONVOLUTIONS = 5 #3
NUM_LAYERS_NO_CONNECTION = 0
RESIDUAL_STYLE = True # False
RESIDUAL_DOWN = True # False
RESIDUAL_ADAIN = True
RESIDUAL_UP = True
STYLE_NORM = True
DOWN_NORM = 'in'
UP_NORM = 'adain'

VAE = True

unet = model.unet.UNetAutoencoder(3, STYLE_DIM, residual_downsampling=RESIDUAL_DOWN, residual_adain=RESIDUAL_ADAIN, residual_upsampling=RESIDUAL_UP, 
        down_normalization=DOWN_NORM, up_normalization=UP_NORM, num_adain_convolutions=ADAIN_CONVOLUTIONS, 
        num_downup_convolutions=DOWNUP_CONVOLUTIONS, num_downup_without_connections=NUM_LAYERS_NO_CONNECTION, output_activation='sigmoid')

if VAE:
    style_encoder = model.autoencoder.Encoder(2 * STYLE_DIM, normalization=STYLE_NORM, residual=RESIDUAL_STYLE, num_down_convolutions=STYLE_DOWN_CONVOLUTIONS)
else:
    style_encoder = model.autoencoder.Encoder(STYLE_DIM, normalization=STYLE_NORM, residual=RESIDUAL_STYLE, num_down_convolutions=STYLE_DOWN_CONVOLUTIONS)

loss_net = loss.LossNet()
_ = loss_net.eval()

checkpoint = torch.load(LOAD_PATH)
unet.load_state_dict(checkpoint["unet_state_dict"])
style_encoder.load_state_dict(checkpoint["style_encoder_state_dict"])


if torch.cuda.is_available(): 
    unet = unet.cuda()
    style_encoder = style_encoder.cuda()
    loss_net = loss_net.cuda()

# Sampling random styles

In [None]:
SAMPLING_ITERATIONS = 100
with torch.no_grad():
    val_iteration = 0
    val_iteration_overfit = 0
    unet.eval(), style_encoder.eval()
    fig = plt.figure(figsize=(20,10))
    for (content_image, content_path), (style_image, style_path) in data_loader_val:

        if val_iteration >= SAMPLING_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image.to('cuda')
            style_image = style_image.to('cuda')

        stylized, style_encoding = forward_sample(content_image)

        features_content = loss_net(content_image)
        features_style = loss_net(style_image)
        features_stylized = loss_net(stylized)

        tb_writer.add_images('randomly sampled styles', torch.from_numpy(np.concatenate([
            img.detach().cpu().numpy() for img in [content_image, stylized]
        ])), val_iteration)
        val_iteration += 1


# Style Interpolation

In [None]:
INTERPOLATION_ITERATIONS = 30
INTERPOLOATION_FACTORS = [1.0, .75, .5, .25, 0.0]

with torch.no_grad():
    val_iteration = 0
    unet.eval(), style_encoder.eval()
    fig = plt.figure(figsize=(20,10))
    for (content_image, content_path), (style_image1, style_path) in data_loader_val:

        _, (style_image2, style_path) = next(iter(data_loader_val))
        
        if val_iteration >= INTERPOLATION_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image.to('cuda')
            style_image1 = style_image1.to('cuda')
            style_image2 = style_image2.to('cuda')
        
        interpolations = []
        for interpolation in INTERPOLOATION_FACTORS:
            stylized, style_encoding = forward_interpolate(content_image, style_image1, style_image2, interpolation)
            interpolations.append(stylized[:8])
        
        img_list = [content_image[:8], style_image1[:8]] + interpolations + [style_image2[:8]]
        
        tb_writer.add_images('interpolation images validation', torch.from_numpy(np.concatenate([
            img.detach().cpu().numpy() for img in img_list
        ])), val_iteration)
        val_iteration += 1


# Style interpolation overfit

In [None]:
INTERPOLATION_ITERATIONS = 30
INTERPOLOATION_FACTORS = [1.0, .75, .5, .25, 0.0]

with torch.no_grad():
    val_iteration = 0
    unet.eval(), style_encoder.eval()
    fig = plt.figure(figsize=(20,10))
    for (content_image, content_path), (style_image1, style_path) in data_loader_val_overfit:

        _, (style_image2, style_path) = next(iter(data_loader_val_overfit))
        
        if val_iteration >= INTERPOLATION_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image.to('cuda')
            style_image1 = style_image1.to('cuda')
            style_image2 = style_image2.to('cuda')
        
        interpolations = []
        for interpolation in INTERPOLOATION_FACTORS:
            stylized, style_encoding = forward_interpolate(content_image, style_image1, style_image2, interpolation)
            interpolations.append(stylized[:8])
        
        img_list = [content_image[:8], style_image1[:8]] + interpolations + [style_image2[:8]]
        
        tb_writer.add_images('interpolation images overfit', torch.from_numpy(np.concatenate([
            img.detach().cpu().numpy() for img in img_list
        ])), val_iteration)
        val_iteration += 1

# Training without VAE

In [None]:
parameters['VAE'] = False

iteration = 0
val_step = 0
val_step_overfit = 0
for (content_image, content_path), (style_image, style_path) in data_loader_train:
    if iteration >= ITERATIONS: 
        break
    if torch.cuda.is_available():
        content_image = content_image.to('cuda')
        style_image = style_image.to('cuda')
    
    unet.train(), style_encoder.train()
    optimizer.zero_grad()
    
    stylized, style_encoding = forward_no_vae(content_image, style_image)
    
    features_content = loss_net(content_image)
    features_style = loss_net(style_image)
    features_stylized = loss_net(stylized)
    
    perceptual_loss = loss.perceptual_loss(features_content, features_stylized, CONTENT_LOSS_WEIGHTS, loss=LOSS_TYPE)
    style_loss = loss.style_loss(features_style, features_stylized, STYLE_LOSS_WEIGHTS, loss=LOSS_TYPE)
    total_loss = perceptual_loss + STYLE_LOSS_ALPHA * style_loss
    total_loss.backward()
    optimizer.step()
    
    tb_writer.add_scalar('train loss', total_loss.item(), iteration)
    tb_writer.add_scalar('train perceptual loss', perceptual_loss.item(), iteration)
    tb_writer.add_scalar('train style loss', style_loss.item(), iteration)
    print(f'\r{iteration:5d} / {ITERATIONS}: loss : {total_loss.item():.4f} -- perceptual loss : {perceptual_loss.item():.4f} -- style loss : {style_loss.item():.4f}', end='\r')
    
    if iteration % 1000 == 0:
        torch.save({
            'unet_state_dict': unet.state_dict(),
            'style_encoder_state_dict': style_encoder.state_dict(),
        }, f"log/{timestamp}/model_{iteration}.pt")
    
    if iteration % 100 == 0:
        # Validate
        print('\nValidation...')
        total_val_loss = 0.0
        with torch.no_grad():
            val_iteration = 0
            val_iteration_overfit = 0
            unet.eval(), style_encoder.eval()
            fig = plt.figure(figsize=(20,10))
            for (content_image, content_path), (style_image, style_path) in data_loader_val:
                
                if val_iteration >= VAL_ITERATIONS:
                    break
                    
                if torch.cuda.is_available():
                    content_image = content_image.to('cuda')
                    style_image = style_image.to('cuda')
                
                stylized, style_encoding = forward_no_vae(content_image, style_image)
    
                features_content = loss_net(content_image)
                features_style = loss_net(style_image)
                features_stylized = loss_net(stylized)

                perceptual_loss = loss.perceptual_loss(features_content, features_stylized, CONTENT_LOSS_WEIGHTS, loss=LOSS_TYPE)
                style_loss = loss.style_loss(features_style, features_stylized, STYLE_LOSS_WEIGHTS, loss=LOSS_TYPE)
                total_loss = perceptual_loss + STYLE_LOSS_ALPHA * style_loss
                total_val_loss += total_loss
                
                fig.add_subplot(2, VAL_ITERATIONS, val_iteration + 1)
                x = np.linspace(-5, 5, 100)
                plt.plot(x, stats.norm.pdf(x), color='red', linestyle='dashed')
                plt.hist(style_encoding.detach().cpu().numpy().reshape(-1), density=True)
                fig.add_subplot(2, VAL_ITERATIONS, VAL_ITERATIONS + val_iteration + 1)
                plt.hist(style_encoding.detach().cpu().numpy().std(0), density=True)
                
                tb_writer.add_scalar('validation loss', total_loss.item(), val_step)
                tb_writer.add_scalar('validation perceptual loss', perceptual_loss.item(), val_step)
                tb_writer.add_scalar('validation style loss', style_loss.item(), val_step)
                tb_writer.add_images('validation images', torch.from_numpy(np.concatenate([
                    img.detach().cpu().numpy() for img in [content_image, style_image, stylized]
                ])), val_step)
                val_iteration += 1
                val_step += 1
                print(f'\r{val_iteration:5d} / {VAL_ITERATIONS}: loss : {total_loss.item():.4f} -- perceptual loss : {perceptual_loss.item():.4f} -- style loss : {style_loss.item():.4f}', end='\r')
            
            for (content_image, content_path), (style_image, style_path) in data_loader_val_overfit:
                
                if val_iteration_overfit >= VAL_ITERATIONS_OVERFIT:
                    break
                    
                if torch.cuda.is_available():
                    content_image = content_image.to('cuda')
                    style_image = style_image.to('cuda')
                
                stylized, style_encoding = forward_no_vae(content_image, style_image)
    
                features_content = loss_net(content_image)
                features_style = loss_net(style_image)
                features_stylized = loss_net(stylized)
                
                tb_writer.add_images('overfit images', torch.from_numpy(np.concatenate([
                    img.detach().cpu().numpy() for img in [content_image, style_image, stylized]
                ])), val_step)
                val_iteration_overfit += 1
                val_step_overfit += 1
            
            plt.show()
            total_val_loss /= VAL_ITERATIONS
            print(f'\nAverage val loss: {total_val_loss}')
            lr_scheduler.step(total_val_loss)
            print(f'Training with lr {optimizer.param_groups[0]["lr"]}...')
            
            
    iteration += 1


# Style interpolation overfit no VAE

In [None]:
INTERPOLATION_ITERATIONS = 30
INTERPOLOATION_FACTORS = [1.0, .75, .5, .25, 0.0]

with torch.no_grad():
    val_iteration = 0
    unet.eval(), style_encoder.eval()
    fig = plt.figure(figsize=(20,10))
    for (content_image, content_path), (style_image1, style_path) in data_loader_val_overfit:

        _, (style_image2, style_path) = next(iter(data_loader_val_overfit))
        
        if val_iteration >= INTERPOLATION_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image.to('cuda')
            style_image1 = style_image1.to('cuda')
            style_image2 = style_image2.to('cuda')
        
        interpolations = []
        for interpolation in INTERPOLOATION_FACTORS:
            stylized, style_encoding = forward_interpolate_no_vae(content_image, style_image1, style_image2, interpolation)
            interpolations.append(stylized[:8])
        
        img_list = [content_image[:8], style_image1[:8]] + interpolations + [style_image2[:8]]
        
        tb_writer.add_images('interpolation images overfit', torch.from_numpy(np.concatenate([
            img.detach().cpu().numpy() for img in img_list
        ])), val_iteration)
        val_iteration += 1

# Create survey images

In [None]:
# test dataloader
data_content_test = data.load_debug_dataset('../dataset/test/survey', resolution=RESOLUTION)
data_loader_content_test = DataLoader(data_content_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_style_test = data.load_debug_dataset('../dataset/style_cherrypicked/test/', resolution=RESOLUTION)
data_loader_style_test = DataLoader(data_style_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# dataloader with test content and training style
data_loader_test_overfit = data.DatasetPairIterator(data_loader_content_test, data_loader_style_train)

data_loader_test = data.DatasetPairIterator(data_loader_content_test, data_loader_style_test)

In [None]:
TEST_ITERATIONS=100

# create images
with torch.no_grad():
    test_iteration = 0
    unet.eval(), style_encoder.eval()
    for (content_image, content_path), (style_image, style_path) in data_loader_test_overfit:
                
        if test_iteration >= TEST_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image.to('cuda')
            style_image = style_image.to('cuda')

        stylized, style_encoding, style_mean, style_logvar = forward(content_image, style_image)

        features_content = loss_net(content_image)
        features_style = loss_net(style_image)
        features_stylized = loss_net(stylized)

        tb_writer.add_images('test images', torch.from_numpy(np.concatenate([
            img.detach().cpu().numpy() for img in [content_image, style_image, stylized]
        ])), test_iteration)
        test_iteration += 1


In [None]:
TEST_ITERATIONS=100

# create images
with torch.no_grad():
    test_iteration = 0
    unet.eval(), style_encoder.eval()
    for (content_image, content_path), (style_image, style_path) in data_loader_test:
                
        if test_iteration >= TEST_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image.to('cuda')
            style_image = style_image.to('cuda')

        stylized, style_encoding, style_mean, style_logvar = forward(content_image, style_image)

        features_content = loss_net(content_image)
        features_style = loss_net(style_image)
        features_stylized = loss_net(stylized)

        tb_writer.add_images('test images, test styles', torch.from_numpy(np.concatenate([
            img.detach().cpu().numpy() for img in [content_image, style_image, stylized]
        ])), test_iteration)
        test_iteration += 1

# Style interpolation test content, training style

In [None]:
INTERPOLATION_ITERATIONS = 50
INTERPOLOATION_FACTORS = [1.0, .75, .5, .25, 0.0]

with torch.no_grad():
    val_iteration = 0
    unet.eval(), style_encoder.eval()
    fig = plt.figure(figsize=(20,10))
    for (content_image, content_path), (style_image1, style_path) in data_loader_test_overfit:

        _, (style_image2, style_path) = next(iter(data_loader_test_overfit))
        
        if val_iteration >= INTERPOLATION_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image.to('cuda')
            style_image1 = style_image1.to('cuda')
            style_image2 = style_image2.to('cuda')
        
        interpolations = []
        for interpolation in INTERPOLOATION_FACTORS:
            stylized, style_encoding = forward_interpolate(content_image, style_image1, style_image2, interpolation)
            interpolations.append(stylized)
            
        interpolations = torch.stack(interpolations, dim=1).unsqueeze(2)
        img_list = [torch.cat([content_image[i].unsqueeze(0), style_image1[i].unsqueeze(0), *interpolations[i], style_image2[i].unsqueeze(0)], dim=0) for i in range(BATCH_SIZE)]
        
        interpolation_tensor = torch.cat(img_list, dim=0)
        
        grid = torchvision.utils.make_grid(interpolation_tensor, nrow=8)
        tb_writer.add_image("vae style interpolation grid", grid, val_iteration)
        
        val_iteration += 1

# Style interpolation test content, test styles

In [None]:
INTERPOLATION_ITERATIONS = 100
INTERPOLOATION_FACTORS = [1.0, .75, .5, .25, 0.0]

with torch.no_grad():
    val_iteration = 0
    unet.eval(), style_encoder.eval()
    fig = plt.figure(figsize=(20,10))
    for (content_image, content_path), (style_image1, style_path) in data_loader_test:

        _, (style_image2, style_path) = next(iter(data_loader_test))
        
        if val_iteration >= INTERPOLATION_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image.to('cuda')
            style_image1 = style_image1.to('cuda')
            style_image2 = style_image2.to('cuda')
        
        interpolations = []
        for interpolation in INTERPOLOATION_FACTORS:
            stylized, style_encoding = forward_interpolate(content_image, style_image1, style_image2, interpolation)
            interpolations.append(stylized[:8])
        
        img_list = [content_image[:8], style_image1[:8]] + interpolations + [style_image2[:8]]
        
        tb_writer.add_images('interpolation images test content, test style', torch.from_numpy(np.concatenate([
            img.detach().cpu().numpy() for img in img_list
        ])), val_iteration)
        val_iteration += 1

# Survey images no VAE

In [None]:
TEST_ITERATIONS=100

# test dataloader
data_content_test = data.load_debug_dataset('../dataset/test/survey', resolution=RESOLUTION)
data_loader_content_test = DataLoader(data_content_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# dataloader with test content and training style
data_loader_test_overfit = data.DatasetPairIterator(data_loader_content_test, data_loader_style_train)

# create images
with torch.no_grad():
    test_iteration = 0
    unet.eval(), style_encoder.eval()
    for (content_image, content_path), (style_image, style_path) in data_loader_test_overfit:
                
        if test_iteration >= TEST_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image.to('cuda')
            style_image = style_image.to('cuda')

        stylized, style_encoding = forward_no_vae(content_image, style_image)

        features_content = loss_net(content_image)
        features_style = loss_net(style_image)
        features_stylized = loss_net(stylized)

        tb_writer.add_images('test images no vae', torch.from_numpy(np.concatenate([
            img.detach().cpu().numpy() for img in [content_image, style_image, stylized]
        ])), test_iteration)
        test_iteration += 1

# Sampling from style embedding

In [None]:
data_style_debug = data.load_debug_dataset('../dataset/debug/style_test3', resolution=RESOLUTION)
data_loader_style_debug = DataLoader(data_style_debug, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_loader_debug = data.DatasetPairIterator(data_loader_content_test, data_loader_style_debug)

TEST_ITERATIONS=1
SAMPLES=16

fig = plt.figure(figsize=(20,10))

# create images
with torch.no_grad():
    test_iteration = 0
    unet.eval(), style_encoder.eval()
    for (content_image, content_path), (style_image, style_path) in data_loader_debug:
                
        if test_iteration >= TEST_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image.to('cuda')
            style_image = style_image.to('cuda')

        stylized, style_encoding, style_mean, style_logvar = forward(content_image, style_image)
        
        stylized_samples = []
        for s in range(SAMPLES):
            stylized_sample, _ = forward_sample_from_embedding(content_image, style_mean, style_logvar)
            stylized_samples.append(stylized_sample)
            
        stylized_samples = torch.cat(stylized_samples, dim=0)

        grid = torchvision.utils.make_grid(torch.cat([content_image, style_image, stylized, stylized_samples], dim=0), nrow=16)
        
        tb_writer.add_image('samples from style embedding', grid, test_iteration)
        test_iteration += 1