#### For this implementation I used keras layers over tensorflow in eager mode, theorethical aspects can be [found here](https://cs.stanford.edu/people/jcjohns/papers/eccv16/JohnsonECCV16.pdf)

In [146]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import matplotlib as mpl
import IPython.display as display
import functools
mpl.rcParams['axes.grid'] = False
mpl.rcParams['figure.figsize'] = (7,7)

#### Defining helper functions for images. Maximum lenth is 512. Each dimension is scaled down to number divisible by spatial downgradings in our network which guarantees same output and input shapes. Network can learn on images of arbitrary sizes, still it is best to provide training data of same(proportional) sizes for faster learning process.

In [162]:
def load_img(path_to_img, scale=True, div=4):
    img = tf.io.read_file(path_to_img)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    shape = tf.cast(tf.shape(img)[:-1], tf.float32)
    long_dim = max(shape.numpy())
    
    if scale:
        max_dim = 512
    else:
        max_dim = 1000

    scale = max_dim / long_dim
    new_shape = tf.cast(tf.round((shape * scale)/div) * div, tf.int32)
    img = tf.image.resize(img, new_shape)
    img = np.expand_dims(img, axis=0)
    return img

def imshow(image, title=None, scale=False):
    if len(image.shape) > 3:
        image = tf.squeeze(image, axis=0)

    show = image
    plt.imshow(tf.cast(show, tf.uint8))
    if title:
        plt.title(title)

#### Adding functions to preprocess and deprocess images for our loss model(VGG19 in this case). Basically deprocessing consists of shifting channels means and rearranging the order. 

In [152]:
def preprocessed_img(image):
    image = image.copy()
    image = tf.keras.applications.vgg19.preprocess_input(image)
    return image

def deprocess_img(image, shifts=[103.939, 116.779, 123.68]):
    img = image.copy()
    if len(img.shape)==4:
        img = np.squeeze(img, 0)
        
    assert len(img.shape)==3
    
    for i in tf.range(3):
        img[:,:,i] += shifts[i]
        
    #reverse channels
    img = img[:, :, ::-1]
    img = np.clip(img, 0, 255).astype('uint8')
    return img

#### Loss model initializer.

In [153]:
def vgg19_layers(layers):
    vgg19 = tf.keras.applications.VGG19(include_top=False, weights='imagenet')    
    outputs = [vgg19.get_layer(layer).output for layer in layers]
    
    return keras.Model([vgg19.input], outputs)

#### Names of VGG19 layers which will be used to compute content and style losses.

In [140]:
content_layers = ['block5_conv2'] 
style_layers = ['block1_conv2',
                'block2_conv2',
                'block3_conv2', 
                'block4_conv2', 
                'block5_conv1']

content_layers_len = len(content_layers)
style_layers_len = len(style_layers)

#### Calculating gram matrix which will be used for computing style loss (check tf.linalg.einsum for more info).

In [161]:
def gram_matrix(image):
    gram = tf.linalg.einsum('bijc,bijd->bcd', image, image)
    shape = image.get_shape()
    i, j = shape[1], shape[2]
    
    return gram / tf.cast(i * j, tf.float32)

#### Transform model consist of:
  * input convolution layer
  * 2 downsampling convolutions
  * 5 residual blocks
  * 2 upsampling convolutions
  * output convolution layer

#### In some cases like this downsampling can have positive effect because we can have much larger number of filters without performance decrease.

In [155]:
from functools import partial
DefaultConv2D = partial(keras.layers.Conv2D, kernel_size=3, strides=1, padding='SAME', kernel_initializer='lecun_normal')

class ResidualUnit(keras.layers.Layer):
    def __init__(self, filters, strides=1, activation='elu', **kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)
        self.main_layers = [
            DefaultConv2D(filters, strides=strides),
            keras.layers.BatchNormalization(),
            self.activation,
            DefaultConv2D(filters),
            keras.layers.BatchNormalization(),
        ]
    
    
    def call(self, inputs):
        y = inputs
        for layer in self.main_layers:
            y = layer(y)
            
        return self.activation(y + inputs)

In [166]:
class StyleContentModel(keras.Model):
    def __init__(self, style_layers, content_layers):
        super(StyleContentModel, self).__init__()
        self.vgg19.trainable = False
        self.vgg19 = vgg19_layers(style_layers + content_layers)
        self.activation = keras.activations.get('elu')
        self.style_layers = style_layers
        self.content_layers = content_layers
        self.num_style_layers = len(style_layers)
        
        self.transform_layers = [ # <-- transform part
            DefaultConv2D(32, kernel_size=9),
            keras.layers.BatchNormalization(),
            self.activation,
            DefaultConv2D(64, strides=2, kernel_size=2),
            keras.layers.BatchNormalization(),
            self.activation,
            DefaultConv2D(128, strides=2, kernel_size=2),
            keras.layers.BatchNormalization(),
            self.activation,
            keras.layers.Conv2DTranspose(64, 2, strides=2, kernel_initializer='lecun_normal'), #Lecun with ELU may improve normalization 
            keras.layers.BatchNormalization(),
            self.activation,
            keras.layers.Conv2DTranspose(32, 2, strides=2, kernel_initializer='lecun_normal'),
            keras.layers.BatchNormalization(),
            self.activation,
            DefaultConv2D(3, kernel_size=9),
            keras.activations.get('tanh')
        ]
        for i in range(5):
            self.transform_layers.insert(9, ResidualUnit(128, 1))
        
        
    def call(self, image, transform=True):
        img2 = image
        outputs = image
        if transform:
            for layer in self.transform_layers:
                img2 = layer(img2)

            img2 *= 255. #increasing back to *255 for vgg19
            outputs = img2
                
        outputs = self.vgg19(outputs) # <-- loss part
        style_outputs, content_outputs = (outputs[:self.num_style_layers],
                                          outputs[self.num_style_layers:])
        
        style_outputs = [gram_matrix(style) for style in style_outputs]
        
        styles = {name:style for name, style in zip(self.style_layers, style_outputs)}
        contents = {name:content for name, content in zip(self.content_layers, content_outputs)}
        
        return {'styles':styles, 'contents':contents, 'image':img2}

#### tf.Dataset may not work eagerly in some tf versions, if not, replace it by custom python generator. Channel division by 255 improve transform network precision.

In [167]:
def load_paths(images_path, num):
    filepaths = []
    for i in range(num):
        filepaths.append(images_path + '/' + str(i+1) + '.jpg')
    
    return filepaths

def create_train_dataset(filepaths, repeat=None, shuffle_buffer_size=20, n_parse_threads=5):
    dataset = tf.data.Dataset.list_files(filepaths).repeat(repeat)
    dataset = dataset.map(lambda x: preprocessed_img(load_img(x))/255., num_parallel_calls=n_parse_threads)
    dataset = dataset.map(lambda x: (x, sc_model(x, transform=False)['contents']), num_parallel_calls=n_parse_threads)
    dataset = dataset.shuffle(shuffle_buffer_size)
    return dataset.prefetch(1)

In [124]:
def style_loss(a, b, s_weight=1e-5):
    n = len(a)
    loss = tf.add_n([tf.reduce_mean(tf.square(a[name] - b[name])) for name in a.keys()]) / n
    loss *= s_weight
    
    return loss

In [144]:
def content_loss(a, b, c_weight=1e1):
    n = len(a)
    loss = tf.add_n([tf.reduce_mean(tf.square(a[name] - b[name])) for name in a.keys()]) / n
    loss *= c_weight
    
    return loss

#### Defining artifact loss algorithm which basically calculates differences between nearest image pixels. 

In [157]:
def artifact_filter(image):
    x_pass = image[:,:,1:,:] - image[:,:,:-1,:]
    y_pass = image[:,1:,:,:] - image[:,:-1,:,:]
    return x_pass, y_pass

def artifacts_loss(image, weight):
    image = tf.cast(image, tf.float32)
    x_p, y_p = artifact_filter(image)
    loss = weight * (tf.reduce_mean(tf.square(x_p)) + tf.reduce_mean(tf.square(y_p)))
    return loss

In [158]:
style_path = '/Users/user/Documents/train_images/style.jpg'
images_path = '/Users/user/Documents/train_images/targets'

In [168]:
sc_model = StyleContentModel(style_layers, content_layers)

style_image = preprocessed_img(load_img(style_path, scale=False))

style_learn = sc_model(style_image, transform=False)['styles']
train_dataset = create_train_dataset(load_paths(images_path, 10))

AttributeError: 'StyleContentModel' object has no attribute 'vgg19'

#### Found Nadam working best for this task.

In [160]:
optimizer = keras.optimizers.Nadam(0.00001)

In [None]:
@tf.function()
def train_step(image, content_learn):
    with tf.GradientTape() as tape:
        outputs = sc_model(image)
        style_l = style_loss(outputs['styles'], style_learn)
        content_l = content_loss(outputs['contents'], content_learn)
        artifact_l = artifacts_loss(outputs['image'], 2)
        loss = style_l + content_l + artifact_l
    
    img = outputs['image']    
    vars = sc_model.trainable_variables
    gradients = tape.gradient(loss, vars)
    opt = optimizer.apply_gradients([(g, v) for g, v in zip(gradients, vars)])
    return style_l, content_l, artifact_l, img

In [22]:
epochs = 100000
steps_per_epoch = 40

step = 0
for n in range(epochs):
    for c_image, c_learn in train_dataset.take(steps_per_epoch):
        step += 1
        s_l, c_l, a_l, img = train_step(c_image, c_learn)
    
    tmp = deprocess_img(img.numpy())
    imshow(tmp)    
    display.clear_output(wait=True)
    plt.title("Train step: {} Loss_s: {} Loss_c: {} Loss_a: {}".format(step, s_l, c_l, a_l))
    plt.show()

NameError: name 'train_dataset' is not defined