In [0]:
#@title #### Copyright © 2019 Evan Davis

#@markdown evan@skimai.com / [@eridgd](http://github.com/eridgd/)

In [0]:
#@title #### Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# <table class="tfo-notebook-buttons" align="left">
#   <td>
#     <a target="_blank" href="https://colab.research.google.com/github/tensorflow/models/blob/master/research/nst_blogpost/4_Neural_Style_Transfer_with_Eager_Execution.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
#   </td>
#   <td>
#     <a target="_blank" href="https://github.com/tensorflow/models/blob/master/research/nst_blogpost/4_Neural_Style_Transfer_with_Eager_Execution.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
#   </td>
# </table>

# Artistic Style Transfer with Convolutional Neural Networks

## Abstractions II &#151; August 23, 2019 &#151; Pittsburgh, PA

![](https://github.com/eridgd/ArtNet-WCT-TF2/raw/master/samples/abstractions_header.gif)

This notebook accompanies my [talk on Neural Artistic Style Transfer](https://docs.google.com/presentation/d/1mj_dyBkSylKMhQKzyGu7_oHpr6IbpaGXGuJDnQgrQGE/edit?usp=sharing) for [Abstractions II](https://abstractions.io/) and contains TensorFlow 2.0 implementations for five papers:

* [Fast Universal Style Transfer for Artistic and Photorealistic Rendering](https://arxiv.org/abs/1907.03118) An et al, 2019
* [A Closed-form Solution to Universal Style Transfer](https://arxiv.org/abs/1906.00668) Lu et al, 2019
* [Universal Style Transfer via Feature Transforms](https://arxiv.org/abs/1705.08086) Li et al, 2017
* [Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization](https://arxiv.org/abs/1703.06868) Huang and Belongie, 2017
* [Fast Patch-based Style Transfer of Arbitrary Style](https://arxiv.org/abs/1612.04337) Chen and Schmidt, 2016

# TO RUN THIS NOTEBOOK

This notebook is intended to demonstrate style transfer in real-time and requires a webcam. It may or may not work on a mobile device.

Just a few quick steps to get started:
 * Log in to a Google account.
 * **_File -> Save a copy in Drive..._**  -- Save a copy of this notebook in your GDrive.
 * **_Runtime -> Run all_** -- Initiates a free GPU Cloud VM, downloads models, and starts webcam demo.
 * <a href="#scrollTo=9len-gzKr5hz">Scroll to the bottom</a> to see the live stylized output.



# Setup TensorFlow

## Make sure we're using a GPU instance

If **`nvidia-smi`** shows no GPU you'll need to change the runtime type:
* **_Runtime -> Change runtime type_**
* Under `Hardware accelerator` choose `GPU` and save.
* Run the notebook again on the new instance.

In [0]:
!nvidia-smi

## Install TensorFlow 2.0 Beta

In [0]:
!pip install tensorflow-gpu==2.0.0-beta1

In [0]:
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

In [0]:
import tensorflow as tf

assert tf.executing_eagerly()

# Downloads


### VGG19 Weights

We'll being using a variant of the "VGG19" convolutional network as our image encoder.

In [0]:
VGG_PATH = 'vgg_normalised.t7'

In [0]:
!wget -c -O {VGG_PATH} "https://www.dropbox.com/s/kh8izr3fkvhitfn/vgg_normalised.t7?dl=1"

### Pre-trained Decoder weights

This is an 'ArtNet'-style decoder architecture that was separately trained to invert VGG features to reconstruct the original input image.

In [0]:
WEIGHTS_PATH = 'artnet_decoder_full_nnupsample_40kunlabeled_tv0.keras'

In [0]:
!wget -c -O {WEIGHTS_PATH} https://www.dropbox.com/s/f992fv5zv4ao0dw/artnet_decoder_full_nnupsample_40kunlabeled_tv0.keras?dl=1

### Sample style images

For style images we'll use a sample of the WikiArt dataset obtained from https://github.com/cs-chan/ArtGAN/tree/master/WikiArt%20Dataset

In [0]:
!wget -c -O wikiart_samples_abstractions.tar.gz https://www.dropbox.com/s/h1ngp8ajthpitku/wikiart_samples_abstractions.tar.gz?dl=1

!tar xfz wikiart_samples_abstractions.tar.gz

from pathlib import Path
style_image_paths = [str(f) for f in Path('wikiart_samples_abstractions').glob('**/*.jpg')]
len(style_image_paths)

# Setup Utilities

## Image utilities

In [0]:
def resize_image_keep_aspect(image, lo_dim):
    shape = tf.shape(image)
    width, height = shape[0], shape[1]
    min_ = tf.minimum(width, height)
    ratio = tf.cast(min_, tf.float32) / tf.constant(lo_dim, dtype=tf.float32)
    new_width = tf.cast(tf.cast(width, tf.float32) / ratio, tf.int32)
    new_height = tf.cast(tf.cast(height, tf.float32) / ratio, tf.int32)
    return tf.image.resize(image, [new_width, new_height])

In [0]:
def crop_center(image):
    h, w = image.shape[-3], image.shape[-2]
    if h > w:
        cropped_image = tf.image.crop_to_bounding_box(image, (h - w) // 2, 0, w, w)
    else:
        cropped_image = tf.image.crop_to_bounding_box(image, 0, (w - h) // 2, h, h)
    return cropped_image

In [0]:
def preprocess_image(image, size):
    image = tf.image.decode_jpeg(image, channels=3)
    # image = tf.cast(image, tf.float32)
    image = resize_image_keep_aspect(image, size)
    image = tf.image.random_crop(image, (size,size,3))
    # image = crop_center(image)
    image /= 255.0  # normalize to [0,1] range
    return image

In [0]:
def load_and_preprocess_image(path, size):
    image = tf.io.read_file(path)
    return preprocess_image(image, size)

In [0]:
import PIL
import IPython.display as ipydisplay

def plot_image(image):
    if len(image.shape) == 4:
        image = image[0]
    image *= 255
    ipydisplay.display(PIL.Image.fromarray((image).numpy().astype('uint8')))

In [0]:
def deconv_output_length(input_length, filter_size, padding, stride):
  """Determines output length of a transposed convolution given input length.
  Arguments:
      input_length: integer.
      filter_size: integer.
      padding: one of "same", "valid", "full".
      stride: integer.
  Returns:
      The output length (integer).
  From: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/layers/utils.py#L159
  """
  if input_length is None:
    return None
  input_length *= stride
  if padding == 'valid':
    input_length += max(filter_size - stride, 0)
  elif padding == 'full':
    input_length -= (stride + filter_size - 2)
  return input_length

## Define Feature Transform Methods

### White-Color Transform (WCT)

In [0]:
def wct_tf(content, style, alpha, eps=1e-9):
    '''TensorFlow version of Whiten-Color Transform
       Assume that content/style encodings have shape 1xHxWxC

       See p.4 of the Universal Style Transfer paper for corresponding equations:
       https://arxiv.org/pdf/1705.08086.pdf
    '''
    # Remove batch dim and reorder to CxHxW
    content_t = tf.transpose(tf.squeeze(content), (2, 0, 1))
    style_t = tf.transpose(tf.squeeze(style), (2, 0, 1))

    Cc, Hc, Wc = tf.unstack(tf.shape(content_t))
    Cs, Hs, Ws = tf.unstack(tf.shape(style_t))

    # CxHxW -> CxH*W
    content_flat = tf.reshape(content_t, (Cc, Hc*Wc))
    style_flat = tf.reshape(style_t, (Cs, Hs*Ws))

    # Content covariance
    mc = tf.reduce_mean(content_flat, axis=1, keepdims=True)
    fc = content_flat - mc
    fcfc = fc @ tf.transpose(fc) / tf.cast((Hc*Wc - 1), tf.float32) + tf.eye(Cc)*eps

    # Style covariance
    ms = tf.reduce_mean(style_flat, axis=1, keepdims=True)
    fs = style_flat - ms
    fsfs = fs @ tf.transpose(fs) / tf.cast((Hs*Ws - 1), tf.float32) + tf.eye(Cs)*eps

    # tf.linalg.svd is slower on GPU, see https://github.com/tensorflow/tensorflow/issues/13603
    with tf.device('/cpu:0'):  
        Sc, Uc, _ = tf.linalg.svd(fcfc)
        Ss, Us, _ = tf.linalg.svd(fsfs)

    # # Uncomment to use PyTorch to compute SVD on CPU. Embarrasingly faster than TF.
    # import torch
    # Uc, Sc, _ = torch.svd(torch.from_numpy(fcfc.numpy()))
    # Us, Ss, _ = torch.svd(torch.from_numpy(fsfs.numpy()))
    # Uc, Sc, Us, Ss = Uc.numpy(), Sc.numpy(), Us.numpy(), Ss.numpy()

    # Filter small singular values
    k_c = tf.reduce_sum(tf.cast(tf.greater(Sc, 1e-5), tf.int32))
    k_s = tf.reduce_sum(tf.cast(tf.greater(Ss, 1e-5), tf.int32))

    # Whiten content feature
    Dc_inv_sqrt = tf.linalg.diag(tf.pow(Sc[:k_c], -0.5))
    fc_hat = Uc[:,:k_c] @ Dc_inv_sqrt @ tf.transpose(Uc[:,:k_c]) @ fc

    # Color content with style
    Ds_sqrt = tf.linalg.diag(tf.pow(Ss[:k_s], 0.5))
    fcs_hat = Us[:,:k_s] @ Ds_sqrt @ tf.transpose(Us[:,:k_s]) @ fc_hat

    # Re-center with mean of style
    fcs_hat = fcs_hat + ms

    # Blend whiten-colored feature with original content feature
    blended = alpha * fcs_hat + (1 - alpha) * (fc + mc)

    # CxH*W -> CxHxW
    blended = tf.reshape(blended, (Cc,Hc,Wc))
    # CxHxW -> 1xHxWxC
    blended = tf.expand_dims(tf.transpose(blended, (1,2,0)), 0)

    return blended

### Adaptive Instance Normalization (AdaIN)

In [0]:
def adain(content_features, style_features, alpha, epsilon=1e-5):
    '''
    Borrowed from https://github.com/jonrei/tf-AdaIN
    Normalizes the `content_features` with scaling and offset from `style_features`.
    See "5. Adaptive Instance Normalization" in https://arxiv.org/abs/1703.06868 for details.
    '''
    style_mean, style_variance = tf.nn.moments(style_features, [1,2], keepdims=True)
    content_mean, content_variance = tf.nn.moments(content_features, [1,2], keepdims=True)
    normalized_content_features = tf.nn.batch_normalization(content_features, content_mean,
                                                            content_variance, style_mean, 
                                                            tf.sqrt(style_variance), epsilon)
    normalized_content_features = alpha * normalized_content_features + (1 - alpha) * content_features
    return normalized_content_features

### "Closed Form" Transform (Optimal Transport)

In [0]:
# import torch

def optimal_tf(content, style, alpha, eps=1e-4):
    '''TensorFlow version of Optimal Transport transform from "A Closed-form Solution to Universal Style Transfer"
       Paper: https://arxiv.org/abs/1906.00668 
       Author's Torch code: https://github.com/lu-m13/OptimalStyleTransfer/blob/master/optimal.lua
       PyTorch version: https://github.com/sunshineatnoon/PytorchWCT/blob/a7a8e29b0561c231c8a1ffcdd952d954231d18e6/util.py
    '''
    # Remove batch dim and reorder to CxHxW
    content_t = tf.transpose(tf.squeeze(content), (2, 0, 1))
    style_t = tf.transpose(tf.squeeze(style), (2, 0, 1))

    Cc, Hc, Wc = tf.unstack(tf.shape(content_t))
    Cs, Hs, Ws = tf.unstack(tf.shape(style_t))

    # CxHxW -> CxH*W
    content_flat = tf.reshape(content_t, (Cc, Hc*Wc))
    style_flat = tf.reshape(style_t, (Cs, Hs*Ws))

    # Content covariance
    mc = tf.reduce_mean(content_flat, axis=1, keepdims=True)
    fc = content_flat - mc
    fcfc = fc @ tf.transpose(fc) / tf.cast((Hc*Wc - 1), tf.float32) + tf.eye(Cc)*eps

    # Style covariance
    ms = tf.reduce_mean(style_flat, axis=1, keepdims=True)
    fs = style_flat - ms
    fsfs = fs @ tf.transpose(fs) / tf.cast((Hs*Ws - 1), tf.float32) + tf.eye(Cs)*eps

    # tf.linalg.svd is slower on GPU, see https://github.com/tensorflow/tensorflow/issues/13603
    with tf.device('/cpu:0'):  
        Sc, Uc, _ = tf.linalg.svd(fcfc)
    
    # # Uncomment to use PyTorch to compute SVD on CPU. Embarrasingly faster than TF.
    # Uc, Sc, _ = torch.svd(torch.from_numpy(fcfc.numpy()), some=False)
    # Uc, Sc = Uc.numpy(), Sc.numpy()

    # Filter small singular values
    k_c = tf.reduce_sum(tf.cast(tf.greater(Sc, 1e-5), tf.int32))

    Dc_sqrt     = tf.linalg.diag(tf.pow(Sc[:k_c], 0.5))
    Dc_inv_sqrt = tf.linalg.diag(tf.pow(Sc[:k_c], -0.5))
    
    cF_sqrt = Uc[:,:k_c] @ Dc_sqrt @ tf.transpose(Uc[:,:k_c])
    cF_inv_sqrt = Uc[:,:k_c] @ Dc_inv_sqrt @ tf.transpose(Uc[:,:k_c])

    middle_matrix = cF_sqrt @ fsfs @ cF_sqrt
    with tf.device('/cpu:0'):
        Sm, Um, _ = tf.linalg.svd(middle_matrix)

    # # Uncomment to use PyTorch to compute SVD
    # Um, Sm, _ = torch.svd(torch.from_numpy(middle_matrix.numpy()), some=False)
    # Um, Sm = Um.numpy(), Sm.numpy()

    k_m = tf.reduce_sum(tf.cast(tf.greater(Sm, 1e-5), tf.int32))
    Dm_sqrt = tf.linalg.diag(tf.pow(Sm[:k_m], 0.5))
    middle_matrix_sqrt = Um[:,:k_m] @ Dm_sqrt @ tf.transpose(Um[:, :k_m])
    
    # Apply transformation matrix to centered content feature
    transform_matrix = cF_inv_sqrt @ middle_matrix_sqrt @ cF_inv_sqrt
    fcs_hat = transform_matrix @ fc

    # Re-center with mean of style
    fcs_hat = fcs_hat + ms

    # Blend whiten-colored feature with original content feature
    blended = alpha * fcs_hat + (1 - alpha) * (fc + mc)

    # CxH*W -> CxHxW
    blended = tf.reshape(blended, (Cc,Hc,Wc))
    # CxHxW -> 1xHxWxC
    blended = tf.expand_dims(tf.transpose(blended, (1,2,0)), 0)

    return blended

### Style-Swap

In [0]:
def style_swap(content, style, patch_size, stride):
    '''Efficiently swap content feature patches with nearest-neighbor style patches
       Original paper: https://arxiv.org/abs/1612.04337
       Adapted from: https://github.com/rtqichen/style-swap/blob/master/lib/NonparametricPatchAutoencoderFactory.lua
    '''
    nC = tf.shape(style)[-1]  # Num channels of input content feature and style-swapped output

    ### Extract patches from style image that will be used for conv/deconv layers
    style_patches = tf.image.extract_patches(style, [1,patch_size,patch_size,1], [1,stride,stride,1], [1,1,1,1], 'VALID')
    before_reshape = tf.shape(style_patches)  # NxRowsxColsxPatch_size*Patch_size*nC
    style_patches = tf.reshape(style_patches, [before_reshape[1]*before_reshape[2],patch_size,patch_size,nC])
    style_patches = tf.transpose(style_patches, [1,2,3,0])  # Patch_sizexPatch_sizexIn_CxOut_c

    # Normalize each style patch
    style_patches_norm = tf.nn.l2_normalize(style_patches, axis=3)

    # Compute cross-correlation/nearest neighbors of patches by using style patches as conv filters
    ss_enc = tf.nn.conv2d(content,
                          style_patches_norm,
                          [1,stride,stride,1],
                          'VALID')

    # For each spatial position find index of max along channel/patch dim  
    ss_argmax = tf.argmax(ss_enc, axis=3)
    encC = tf.shape(ss_enc)[-1]  # Num channels in intermediate conv output, same as # of patches
    
    # One-hot encode argmax with same size as ss_enc, with 1's in max channel idx for each spatial pos
    ss_oh = tf.one_hot(ss_argmax, encC, 1., 0., 3)

    # Calc size of transposed conv out
    deconv_out_H = deconv_output_length(tf.shape(ss_oh)[1], patch_size, 'valid', stride)
    deconv_out_W = deconv_output_length(tf.shape(ss_oh)[2], patch_size, 'valid', stride)
    deconv_out_shape = tf.stack([1,deconv_out_H,deconv_out_W,nC])

    # Deconv back to original content size with highest matching (unnormalized) style patch swapped in for each content patch
    ss_dec = tf.nn.conv2d_transpose(ss_oh,
                                    style_patches,
                                    deconv_out_shape,
                                    [1,stride,stride,1],
                                    'VALID')

    ### Interpolate to average overlapping patch locations
    ss_oh_sum = tf.reduce_sum(ss_oh, axis=3, keepdims=True)

    filter_ones = tf.ones([patch_size,patch_size,1,1], dtype=tf.float32)
    
    deconv_out_shape = tf.stack([1,deconv_out_H,deconv_out_W,1])  # Same spatial size as ss_dec with 1 channel

    counting = tf.nn.conv2d_transpose(ss_oh_sum,
                                         filter_ones,
                                         deconv_out_shape,
                                         [1,stride,stride,1],
                                         'VALID')

    counting = tf.tile(counting, [1,1,1,nC])  # Repeat along channel dim to make same size as ss_dec

    interpolated_dec = tf.divide(ss_dec, counting)

    return interpolated_dec


def wct_style_swap(content, style, alpha, ss_blend, patch_size=3, stride=1, eps=1e-9):
    '''Modified Whiten-Color Transform that performs style swap on whitened content/style encodings before coloring
       Assume that content/style encodings have shape 1xHxWxC
    '''
    beta=0.5
    content_t = tf.transpose(tf.squeeze(content), (2, 0, 1))
    style_t = tf.transpose(tf.squeeze(style), (2, 0, 1))

    Cc, Hc, Wc = tf.unstack(tf.shape(content_t))
    Cs, Hs, Ws = tf.unstack(tf.shape(style_t))

    # CxHxW -> CxH*W
    content_flat = tf.reshape(content_t, (Cc, Hc*Wc))
    style_flat = tf.reshape(style_t, (Cs, Hs*Ws))

    # Content covariance
    mc = tf.reduce_mean(content_flat, axis=1, keepdims=True)
    fc = content_flat - mc
    fcfc = fc @ tf.transpose(fc) / tf.cast((Hc*Wc - 1), tf.float32) + tf.eye(Cc)*eps	

    # Style covariance
    ms = tf.reduce_mean(style_flat, axis=1, keepdims=True)
    fs = style_flat - ms
    fsfs = fs @ tf.transpose(fs) / tf.cast((Hs*Ws - 1), tf.float32) + tf.eye(Cs)*eps	

    # tf.linalg.svd is slower on GPU, see https://github.com/tensorflow/tensorflow/issues/13603
    with tf.device('/cpu:0'):  
        Sc, Uc, _ = tf.linalg.svd(fcfc)
        Ss, Us, _ = tf.linalg.svd(fsfs)

    k_c = tf.reduce_sum(tf.cast(tf.greater(Sc, 1e-5), tf.int32))
    k_s = tf.reduce_sum(tf.cast(tf.greater(Ss, 1e-5), tf.int32))

    ### Whiten content feature
    Dc = tf.linalg.diag(tf.pow(Sc[:k_c], -0.5))

    fc_hat = Uc[:,:k_c] @ Dc @ tf.transpose(Uc[:,:k_c]) @ fc		
    # fc_hat = tf.matmul(tf.matmul(tf.matmul(Uc[:,:k_c], Dc), Uc[:,:k_c], transpose_b=True), fc)

    # Reshape before passing to style swap, CxH*W -> 1xHxWxC
    whiten_content = tf.expand_dims(tf.transpose(tf.reshape(fc_hat, [Cc,Hc,Wc]), [1,2,0]), 0)

    ### Whiten style before swapping
    Ds_inv_sqrt = tf.linalg.diag(tf.pow(Ss[:k_s], -0.5))
    whiten_style = Us[:,:k_s] @ Ds_inv_sqrt @ tf.transpose(Us[:,:k_s]) @ fs
    # whiten_style = tf.matmul(tf.matmul(tf.matmul(Us[:,:k_s], Ds_inv_sqrt), Us[:,:k_s], transpose_b=True), fs)
    # Reshape before passing to style swap, CxH*W -> 1xHxWxC
    whiten_style = tf.expand_dims(tf.transpose(tf.reshape(whiten_style, [Cs,Hs,Ws]), [1,2,0]), 0)

    ### Style swap whitened encodings
    ss_feature = style_swap(whiten_content, whiten_style, patch_size, stride)
    # # HxWxC -> CxH*W
    # ss_feature = tf.transpose(tf.reshape(ss_feature, [Hc*Wc,Cc]), [1,0])

    Wc_ss_blended = ss_blend * ss_feature + (1 - ss_blend) * whiten_content[0]
    # # HxWxC -> CxH*W
    Wc_ss_blended = tf.transpose(tf.reshape(Wc_ss_blended, [Hc*Wc,Cc]), [1,0])

    ### Color style-swapped encoding with style 
    Ds_sqrt = tf.linalg.diag(tf.pow(Ss[:k_s], 0.5))
    fcs_hat = Us[:,:k_s] @ Ds_sqrt @ tf.transpose(Us[:,:k_s]) @ Wc_ss_blended
    # fcs_hat = tf.matmul(tf.matmul(tf.matmul(Us[:,:k_s], Ds_sqrt), Us[:,:k_s], transpose_b=True), ss_feature)
    fcs_hat = fcs_hat + ms

    ### Blend style-swapped & colored encoding with original content encoding
    blended = alpha * fcs_hat + (1 - alpha) * (fc + mc)
    # CxH*W -> CxHxW
    blended = tf.reshape(blended, (Cc,Hc,Wc))
    # CxHxW -> 1xHxWxC
    blended = tf.expand_dims(tf.transpose(blended, (1,2,0)), 0)

    return blended

# Define Model

In [0]:
import os
import numpy as np
import random
import time

from tensorflow.keras.layers import Input, UpSampling2D, MaxPooling2D, Conv2D, Activation

## VGG19 from vgg_normalised.t7

In [0]:
# Borrow a modified version of torchfile to read VGG weight file
!wget --quiet -c https://raw.githubusercontent.com/eridgd/WCT-TF/master/torchfile.py

In [0]:
import torchfile

def vgg_from_t7(t7_file, output_layers=['relu1_1','relu2_1','relu3_1','relu4_1','relu5_1'], trainable=False):
    '''Extract VGG layers from a Torch .t7 model into a Keras model
       e.g. vgg = vgg_from_t7('vgg_normalised.t7', target_layer='relu4_1')
       Adapted from https://github.com/jonrei/tf-AdaIN/blob/master/AdaIN.py
       Converted caffe->t7 from https://github.com/xunhuang1995/AdaIN-style
    '''
    t7 = torchfile.load(t7_file, force_8bytes_long=True)
    
    inp = Input(shape=(None, None, 3), name='vgg_input')
    x = inp
    
    outputs = []
    for idx,module in enumerate(t7.modules):
        name = module.name.decode() if module.name is not None else None
        
        if idx == 0:
            name = 'preprocess'  # VGG 1st layer preprocesses with a 1x1 conv to multiply by 255 and subtract BGR mean as bias

        if module._typename == b'nn.SpatialReflectionPadding':
            continue  # Use 'same' zero padding instead of reflection pad
        elif module._typename == b'nn.SpatialConvolution':
            filters = module.nOutputPlane
            kernel_size = module.kH
            weight = module.weight.transpose([2,3,1,0])
            bias = module.bias
            x = Conv2D(filters, kernel_size, padding='same', activation=None, name=name,
                        kernel_initializer=tf.constant_initializer(weight),
                        bias_initializer=tf.constant_initializer(bias),
                        trainable=trainable)(x)
        elif module._typename == b'nn.ReLU':
            x = Activation('relu', name=name)(x)
        elif module._typename == b'nn.SpatialMaxPooling':
            x = MaxPooling2D(padding='same', name=name)(x)
        else:
            raise NotImplementedError(module._typename)
            
        if name in output_layers:
            outputs.append(x)

        if name == output_layers[-1]:
            break
    
    model = tf.keras.Model(inputs=inp, outputs=outputs)

    model.trainable = trainable

    return model

In [0]:
class VGG19(tf.keras.Model):
    def __init__(self, vgg_path=VGG_PATH):
        super(VGG19, self).__init__()
        self.vgg_tiers = vgg_from_t7(vgg_path, 
                             output_layers=['relu1_1','relu2_1','relu3_1','relu4_1','relu5_1'], 
                             trainable=False)
        
    def call(self, x):
        return self.vgg_tiers(x)

## Decoder

This is the only model that is trained. At inference time it allows feature transforms to be applied at multiple decoder layers to achieve stylization.

### Pyramid Fusion Conv layer to combine VGG features

In [0]:
class PyramidFuseConv(tf.keras.layers.Layer):
    def __init__(self):
        super(PyramidFuseConv, self).__init__()
        self.downsample_layers = []
        for psize in [16, 8, 4, 2, 1]:
            self.downsample_layers.append(Conv2D(filters=32, 
                                                 kernel_size=psize, 
                                                 strides=psize, 
                                                 padding='same',
                                                 activation='relu'))
        self.fuse_conv = Conv2D(filters=512, kernel_size=1, activation='relu', padding='same', 
                                name='pyramid_fuse_conv')
    
    def call(self, vgg_feats):
        downsampled_feats = []
        for feat, layer in zip(vgg_feats, self.downsample_layers):
            feat = layer(feat)
            downsampled_feats.append(feat)
        feats_concat = tf.concat(downsampled_feats, axis=-1)
        return self.fuse_conv(feats_concat)

### ArtNet with Pyramid fuse layer

In [0]:
class ArtNetDecoder(tf.keras.Model):
    def __init__(self):
        super(ArtNetDecoder, self).__init__()
        self.pyramid_fuse_conv = PyramidFuseConv()  # Downsample & fuse intermediate VGG features
        
        self.decoder_tiers = {
            5: [                                                                        #  HxW  / InC->OutC                                     
                Conv2D(filters=512, kernel_size=3, padding='same', activation='relu'),  # 32x32 / 512->512
                UpSampling2D(),                                                         # 32x32 -> 64x64
                Conv2D(filters=512, kernel_size=3, padding='same', activation='relu'),  # 64x64 / 512->512
                Conv2D(filters=512, kernel_size=3, padding='same', activation='relu'),  # 64x64 / 512->512
                Conv2D(filters=512, kernel_size=3, padding='same', activation='relu')], # 64x64 / 512->512
            4: [
                Conv2D(filters=256, kernel_size=3, padding='same', activation='relu'),  # 64x64 / 512->256
                UpSampling2D(),                                                         # 64x64 -> 128x128
                Conv2D(filters=256, kernel_size=3, padding='same', activation='relu'),  # 128x128 / 256->256
                Conv2D(filters=256, kernel_size=3, padding='same', activation='relu'),  # 128x128 / 256->256
                Conv2D(filters=256, kernel_size=3, padding='same', activation='relu')], # 128x128 / 256->256
            3: [
                Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'),  # 128x128 / 256->128
                UpSampling2D(),                                                         # 128x128 -> 256x256
                Conv2D(filters=128, kernel_size=3, padding='same', activation='relu')], # 256x256 / 128->128
            2: [
                Conv2D(filters=64,  kernel_size=3, padding='same', activation='relu'),  # 256x256 / 128->64
                UpSampling2D()],                                                        # 256x256 -> 512x512
            1: [
                Conv2D(filters=64,  kernel_size=3, padding='same', activation='relu')]  # 512x512 / 64->64
        }
        
        self.out = Conv2D(filters=3, kernel_size=3, padding='same', activation=None)    # 512x512 / 64->3
    
    def call(self, vgg_feats, return_tiers=None, style_decoder_feats=None, transform=wct_tf, alpha=1.,
             swap_tiers=None, ss_blend=1., ss_patch_size=3, ss_stride=1, clip=False):
        x = self.pyramid_fuse_conv(vgg_feats)

        if return_tiers:
            tier_ins = []

        for i, tier in self.decoder_tiers.items():
            if return_tiers and i in return_tiers:
                tier_ins.append(x)
                if i == min(return_tiers):
                    return tier_ins

            # Optionally apply WCT or AdaIN transform before decoding
            # This is only used at inference time
            if style_decoder_feats is not None and i in style_decoder_feats:  
                if swap_tiers and i in swap_tiers:   # Style-swap + WCT
                    x = wct_style_swap(x, style_decoder_feats[i], alpha=alpha, ss_blend=ss_blend, 
                                       patch_size=ss_patch_size, stride=ss_stride)
                else:                                # WCT/AdaIN/OT
                    x = transform(x, style_decoder_feats[i], alpha=alpha)

            for layer in tier:
                x = layer(x)

        x = self.out(x)
        
        if clip:
            x = tf.clip_by_value(x, 0, 1)
        return x

## Autoencoder with VGG Encoder and ArtNet Decoder

In [0]:
encoder = VGG19()          # VGG19 weights are loaded automatically
decoder = ArtNetDecoder()

### Load trained decoder weights

In [0]:
# Loading weights for a subclassed model requires data to be run through it first. 
# See https://www.tensorflow.org/beta/guide/keras/saving_and_serializing#saving_subclassed_models
_ = decoder(encoder(tf.ones([1,128,128,3])))  

# Now we can load the weights.
decoder.load_weights(WEIGHTS_PATH)

### Define stylization procedure

`@tf.function` invokes AutoGraph to compile the model to a graph. See https://www.tensorflow.org/beta/guide/autograph

In [0]:
@tf.function
def extract_style_feats(encoder, decoder, style_img, tiers=[5,4,3]):
    style_feats = encoder(style_img)
    style_feats = decoder(style_feats, return_tiers=tiers)
    return dict(zip(tiers, style_feats))

@tf.function
def stylize(encoder, decoder, content_img, style_img=None, style_feats=None, 
            tiers=[5,4,3], transform=wct_tf, alpha=0.6,
            swap_tiers=None, ss_blend=0.6, ss_patch_size=3, ss_stride=1):
    if style_feats is None:
        style_feats = extract_style_feats(encoder, decoder, style_img, tiers)

    inp = decoder(encoder(content_img),
                    style_decoder_feats=style_feats,
                    transform=transform,
                    alpha=alpha,
                    swap_tiers=swap_tiers, ss_blend=ss_blend, ss_patch_size=ss_patch_size, ss_stride=ss_stride,
                    clip=True)
    return inp   

# Webcam Demo

In [0]:
from IPython.display import display, Javascript
from google.colab.output import eval_js
from base64 import b64decode

def setup_stream():
  js = Javascript('''
    async function setupStream() {
      const div = document.createElement('div');
      video = document.createElement('video');
      video.style.display = 'none';
      stream = await navigator.mediaDevices.getUserMedia({video: true});
      document.body.appendChild(div);
      div.appendChild(video);
      div.style.visibility='hidden' 
      video.srcObject = stream;
      await video.play();
      // google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);
      canvas = document.createElement('canvas');
      canvas.width = video.videoWidth;
      canvas.height = video.videoHeight;
  }
  ''')
  display(js)
  eval_js('setupStream()')

def take_photo(filename='photo.jpg', quality=0.8):
  js = Javascript('''
    async function takePhoto(quality) {
      canvas.getContext('2d').drawImage(video, 0, 0, video.videoWidth, video.videoHeight);
      return canvas.toDataURL('image/jpeg', quality);
    }
    ''')
  display(js)
  data = eval_js('takePhoto({})'.format(quality))
  binary = b64decode(data.split(',')[1])
  return binary

In [0]:
from PIL import Image
import io
import IPython.display as ipydisplay
import time
import numpy as np

def stylize_webcam(style_image_size, randomize_every, tiers, alpha, transform, passes, 
                   swap_tiers, ss_blend, ss_patch_size, ss_stride):
    setup_stream()
    disp_result = ipydisplay.display('', display_id='stylization')
    disp_text = ipydisplay.display('', display_id='time')

    i = -1
    while True:
        i += 1
        try: 
            binary = take_photo()
            content_img = Image.open(io.BytesIO(binary))
            content_img = np.array(content_img)[None, ...].astype('float32') / 255

            if i % randomize_every == 0:
                style_height = content_img.shape[1]
                style_path = random.choice(style_image_paths)
                # style_path = 'sullivan.jpg'
                disp_text.update(f"Loading style from {style_path}")
                try:
                    style_img = load_and_preprocess_image(style_path, style_image_size)  # Img is center-cropped square
                    style_img = style_img[None, ...]  # Expand with batch dim
                    style_feats_dict = extract_style_feats(encoder, decoder, style_img, tiers)
                except Exception as e:  # tf.image.random_crop sometimes has issues
                    print(e)
                    i -= 1
                    continue
                # if style_img.shape[0] > style_image_size:
                    # print("Resizing from ", style_img.shape, "to", style_image_size)
                # style_img = tf.image.resize(style_img, [style_image_size, style_image_size])
                # Pre-calculate the style image decoder features
                # style_img = style_img[None, ...]  # Expand with batch dim
                # style_feats_dict = extract_style_feats(encoder, decoder, style_img, tiers)

                style_img_np = tf.image.resize(style_img[0], (content_img.shape[1],  content_img.shape[1])).numpy()
            
            s = time.time()
            result = content_img
            for _ in range(passes):
                result = stylize(encoder, decoder,
                            result, 
                            # style_img,
                            style_feats=style_feats_dict, 
                            transform=transform, 
                            tiers=tiers,
                            alpha=alpha,
                            swap_tiers=swap_tiers, ss_blend=ss_blend, ss_patch_size=ss_patch_size, ss_stride=ss_stride)
                
            disp_text.update(f"Frame {i} stylized in {time.time() - s}")

            output_img = np.hstack([result[0].numpy(), 
                                    np.zeros([style_height, 3, 3]), 
                                    style_img_np])
            output_img = np.uint8(output_img * 255)
            
            # Convert to jpeg before displaying to speed things up
            # See https://medium.com/@kostal91/displaying-real-time-webcam-stream-in-ipython-at-relatively-high-framerate-8e67428ac522
            f = io.BytesIO()
            Image.fromarray(output_img).save(f, 'jpeg')
            disp_result.update(ipydisplay.Image(data=f.getvalue()))
        #     eval_js('''stream.getVideoTracks()[0].stop();''')
        except Exception as err:
            # Errors will be thrown if the user does not have a webcam or if they do not
            # grant the page permission to access it.
            print(str(err))

In [0]:
#@markdown **Stylization Settings for WCT / AdaIN / Optimal Transport**  

#@markdown ⬅ Press ▶️ to start webcam stylization demo

style_image_size = 512 #@param {type:"slider", min:128, max:2048, step:32}

randomize_style_every = 8 #@param {type:"integer"}

tiers = 5,4,3,2 #@param {type:"raw"}
tiers = [int(t) for t in tiers]

alpha = 0.85 #@param {type:"slider", min:0, max:1, step:0.05}

transform = "WCT" #@param ["WCT", "AdaIN", "Optimal Transport"] {allow-input: false}
if transform == 'WCT':
    transform = wct_tf
elif transform == 'Optimal Transport':
    transform = optimal_tf
else:
    transform = adain

passes = 1 #@param {type:"integer"}

#@markdown **Style-Swap Settings**
style_swap_on = True #@param {type:"boolean"}
style_swap_tiers = 5,  #@param {type:"raw"}
style_swap_tiers = [int(t) for t in style_swap_tiers]
if style_swap_on is False:
    style_swap_tiers = None

style_swap_blend = 0.6 #@param {type:"slider", min:0, max:1, step:0.05}
style_swap_patch_size = 3 #@param {type: "integer"}
style_swap_stride = 1 #@param {type: "integer"}

# multi_level_transform = False #@param {type:"boolean"}


stylize_webcam(style_image_size=style_image_size,
               randomize_every=randomize_style_every, 
               tiers=tiers, 
               alpha=alpha,
               transform=transform,
               swap_tiers=style_swap_tiers,
               ss_blend=style_swap_blend,
               ss_patch_size=style_swap_patch_size,
               ss_stride=style_swap_stride,
               passes=passes)

# Style Swap

In [0]:
!wget -c -O sullivan.jpg https://www.dropbox.com/s/rdzkzye8iihgtim/sullivan.jpg?dl=1

In [0]:
from PIL import Image
import io
import IPython.display as ipydisplay
import time
import numpy as np

def stylize_webcam(style_image_size, randomize_every, tiers, alpha, transform, passes, 
                   swap_tiers, ss_blend, ss_patch_size, ss_stride):
    setup_stream()
    disp_result = ipydisplay.display('', display_id='stylization')
    disp_text = ipydisplay.display('', display_id='time')

    style_img = tf.io.read_file('sullivan.jpg')
    style_img = tf.image.decode_jpeg(style_img, channels=3)
    style_img = resize_image_keep_aspect(style_img, style_image_size)
    # style_img = tf.image.resize(style_img, [style_image_size, style_image_size])
    style_img = tf.cast(style_img, tf.float32)
    style_img = crop_center(style_img)
    style_img /= 255.0  # normalize to [0,1] range
    style_img = style_img[None, ...]  # Expand with batch dim
    style_feats_dict = extract_style_feats(encoder, decoder, style_img, tiers)

    i = -1
    while True:
        i += 1
        try: 
            binary = take_photo()
            content_img = Image.open(io.BytesIO(binary))
            content_img = np.array(content_img)[None, ...].astype('float32') / 255

            if i % randomize_every == 0:
                style_height = content_img.shape[1]
                # style_path = random.choice(style_image_paths)
                style_path = 'sullivan.jpg'
                # disp_text.update(f"Loading style from {style_path}")

                style_img_np = tf.image.resize(style_img[0], (content_img.shape[1],  content_img.shape[1])).numpy()
            
            s = time.time()
            result = content_img
            for _ in range(passes):
                result = stylize(encoder, decoder,
                            result, 
                            # style_img,
                            style_feats=style_feats_dict, 
                            transform=transform, 
                            tiers=tiers,
                            alpha=alpha,
                            swap_tiers=swap_tiers, ss_blend=ss_blend, ss_patch_size=ss_patch_size, ss_stride=ss_stride)
                
            disp_text.update(f"Frame {i} stylized in {time.time() - s}")

            output_img = np.hstack([result[0].numpy(), 
                                    np.zeros([style_height, 3, 3]), 
                                    style_img_np])
            output_img = np.uint8(output_img * 255)
            
            # Convert to jpeg before displaying to speed things up
            # See https://medium.com/@kostal91/displaying-real-time-webcam-stream-in-ipython-at-relatively-high-framerate-8e67428ac522
            f = io.BytesIO()
            Image.fromarray(output_img).save(f, 'jpeg')
            disp_result.update(ipydisplay.Image(data=f.getvalue()))
        #     eval_js('''stream.getVideoTracks()[0].stop();''')
        except Exception as err:
            # Errors will be thrown if the user does not have a webcam or if they do not
            # grant the page permission to access it.
            print(str(err))

In [0]:
#@markdown **Stylization Settings for WCT / AdaIN / Optimal Transport**  

#@markdown ⬅ Press ▶️ to start webcam stylization demo

style_image_size = 512 #@param {type:"slider", min:128, max:2048, step:32}

randomize_style_every = 80 #@param {type:"integer"}

tiers = 5,4,3,2 #@param {type:"raw"}
tiers = [int(t) for t in tiers]

alpha = 0.85 #@param {type:"slider", min:0, max:1, step:0.05}

transform = "WCT" #@param ["WCT", "AdaIN", "Optimal Transport"] {allow-input: false}
if transform == 'WCT':
    transform = wct_tf
elif transform == 'Optimal Transport':
    transform = optimal_tf
else:
    transform = adain

passes = 1 #@param {type:"integer"}

#@markdown **Style-Swap Settings**
style_swap_on = True #@param {type:"boolean"}
style_swap_tiers = 5,  #@param {type:"raw"}
style_swap_tiers = [int(t) for t in style_swap_tiers]
if style_swap_on is False:
    style_swap_tiers = None

style_swap_blend = 1 #@param {type:"slider", min:0, max:1, step:0.05}
style_swap_patch_size = 3 #@param {type: "integer"}
style_swap_stride = 1 #@param {type: "integer"}

# multi_level_transform = False #@param {type:"boolean"}


stylize_webcam(style_image_size=style_image_size,
               randomize_every=randomize_style_every, 
               tiers=tiers, 
               alpha=alpha,
               transform=transform,
               swap_tiers=style_swap_tiers,
               ss_blend=style_swap_blend,
               ss_patch_size=style_swap_patch_size,
               ss_stride=style_swap_stride,
               passes=passes)