#### Import the necessary packages

In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from time import time

#### Load the base and style images, adjust image size

In [None]:
# paths for the images
base_img_path = "../images/san_francisco.jpg"
style_ref_img_path = "../images/van_gogh_starry_night.jpg"

# load images, adjust sizes so that vertical is 400 pixels
original_width, original_height = keras.utils.load_img(base_img_path).size
img_height = 400
img_width = round(original_width * img_height / original_height) # = 533

#### Utility functions for pre or post processing images 

In [None]:
def preprocess_image(image_path): 
    '''
    loads and converts the image at image_path to an appropriate array
    applies vgg19 preprocessing
    '''
    img = keras.utils.load_img(image_path, target_size=(img_height, img_width)) 
    img = keras.utils.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = keras.applications.vgg19.preprocess_input(img) 
    return img

def deprocess_image(img):
    ''' 
    converts a numpy array into a valid image
    reverts vgg19 preprocessing
    '''
    img = img.reshape((img_height, img_width, 3))
    
    # reverses a transformation done by vgg19.preprocess_input
    img[:, :, 0] += 103.939
    img[:, :, 1] += 116.779
    img[:, :, 2] += 123.68
    
    # BGR to RGB, also part of reverting the transformation
    img = img[:, :, ::-1]
    img = np.clip(img, 0, 255).astype("uint8") 
    return img

#### Load the VGG19 model minus the final fully connected layers

In [None]:
# VGG19 model loaded with pretrained ImageNet weights
# output layers (top 4) not included (as it should be)
model = keras.applications.vgg19.VGG19(weights="imagenet", include_top=False) 

# a dictionary for the activation values of every (target) layer
outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])

# model that returns (in a dict) the activation values of all layers
feature_extractor = keras.Model(inputs=model.inputs, outputs=outputs_dict)

#### Define the content loss

In [None]:
def content_loss(base_img, combination_img):
    '''
    computes L2 content loss between base image and the combination (generated) image
    '''
    return tf.reduce_sum(tf.square(combination_img - base_img))

#### Define the style loss

In [None]:
def gram_matrix(x):
    '''
    computes the gram matrix x times x.T, after reshaping appropriately
    '''
    x = tf.transpose(x, (2, 0, 1))
    features = tf.reshape(x, (tf.shape(x)[0], -1)) 
    gram = tf.matmul(features, tf.transpose(features)) 
    return gram

def style_loss(style_img, combination_img): 
    '''
    computes the style loss between gram style S and gram combination C matrices (images)
    essentially computes
                             \sum_{ij} (S - C)_{ij}^2
                            --------------------------
                             (2 * channels * size)^2
    '''
    S = gram_matrix(style_img)
    C = gram_matrix(combination_img) 
    channels = 3
    size = img_height * img_width
    return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels ** 2) * (size ** 2))

#### Define the total variation (continuity) loss

In [None]:
def total_variation_loss(x):
    '''
    total variation loss ensures continuity across resulting image
    as an L 1.25 norm
    in both the vertical (a tensor) and horizontal (b tensor) direction
    '''
    a = tf.square(x[:, : img_height - 1, : img_width - 1, :] - x[:, 1:, : img_width - 1, :])
    b = tf.square(x[:, : img_height - 1, : img_width - 1, :] - x[:, : img_height - 1, 1:, :])
    return tf.reduce_sum(tf.pow(a + b, 1.25))

#### Define the total loss, set loss weights

In [None]:
# layers to use for the style loss
style_layer_names = [
    "block1_conv1",
    "block2_conv1",
    "block3_conv1",
    "block4_conv1",
    "block5_conv1",
] 

# layer to use for the content loss
content_layer_name = "block5_conv2"

# weight for total variation/style/content loss
total_variation_weight = 1e-6
style_weight = 1e-6
content_weight = 2.5e-8

def compute_loss(combination_image, base_image, style_reference_image): 
    '''
    computes the total loss, by adding: content + style + total_variation
    '''
    input_tensor = tf.concat([base_image, style_reference_image, combination_image], axis=0)
    
    # extract weights
    features = feature_extractor(input_tensor)
    
    # initialize loss
    loss = tf.zeros(shape=())
    
    # add content loss
    layer_features = features[content_layer_name]
    base_image_features = layer_features[0, :, :, :]
    combination_features = layer_features[2, :, :, :]
    loss = loss + content_weight * content_loss(base_image_features, combination_features)
    
    # add style loss
    for layer_name in style_layer_names:
        layer_features = features[layer_name] 
        style_reference_features = layer_features[1, :, :, :] 
        combination_features = layer_features[2, :, :, :]
        style_loss_value = style_loss(style_reference_features, combination_features)
        loss += (style_weight / len(style_layer_names)) * style_loss_value 
        
    # add total variation loss
    loss += total_variation_weight * total_variation_loss(combination_image)
    
    return loss

#### Compute the loss and gradients

In [None]:
@tf.function # makes training step fast by compiling as a tf.function
def compute_loss_and_grads(combination_image, base_image, style_reference_image): 
    '''
    computes loss and gradients
    '''
    with tf.GradientTape() as tape:
        loss = compute_loss(combination_image, base_image, style_reference_image)
    grads = tape.gradient(loss, combination_image) 
    
    return loss, grads

#### Load SGD with momentum optimizer

In [None]:
# use SGD optimizer (gradient descent with momentum)
# start with a learning rate of 100 and decrease it by 4% every 100 steps
# TODO: lots of optimization (pun intended) could go in here
optimizer = keras.optimizers.SGD(keras.optimizers.schedules.ExponentialDecay(
     initial_learning_rate=100.0, decay_steps=100, decay_rate=0.96))

#### Preprocess images before training

In [None]:
# preprocess original and style image
base_img = preprocess_image(base_img_path)
style_ref_img = preprocess_image(style_ref_img_path)

# will be updated throughout, so use a Variable to store it
combination_img = tf.Variable(preprocess_image(base_img_path))

#### Train the network

In [None]:
iterations = 4000
for i in range(1, iterations + 1):    
    tic = time()
    # compute loss and gradients
    loss, grads = compute_loss_and_grads(combination_img, base_img, style_ref_img)
    # do gradient descent
    # update the combination image in a direction minimizing the loss
    optimizer.apply_gradients([(grads, combination_img)]) 
    toc = time()

    # comment the below line if you don't want step by step details 
    print(f"iteration {i:04}: loss={loss:.4f} time={(toc - tic):02.2f}")
    
    # save img at regular intervals
    if i % 100 == 0:
        img = deprocess_image(combination_img.numpy()) 
        fname = f"../images/generated_image_at_iteration_{i}.png" 
        keras.utils.save_img(fname, img)