In [None]:
# https://arxiv.org/pdf/1508.06576.pdf
# https://riptutorial.com/keras/example/32608/transfer-learning-using-keras-and-vgg
!pip install -q tensorflow-gpu==2.0.0-beta1

import numpy as np
import tensorflow as tf
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

CONTENT_LAYER = "block2_conv2"
VARIATION_LOSS_WEIGHT = tf.constant(1e8, dtype=tf.float32)
STYLE_LAYERS = [
    "block1_conv2", "block2_conv2",
    "block3_conv3", "block4_conv3",
    "block5_conv3"
]

In [None]:
content_path = tf.keras.utils.get_file(
    'turtle.jpg',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/Green_Sea_Turtle_grazing_seagrass.jpg')
style_path = tf.keras.utils.get_file(
    'kandinsky.jpg',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/Vassily_Kandinsky%2C_1913_-_Composition_7.jpg')

def imshow(image, title=None):
    if len(image.shape) > 3:
        image = tf.squeeze(image, axis=0)

    plt.imshow(image)
    if title:
        plt.title(title)

def load_image(image_path):
    max_dim = 512
    image = tf.io.read_file(image_path)
    image = tf.image.decode_image(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)

    shape = tf.cast(tf.shape(image)[:-1], tf.float32)
    long_dim = max(shape)
    scale = max_dim / long_dim

    new_shape = tf.cast(shape * scale, tf.int32)

    image = tf.image.resize(image, new_shape)
    image = image[tf.newaxis, :]
    return image

style_image = load_image(style_path)
content_image = load_image(content_path)

In [None]:
def replace_max_by_average_pooling(model):

    input_layer, *other_layers = model.layers
    assert isinstance(input_layer, tf.keras.layers.InputLayer)

    x = input_layer.output
    for layer in other_layers:
        if isinstance(layer, tf.keras.layers.MaxPooling2D):
            layer = tf.keras.layers.AveragePooling2D(
                pool_size=layer.pool_size,
                strides=layer.strides,
                padding=layer.padding,
                data_format=layer.data_format,
                name=f"{layer.name}_av",
            )
        x = layer(x)

    return tf.keras.models.Model(inputs=input_layer.input, outputs=x)

# If you are only interested in convolution filters. Note that by not
# specifying the shape of top layers, the input tensor shape is (None, None, 3),
# so you can use them for any size of images.
vgg_model = tf.keras.applications.VGG19(
    weights='imagenet',
    include_top=False,
)

vgg_model.trainable = False
vgg_model = replace_max_by_average_pooling(vgg_model)
outputs = {
    "style_layers": [vgg_model.get_layer(layer).get_output_at(1) for layer in STYLE_LAYERS],
    "content_layer": vgg_model.get_layer(CONTENT_LAYER).get_output_at(1)
}
model = tf.keras.Model([vgg_model.input], outputs)

In [None]:
target_content = model(content_image)["content_layer"]
target_style = model(style_image)["style_layers"]
predicted_outputs = {
    "content_layer": target_content,
    "style_layers": target_style
}
input_image = tf.Variable(content_image)
opt = tf.optimizers.Adam(learning_rate=.02, beta_1=.99, epsilon=1e-1)

@tf.function
def compute_gram_matrix(style_features):
    b, h, w, c = style_features.shape
    right_side = tf.keras.backend.reshape(style_features, [b, -1, c])
    left_side = tf.transpose(right_side, [0, 2, 1])
    
    gram_matrix = tf.matmul(left_side, right_side)
    
    # average (maybe not the most efficient)
    gram_matrix = gram_matrix / (h * w)
    
    # shape [b, c, c]
    return gram_matrix

@tf.function
def style_transfer_loss(outputs, predicted_outputs):
    style_weights = [tf.constant(.2, dtype=tf.float32) for i in range(5)]
    alpha = tf.constant(1, dtype=tf.float32)
    beta = tf.constant(1000, dtype=tf.float32)
    #style_weights = tf.keras.backend.reshape(style_weights, [1, 1, 1, 5])

    outputs["style_layers"]
    outputs["content_layer"]

    predicted_outputs["style_layers"]
    predicted_outputs["content_layer"]

    predicted_output_gram_matrices = [
        compute_gram_matrix(layer) for layer in 
        predicted_outputs["style_layers"]
    ]
    
    output_gram_matrices = [
        compute_gram_matrix(layer) for layer in 
        outputs["style_layers"]
    ]

    style_loss = [
        tf.reduce_sum(tf.keras.backend.square(output - predicted)) * weight / 4
        for output, predicted, weight in 
        zip(output_gram_matrices, predicted_output_gram_matrices, style_weights)
    ]
    style_loss = tf.reduce_sum(style_loss)

    content_loss = .5 * tf.keras.backend.square(
        outputs["content_layer"] - predicted_outputs["content_layer"]
    )
    content_loss = tf.reduce_mean(content_loss)

    loss = alpha * content_loss + beta * style_loss
    return loss


In [None]:


def clip_0_1(image):
    
    return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)

def high_pass_x_y(image):
    x_var = image[:,:,1:,:] - image[:,:,:-1,:]
    y_var = image[:,1:,:,:] - image[:,:-1,:,:]

    return x_var, y_var

def total_variation_loss(image):
    x_deltas, y_deltas = high_pass_x_y(image)
    return tf.reduce_mean(x_deltas**2) + tf.reduce_mean(y_deltas**2)

@tf.function
def train_step(input_image):
    with tf.GradientTape() as tape:
        output = model(input_image)
        loss = style_transfer_loss(output, predicted_outputs)
        loss += VARIATION_LOSS_WEIGHT * \
            total_variation_loss(input_image)
    gradients = tape.gradient(loss, input_image)
    opt.apply_gradients([(gradients, input_image)])
    input_image.assign(clip_0_1(input_image))

In [None]:
num_steps = 100
num_epochs = 10
for epoch in range(1, 1 + num_epochs):
    for i in range(num_steps):
        train_step(input_image)
    plt.subplot(2, 5, epoch)
    imshow(tf.identity(input_image), 'Epoch: {}'.format(epoch))

In [None]:
plt.subplot(1, 3, 1)
imshow(content_image, "Content Image")
plt.subplot(1, 3, 2)
imshow(style_image, "Style Image")
plt.subplot(1, 3, 3)
imshow(input_image, "New Image")

In [None]:
# How to save input_image
from PIL import Image
Image.fromarray(
    np.array(
        tf.image.convert_image_dtype(
            tf.squeeze(input_image) ,tf.uint8
        )
    )
).save("scanner_darkly_me.png")
