In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import time
import IPython.display as display

from util import imshow, load_img, save_img
from model import StyleTransferModel, print_stats
from losses import clip_0_1, style_loss

# load input images
content_path = tf.keras.utils.get_file('neckarfront.jpg','https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg')
style_path = tf.keras.utils.get_file('starry-night.jpg','https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg')
content_image = load_img(content_path)
style_image = load_img(style_path)

plt.subplot(1, 2, 1)
imshow(content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(style_image, 'Style Image')

plt.show()

In [None]:
# reconstruct style, for every layer
style_layers = ['block1_conv1',
                'block2_conv1',
                'block3_conv1', 
                'block4_conv1', 
                'block5_conv1']


opt = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)

for idx in range(len(style_layers)):
    extractor = StyleTransferModel(style_layers[:idx+1], ['block1_conv1'])
    results = extractor(tf.constant(style_image))

    # the variable to optimize
    image = tf.Variable(tf.random.uniform(style_image.shape))

    style_targets = extractor(style_image)['style']

    # style_weights = [ 1e3/n**2 for n in [64, 128, 256, 512, 512] ]
    style_weights = [ 1.0, 1.0, 1.0, 1.0, 1.0 ]

    style_weights = style_weights[:idx+1]

    # the weights are normalized
    style_weights = [ w/sum(style_weights) for w in style_weights ]
    style_weights = tf.constant(style_weights)


    @tf.function()
    def train_step(image):
      with tf.GradientTape() as tape:
        outputs = extractor(image)
        total_loss = style_loss(outputs['style'], style_targets, style_weights)

      grad = tape.gradient(total_loss, image)
      opt.apply_gradients([(grad, image)])
      image.assign(clip_0_1(image))

    start = time.time()

    epochs = 20
    steps_per_epoch = 100

    step = 0
    for n in range(epochs):
      for m in range(steps_per_epoch):
        step += 1
        train_step(image)
        print(".", end='')
      display.clear_output(wait=True)
      imshow(image.read_value())
      plt.title("Train step: {}".format(step))
      print(style_layers[:idx+1])
      plt.show()

    end = time.time()
    print("Total time: {:.1f}".format(end-start))

    save_img(image[0], 'style_{}.png'.format(style_layers[idx]))