In [2]:
import data, model.unet, model.autoencoder, loss, function
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils
import numpy as np
import torch.utils.tensorboard as tb
import torchvision
import scipy.stats as stats
import pickle
import datetime
import os
from sklearn.decomposition import PCA
from PIL import Image

%load_ext autoreload
%autoreload 2

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [4]:
date = datetime.datetime.now()
timestamp = date.strftime(f"{prefix}_%d-%b-%Y_%H.%M.%S")
os.makedirs(f"log/{timestamp}")
tb_writer = tb.SummaryWriter(f"log/{timestamp}")

In [5]:
DEBUG=False

In [6]:
VAL_PORTION = 0.2
ITERATIONS = 100001
VAL_ITERATIONS = 5
VAL_ITERATIONS_OVERFIT = 1
RESOLUTION = 96
CHANNELS = 3
STYLE_DIM = 512

BATCH_SIZE = 5
LOSS_TYPE = 'l2'

CONTENT_LOSS_WEIGHTS = {
    #'relu_1_1' : 1e-2,
    #'relu_4_2' : 5e-3,
    'relu_4_2' : 2e-2,
    #'relu_4_2' : 1e0,
}

STYLE_LOSS_WEIGHTS = {
    'relu_1_1' : 1e3,
    'relu_2_1' : 5e3, # 5e3
    'relu_3_1' : 1e3,
    'relu_4_1' : 1e3, # 1e3
    'relu_5_1' : 1e3,
}

STYLE_LOSS_ALPHA = 1.0
KLD_LOSS_WEIGHT = 5e-5

In [7]:
torch.manual_seed(0)
np.random.seed(0)

TRAINING_PORTION_STYLE=128 # 128

data_style_train = data.load_dataset("../dataset/style_cherrypicked/train", resolution=RESOLUTION)
data_style_test = data.load_dataset("../dataset/style_cherrypicked/test", resolution=RESOLUTION)
data_loader_style_train = DataLoader(data_style_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_style_test = DataLoader(data_style_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_content_test = data.load_debug_dataset('../dataset/content_test', resolution=RESOLUTION)
data_loader_content_test = DataLoader(data_content_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_loader_test_overfit = data.DatasetPairIterator(data_loader_content_test, data_loader_style_train)
data_loader_test = data.DatasetPairIterator(data_loader_content_test, data_style_test)

In [11]:
DOWNUP_CONVOLUTIONS = 5 #3
ADAIN_CONVOLUTIONS = 7 #3
STYLE_DOWN_CONVOLUTIONS = 5 #3
NUM_LAYERS_NO_CONNECTION = 0
RESIDUAL_STYLE = True # False
RESIDUAL_DOWN = True # False
RESIDUAL_ADAIN = True
RESIDUAL_UP = True
STYLE_NORM = True
DOWN_NORM = 'in'
UP_NORM = 'adain'

if not VAE:
    STYLE_DIM = STYLE_DIM * 2

#normalization = ['adain', 'adain', None, None, None]
unet = model.unet.UNetAutoencoder(3, STYLE_DIM, residual_downsampling=RESIDUAL_DOWN, residual_adain=RESIDUAL_ADAIN, residual_upsampling=RESIDUAL_UP, 
        down_normalization=DOWN_NORM, up_normalization=UP_NORM, num_adain_convolutions=ADAIN_CONVOLUTIONS, 
        num_downup_convolutions=DOWNUP_CONVOLUTIONS, num_downup_without_connections=NUM_LAYERS_NO_CONNECTION, output_activation='sigmoid')

if VAE:
    style_encoder = model.autoencoder.Encoder(2 * STYLE_DIM, normalization=STYLE_NORM, residual=RESIDUAL_STYLE, num_down_convolutions=STYLE_DOWN_CONVOLUTIONS)
else:
    style_encoder = model.autoencoder.Encoder(STYLE_DIM, normalization=STYLE_NORM, residual=RESIDUAL_STYLE, num_down_convolutions=STYLE_DOWN_CONVOLUTIONS)

In [12]:
# Load pre-trained model
state_dict = torch.load('./models/model.pt')
unet.load_state_dict(state_dict['unet_state_dict'])
style_encoder.load_state_dict(state_dict['style_encoder_state_dict'])

<All keys matched successfully>

In [13]:
if torch.cuda.is_available(): 
    unet = unet.cuda()
    style_encoder = style_encoder.cuda()

In [14]:
def forward(content_image, style_image):
    """ Forward pass through the architecture.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_image : torch.Tensor, shape [batch_size, 3, H, W]
        The style images.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_encoding : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    style_encoding_mean : torch.Tensor, shape [batch_size, STYLE_DIM]
        Means for the style encodings.
    style_encoding_logvar : torch.Tensor, shape [batch_size, STYLE_DIM]
        Logarithm of the variances of the style encodings.
    """
    style_stats = style_encoder(style_image)
    style_mean = style_stats[..., : STYLE_DIM]
    style_logvar = style_stats[..., STYLE_DIM : ]
    style_sample = function.sample_normal(style_mean, style_logvar)
    stylized = unet(content_image, style_sample)
    return stylized, style_sample, style_mean, style_logvar

def forward_no_vae(content_image, style_image):
    """ Forward pass through the architecture.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_image : torch.Tensor, shape [batch_size, 3, H, W]
        The style images.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_encoding : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    """
    style_encoding = style_encoder(style_image)
    stylized = unet(content_image, style_encoding)
    return stylized, style_encoding

def forward_interpolate(content_image, style_image1, style_image2, interpolation_factor):
    """ Forward pass through the architecture.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_image : torch.Tensor, shape [batch_size, 3, H, W]
        The style images.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_encoding : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    style_encoding_mean : torch.Tensor, shape [batch_size, STYLE_DIM]
        Means for the style encodings.
    style_encoding_logvar : torch.Tensor, shape [batch_size, STYLE_DIM]
        Logarithm of the variances of the style encodings.
    """
    style_stats1 = style_encoder(style_image1)
    style_mean1 = style_stats1[..., : STYLE_DIM]
    style_logvar1 = style_stats1[..., STYLE_DIM : ]
    style_sample1 = function.sample_normal(style_mean1, style_logvar1)
    
    style_stats2 = style_encoder(style_image2)
    style_mean2 = style_stats2[..., : STYLE_DIM]
    style_logvar2 = style_stats2[..., STYLE_DIM : ]
    style_sample2 = function.sample_normal(style_mean2, style_logvar2)
    
    style_sample = interpolation_factor * style_sample1 + (1 - interpolation_factor) * style_sample2
    
    stylized = unet(content_image, style_sample)
    return stylized, style_sample

def forward_sample(content_image):
    """ Forward pass through the architecture.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_image : torch.Tensor, shape [batch_size, 3, H, W]
        The style images.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_encoding : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    style_encoding_mean : torch.Tensor, shape [batch_size, STYLE_DIM]
        Means for the style encodings.
    style_encoding_logvar : torch.Tensor, shape [batch_size, STYLE_DIM]
        Logarithm of the variances of the style encodings.
    """
    style_sample = torch.randn((BATCH_SIZE, STYLE_DIM), device=content_image.device, requires_grad=False)
    stylized = unet(content_image, style_sample)
    return stylized, style_sample

def forward_interpolate_no_vae(content_image, style_image1, style_image2, interpolation_factor):
    """ Forward pass through the architecture.
    
    Parameters:
    -----------
    content_image : torch.Tensor, shape [batch_size, 3, H, W]
        The content images.
    style_image : torch.Tensor, shape [batch_size, 3, H, W]
        The style images.
    
    Returns:
    --------
    stylized : torch.Tensor, shape [batch_size, 3, H, W]
        The stylizations.
    style_encoding : torch.Tensor, shape [batch_size, STYLE_DIM]
        Style encodings.
    style_encoding_mean : torch.Tensor, shape [batch_size, STYLE_DIM]
        Means for the style encodings.
    style_encoding_logvar : torch.Tensor, shape [batch_size, STYLE_DIM]
        Logarithm of the variances of the style encodings.
    """
    style_embedding1 = style_encoder(style_image1)
    style_embedding2 = style_encoder(style_image2)
    
    style_embedding = interpolation_factor * style_embedding1 + (1 - interpolation_factor) * style_embedding2
    
    stylized = unet(content_image, style_embedding)
    return stylized, style_embedding


# Pertrub styles (training set)
We perturb styles by projecting them using PCA, adjusting the component that captures most of the variance and projecting back.

In [15]:
NUM_PERTURBATIONS = 5

In [16]:
# Embed all training styles
style_embeddings = np.empty((len(data_style_train), STYLE_DIM))
style_encoder.eval()
num_embedded = 0
with torch.no_grad():
    for styles, _ in data_loader_style_train:
        for idx in range(styles.size(0)):
            embedding_stats = style_encoder(styles[idx : idx + 1].cuda())
            embedding_mean = embedding_stats[0, : STYLE_DIM]
            style_embeddings[num_embedded] = embedding_mean.detach().cpu().numpy()
            num_embedded += 1
style_embeddings = style_embeddings[:num_embedded]

In [17]:
pca = PCA(n_components=64)
embeddings_transformed = pca.fit_transform(style_embeddings)
print(f'{pca.explained_variance_ratio_.sum()} of total variance is captured by pca')
print(f'{pca.explained_variance_ratio_[0]} of total variance is captured by first component')

0.9917115380511897 of total variance is captured by pca
0.19497261764524146 of total variance is captured by first component


In [18]:
style_encoder.cpu() # Save my poor GPU memory...
None

In [19]:
epsilons = np.linspace(-10, 10, NUM_PERTURBATIONS)
s = np.tile(embeddings_transformed, (epsilons.shape[0], 1, 1))
s[:, :, 0] += epsilons.reshape((-1, 1))
embeddings_perturbed = pca.inverse_transform(s.reshape((-1, s.shape[-1])))
embeddings_perturbed = embeddings_perturbed.reshape((s.shape[0], s.shape[1], -1))
embeddings_perturbed.shape

(5, 96, 512)

In [None]:
PERTURBATION_ITERATIONS = embeddings_perturbed.shape[1]
with torch.no_grad():
    val_iteration = 0
    unet.eval()
    for (content_image, content_path), _ in data_loader_val:

        if val_iteration >= PERTURBATION_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image.to('cuda')
        style_encoding = torch.from_numpy(embeddings_perturbed[:, val_iteration, :]).to('cuda').float()
        content_image = content_image[:style_encoding.size(0)]
        stylized = unet(content_image, style_encoding)
        
            
        tb_writer.add_images('perturbations along first pc', torch.from_numpy(np.concatenate([
            img.detach().cpu().numpy() for img in [content_image[0:1], stylized]
        ])), val_iteration)
        val_iteration += 1


## Perturb along two principal components

In [20]:
print(f'{pca.explained_variance_ratio_[:2].sum()} of total variance is captured by first two components')

0.33270730367451795 of total variance is captured by first two components


In [21]:
xx, yy = np.meshgrid(np.linspace(-10, 10, NUM_PERTURBATIONS), np.linspace(-10, 10, NUM_PERTURBATIONS))

In [22]:
s = np.tile(embeddings_transformed, (NUM_PERTURBATIONS, NUM_PERTURBATIONS, 1, 1))
s[:, :, :, 0] += xx.reshape((NUM_PERTURBATIONS, NUM_PERTURBATIONS, 1))
s[:, :, :, 1] += yy.reshape((NUM_PERTURBATIONS, NUM_PERTURBATIONS, 1))
embeddings_perturbed = pca.inverse_transform(s.reshape((-1, s.shape[-1])))
embeddings_perturbed = embeddings_perturbed.reshape(list(s.shape[:-1]) +  [-1])
embeddings_perturbed.shape

(5, 5, 96, 512)

In [25]:
PERTURBATION_ITERATIONS = embeddings_perturbed.shape[-2]
with torch.no_grad():
    val_iteration = 0
    unet.eval()
    unet.cpu()
    
    for (content_image, content_path), _ in data_loader_val:

        if val_iteration >= PERTURBATION_ITERATIONS:
            break

        if torch.cuda.is_available():
            content_image = content_image
        style_encoding = torch.from_numpy(embeddings_perturbed[:, :, val_iteration, :]).float()
        # Flatten to input as NUM_PERTURBATIONS x NUM_PERTURBATIONS batch
        style_encoding = style_encoding.view(NUM_PERTURBATIONS * NUM_PERTURBATIONS, style_encoding.size(-1))
        
        content_image = content_image[0].unsqueeze(0).expand((style_encoding.size(0), -1, -1, -1))
        
        stylized = unet(content_image, style_encoding)
        tb_writer.add_image('perturbations along first two pcs', 
                            torchvision.utils.make_grid(stylized, nrow=NUM_PERTURBATIONS), val_iteration)
        print(val_iteration)
        val_iteration += 1

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
