In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import time
import IPython.display as display

from util import imshow, load_img, save_img, clip_0_1
from model import StyleTransferModel, print_stats
from losses import style_content_loss, total_variation_loss

# load input images
content_path = tf.keras.utils.get_file('neckarfront.jpg','https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg')
#style_path = tf.keras.utils.get_file('starry-night.jpg','https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg')
style_path = tf.keras.utils.get_file('kandinsky.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/Vassily_Kandinsky%2C_1913_-_Composition_7.jpg')

#content_path = tf.keras.utils.get_file('stata.jpg','https://raw.githubusercontent.com/lengstrom/fast-style-transfer/master/examples/content/stata.jpg')
#style_path = tf.keras.utils.get_file('udnie.jpg','https://raw.githubusercontent.com/jcjohnson/fast-neural-style/master/images/styles/udnie.jpg')

#content_path = tf.keras.utils.get_file('lake2.jpg','https://i.imgur.com/DDMNAUP.jpg')
#content_path = tf.keras.utils.get_file('forest.jpg','https://www.positive.news/wp-content/uploads/2019/03/feat-1800x0-c-center.jpg')
#style_path = tf.keras.utils.get_file('abstract2.jpg','https://i.imgur.com/9y4UfcK.jpg')

# style_path = tf.keras.utils.get_file('candy.jpg','https://raw.githubusercontent.com/jcjohnson/fast-neural-style/master/images/styles/candy.jpg')

# style_path = tf.keras.utils.get_file('wave.jpg','https://raw.githubusercontent.com/lengstrom/fast-style-transfer/master/examples/style/wave.jpg')


content_image = load_img(content_path, max_dim=512)
style_image = load_img(style_path, max_dim=512)

plt.subplot(1, 2, 1)
imshow(content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(style_image, 'Style Image')

plt.show()

In [None]:
content_layers = ['block5_conv1']
content_weights = [1.0]

style_layers = ['block1_conv1',
                'block2_conv1',
                'block3_conv1', 
                'block4_conv1', 
                'block5_conv1']

style_weights = [1.0, 1.0, 1.0, 1.0, 1.0]
style_weights = [w/sum(style_weights) for w in style_weights]

# total_variation_weight = 2e2

extractor = StyleTransferModel(style_layers, content_layers)

content_targets = extractor(content_image)['content']
style_targets = extractor(style_image)['style']

alpha = 1e-2
beta = 1

In [None]:
opt = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)
# the variable to optimize
image = tf.Variable(content_image)
# image = tf.Variable(tf.random.uniform(content_image.shape))

@tf.function()
def train_step(image):
  with tf.GradientTape() as tape:
    outputs = extractor(image)
    loss = style_content_loss(outputs['content'], content_targets, content_weights, alpha, outputs['style'], style_targets, style_weights, beta)
    # loss += total_variation_weight*total_variation_loss(image)
    
  grad = tape.gradient(loss, image)
  opt.apply_gradients([(grad, image)])
  image.assign(clip_0_1(image))

start = time.time()

epochs = 20
steps_per_epoch = 100

step = 0
for n in range(epochs):
  for m in range(steps_per_epoch):
    step += 1
    train_step(image)
    print(".", end='')
  display.clear_output(wait=True)
  imshow(image.read_value())
  plt.title("Train step: {}".format(step))
  plt.show()

end = time.time()
print("Total time: {:.1f}".format(end-start))

save_img(image[0], 'output.png')