In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import time
import IPython.display as display

from util import imshow, load_img, save_img
from model import StyleTransferModel, print_stats
from losses import clip_0_1, content_loss

# load input images
content_path = tf.keras.utils.get_file('neckarfront.jpg','https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg')
style_path = tf.keras.utils.get_file('starry-night.jpg','https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg')
content_image = load_img(content_path)
style_image = load_img(style_path)

plt.subplot(1, 2, 1)
imshow(content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(style_image, 'Style Image')

plt.show()

In [None]:
# reconstruct content, for every layer
content_layers = ['block1_conv1',
                    'block2_conv1',
                    'block3_conv1', 
                    'block4_conv1', 
                    'block5_conv1']

opt = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)

for content_layer in content_layers:
    extractor = StyleTransferModel(['block1_conv1'], [content_layer])
    results = extractor(tf.constant(content_image))

    # the variable to optimize
    image = tf.Variable(tf.random.uniform(content_image.shape))

    content_targets = extractor(content_image)['content']
    content_weights = tf.constant([ 1e10 ])
    
    @tf.function()
    def train_step(image):
      with tf.GradientTape() as tape:
        outputs = extractor(image)
        loss = content_loss(outputs['content'], content_targets, content_weights)

      grad = tape.gradient(loss, image)
      opt.apply_gradients([(grad, image)])
      image.assign(clip_0_1(image))

    start = time.time()

    epochs = 20
    steps_per_epoch = 100

    step = 0
    for n in range(epochs):
      for m in range(steps_per_epoch):
        step += 1
        train_step(image)
        print(".", end='')
      display.clear_output(wait=True)
      imshow(image.read_value())
      plt.title("Train step: {}".format(step))
      plt.show()

    end = time.time()
    print("Total time: {:.1f}".format(end-start))

    save_img(image[0], 'content_{}.png'.format(content_layer))