## Installations

In [None]:
!pip install --upgrade tensorflow==2.1.0
!pip install google-api-python-client
!pip install apiclient
!pip install  --no-deps tensorflow-addons~=0.6
!pip install typeguard==2.7.1
!pip install -U tensorflow-gpu==2.1.0 grpcio

In [None]:
import tensorflow as tf
print(tf.__version__)
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('GPU device found!')

## Get Data

In [None]:
# This will free up some disk space before we import the dataset, which is ~13 GB.
!rm -rf /usr/local/lib/python2.7
!rm -rf /swift
!rm -rf /tensorflow-1.15.2/python2.7

In [None]:
!mkdir /tmp/data
!mkdir /tmp/data/train2014

In [None]:
# Retrieve the dataset. I saved a copy of the zip file to Google Drive 
# to avoid making requests each time I openned Colab.
!wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip -O /tmp/train2014.zip

In [None]:
!unzip -q /tmp/train2014.zip -d /tmp/data

In [None]:
!rm /tmp/train2014.zip

In [None]:
!wget -q https://static.wixstatic.com/media/507339_37d143096d6249f3b7253a4b474a1a3a~mv2.jpg -O /tmp/udnie.jpg
!wget -q https://static.wixstatic.com/media/507339_e9064ed1c95d4e1eb76d31ff13117d9e~mv2.jpg -O /tmp/mosaic.jpg
!wget -q https://static.wixstatic.com/media/507339_f11f19449cd54141a1e95ad4b4419ccf~mv2.jpg -O /tmp/the_scream.jpg
!wget -q https://static.wixstatic.com/media/507339_62306b21cef045d599ac272aa410ca3a~mv2.jpg -O /tmp/rain_princess.jpg
!wget -q https://static.wixstatic.com/media/507339_b3b8b65faa854e83b6924344594b2807~mv2.jpg -O /tmp/wave.jpg
!wget -q https://static.wixstatic.com/media/507339_0d4a0400a92c4db28353d9daab20bad0~mv2.jpg -O /tmp/chicago.jpg

## Models

In [None]:
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from skimage import transform

def get_image(img_path, img_size=False):
  img = load_img(img_path)
  img = img_to_array(img, dtype=np.float32)
  if img_size != False:
    img = resize_img(img, img_size)
  return img

def resize_img(img, size):
  if len(size) == 2:
    size += (3,)
  return transform.resize(img, size, preserve_range=True)

In [None]:
from tensorflow.keras.layers import Activation, Add, BatchNormalization, Conv2D, Conv2DTranspose, Layer
from tensorflow_addons.layers import InstanceNormalization


class ConvLayer(Layer):
  def __init__(self, filters, 
               kernel=(3,3), padding='same', 
               strides=(1,1), activate=True, name="", 
               weight_initializer="glorot_uniform"
               ):
    super(ConvLayer, self).__init__()
    self.activate = activate
    self.conv = Conv2D(filters, kernel_size=kernel, 
                       padding=padding, strides=strides, 
                       name=name, trainable=True,
                       use_bias=False, 
                       kernel_initializer=weight_initializer)
    self.inst_norm = InstanceNormalization(axis=3, 
                                          center=True, 
                                          scale=True, 
                                          beta_initializer="zeros", 
                                          gamma_initializer="ones",
                                          trainable=True)
    if self.activate:
      self.relu_layer = Activation('relu', trainable=False)

  def call(self, x):
    x = self.conv(x)
    x = self.inst_norm(x)
    if self.activate:
      x = self.relu_layer(x)
    return x


class ResBlock(Layer):
  def __init__(self, filters, kernel=(3,3), padding='same', weight_initializer="glorot_uniform", prefix=""):
    super(ResBlock, self).__init__()
    self.prefix_name = prefix + "_"
    self.conv1 = ConvLayer(filters=filters, 
                           kernel=kernel, 
                           padding=padding, 
                           weight_initializer=weight_initializer,
                           name=self.prefix_name + "conv_1")
    self.conv2 = ConvLayer(filters=filters, 
                           kernel=kernel, 
                           padding=padding, 
                           activate=False, 
                           weight_initializer=weight_initializer,
                           name=self.prefix_name + "conv_2")
    self.add = Add(name=self.prefix_name + "add")

  def call(self, x):
    tmp = self.conv1(x)
    c = self.conv2(tmp)
    return self.add([x, c])


class ConvTLayer(Layer):
  def __init__(self, filters, kernel=(3,3), padding='same', strides=(1,1), activate=True, name="",
               weight_initializer="glorot_uniform" 
               ):
    super(ConvTLayer, self).__init__()
    self.activate = activate
    self.conv_t = Conv2DTranspose(filters, kernel_size=kernel, padding=padding, 
                                  strides=strides, name=name, 
                                  use_bias=False,
                                  kernel_initializer=weight_initializer)
    self.inst_norm = InstanceNormalization(axis=3, 
                                          center=True, 
                                          scale=True, 
                                          beta_initializer="zeros", 
                                          gamma_initializer="ones",
                                          trainable=True)
    if self.activate:
      self.relu_layer = Activation('relu')

  def call(self, x):
    x = self.conv_t(x)
    x = self.inst_norm(x)
    if self.activate:
      x = self.relu_layer(x)
    return x


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization, Conv2D, Add, Layer, Conv2DTranspose, Activation

class TransformNet:
  def __init__(self):
    self.conv1 = ConvLayer(32, (9,9), strides=(1,1), padding='same', name="conv_1")
    self.conv2 = ConvLayer(64, (3,3), strides=(2,2), padding='same', name="conv_2")
    self.conv3 = ConvLayer(128, (3,3), strides=(2,2), padding='same', name="conv_3")
    self.res1 = ResBlock(128, prefix="res_1")
    self.res2 = ResBlock(128, prefix="res_2")
    self.res3 = ResBlock(128, prefix="res_3")
    self.res4 = ResBlock(128, prefix="res_4")
    self.res5 = ResBlock(128, prefix="res_5")
    self.convt1 = ConvTLayer(64, (3,3), strides=(2,2), padding='same', name="conv_t_1")
    self.convt2 = ConvTLayer(32, (3,3), strides=(2,2), padding='same', name="conv_t_2")
    self.conv4 = ConvLayer(3, (9,9), strides=(1,1), padding='same', activate=False, name="conv_4")
    self.tanh = Activation('tanh')
    self.model = self._get_model()

  def _get_model(self):
    inputs = tf.keras.Input(shape=(None,None,3))
    x = self.conv1(inputs)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.res1(x)
    x = self.res2(x)
    x = self.res3(x)
    x = self.res4(x)
    x = self.res5(x)
    x = self.convt1(x)
    x = self.convt2(x)
    x = self.conv4(x)
    x = self.tanh(x)
    x = (x + 1) * (255. / 2)
    return tf.keras.Model(inputs, x, name="transformnet")

  def get_variables(self):
    return self.model.trainable_variables

  def preprocess(self, img):
    return img / 255.0

  def postprocess(self, img):
    return tf.clip_by_value(img, 0.0, 255.0)


In [None]:
from tensorflow.keras.applications import VGG19
from tensorflow.keras import Model
from collections import namedtuple


VGG_Output = namedtuple('VGG_Output', 'content_output style_output')

class VGGModel:
  def __init__(self,
               content_layers=["conv4_2"],
               style_layers=["conv1_1", "conv2_1", "conv3_1", "conv4_1", "conv5_1"]
               ):
    self.vgg = VGG19(include_top=False, weights='imagenet')
    self.layers = {
      "input"  : 0,
      "conv1_1": 1,
      "conv1_2": 2,
      "pool1"  : 3,
      "conv2_1": 4,
      "conv2_2": 5,
      "pool2"  : 6,
      "conv3_1": 7,
      "conv3_2": 8,
      "conv3_3": 9,
      "conv3_4": 10,
      "pool3"  : 11,
      "conv4_1": 12,
      "conv4_2": 13,
      "conv4_3": 14,
      "conv4_4": 15,
      "pool4"  : 16,
      "conv5_1": 17,
      "conv5_2": 18,
      "conv5_3": 19,
      "conv5_4": 20,
      "pool5"  : 21,
      "flatten": 22,
      "fc1"    : 23,
      "fc2"    : 24,
      "predictions": 25,
    }
    self.content_layers = content_layers
    self.style_layers = style_layers
    self.total_output_layers = self.content_layers + self.style_layers
    self.partition_idx = len(self.content_layers)
    self.model = Model(self.vgg.inputs, self._get_outputs(), trainable=False)

  def forward(self, X):
    outputs = self.model(X)
    return VGG_Output(outputs[:self.partition_idx], outputs[self.partition_idx:])

  def _get_outputs(self):
    return [self.vgg.layers[self.layers[layer]].output for layer in self.total_output_layers]

  def preprocess(self, images):
    images = tf.keras.applications.vgg19.preprocess_input(images)
    images = tf.cast(images, tf.float32)
    return images
  

## Training
The main training loop. 

Note: The log and save protocols were largely used for development. In other implementations, I added to them to copy SavedModels and log files to Google Drive in order to check in on the training model. This is why I used SavedModels over checkpoints. Feel free to copy the notebook and do likewise.

In [None]:
from tensorflow.keras.applications import VGG19
from tensorflow.keras.preprocessing.image import img_to_array, load_img 
from tensorflow.keras.applications.vgg19 import preprocess_input
from collections import namedtuple
from glob import glob
import os
import datetime
import numpy as np
import tensorflow as tf

Loss = namedtuple('Loss', 'total_loss style_loss content_loss tv_loss')

class Trainer:
  def __init__(self, 
               style_path, 
               content_file_path, 
               epochs=2, 
               batch_size=8,
               content_weight=1e0,
               style_weight=4e1,
               tv_weight=2e2,
               learning_rate=1e3,
               log_period=100,
               save_period=1000,
               content_layers=["conv4_2"],
               style_layers=["conv1_1", "conv2_1", "conv3_1", "conv4_1", "conv5_1"],
               content_layer_weights=[1],
               style_layer_weights=[0.2, 0.2, 0.2, 0.2, 0.2]):
    self.style_path = style_path
    self.style_name = style_path.split("/")[-1].split(".")[0]
    self.content_file_path = content_file_path
    assert(len(content_layers) == len(content_layer_weights))
    self.content_layers = content_layers
    self.content_layer_weights = content_layer_weights
    assert(len(style_layers) == len(style_layer_weights))
    self.style_layers = style_layers
    self.style_layer_weights = style_layer_weights
    self.epochs = epochs
    self.batch_size = batch_size
    self.log_period = log_period
    self.save_period = save_period
    self.saved_model_path = "/tmp/saved_models"

    self.style_weight = style_weight
    self.content_weight = content_weight
    self.tv_weight = tv_weight

    self.transform = TransformNet()
    self.vgg = VGGModel(content_layers, style_layers)
    self.learing_rate = learning_rate
    self.train_optimizer = tf.keras.optimizers.Adam(learning_rate=self.learing_rate)

  def run(self):
    self.S_outputs = self._get_S_outputs()
    self.S_style_grams = [self._gram_matrix(tf.convert_to_tensor(m, tf.float32)) for m in self.S_outputs[self.vgg.partition_idx:]]

    content_images = glob(os.path.join(self.content_file_path, "*.jpg"))
    num_images = len(content_images) - (len(content_images) % self.batch_size)
    print("Training on %d images" % num_images)

    self.iteration = 0

    for e in range(self.epochs):
      for e_i, batch in enumerate([content_images[i:i+self.batch_size] for i in range(0, num_images, self.batch_size)]):

        content_imgs = [get_image(img_path, (256,256,3)) for img_path in batch]
        content_imgs = np.array(content_imgs)
        content_tensors = tf.convert_to_tensor(content_imgs)      

        loss = self._train_step(content_tensors)

        if (self.iteration % self.log_period == 0):
          self._log_protocol(loss)     
        if (self.iteration % self.save_period == 0):
          self._save_protocol()

        self.iteration += 1

      self._log_protocol(loss)
      self._save_protocol()
      print("Epoch complete.")
    print("Training finished.")

  def _get_S_outputs(self):
    img = tf.convert_to_tensor(get_image(self.style_path), tf.float32)
    img = tf.expand_dims(img, 0)
    img = self.vgg.preprocess(img)
    return self.vgg.model(img)

  @tf.function
  def _train_step(self, content_tensors):
    with tf.GradientTape(watch_accessed_variables=False) as tape:
      tape.watch(self.transform.get_variables())
      C = self.transform.preprocess(content_tensors)
      X = self.transform.model(C)
      X = self.transform.postprocess(X)

      X_vgg = self.vgg.preprocess(X)
      Y_hat = self.vgg.forward(X_vgg) 
      Y_hat_content = Y_hat.content_output
      Y_hat_style = Y_hat.style_output

      C_vgg = self.vgg.preprocess(content_tensors)
      Y = self.vgg.forward(C_vgg)
      Y_content = Y.content_output

      L = self._get_loss(Y_hat_content, Y_hat_style, Y_content, X)
    grads = tape.gradient(L.total_loss, self.transform.get_variables())
    self.train_optimizer.apply_gradients(zip(grads, self.transform.get_variables()))
    return L

  def _get_loss(self, transformed_content, transformed_style, content, transformed_img):

    content_loss = self._get_content_loss(transformed_outputs=transformed_content, content_outputs=content)
    style_loss = self._get_style_loss(transformed_style)
    tv_loss = self._get_total_variation_loss(transformed_img)

    L_style = style_loss * self.style_weight
    L_content = content_loss * self.content_weight
    L_tv = tv_loss * self.tv_weight

    total_loss = L_style + L_content + L_tv
    
    return Loss(total_loss=total_loss, 
                style_loss=L_style, 
                content_loss=L_content,
                tv_loss=L_tv)
    
  def _get_content_loss(self, transformed_outputs, content_outputs):
    content_loss = 0
    assert(len(transformed_outputs) == len(content_outputs))
    for i, output in enumerate(transformed_outputs):
      weight = self.content_layer_weights[i]
      B, H, W, CH = output.get_shape()
      HW = H * W
      loss_i = weight * 2 * tf.nn.l2_loss(output-content_outputs[i]) / (B*HW*CH)
      content_loss += loss_i
    return content_loss

  def _get_style_loss(self, transformed_outputs):
    style_loss = 0
    assert(len(transformed_outputs) == len(self.S_style_grams))
    for i, output in enumerate(transformed_outputs):
      weight = self.style_layer_weights[i]
      B, H, W, CH = output.get_shape()
      G = self._gram_matrix(output)
      A = self.S_style_grams[i]
      style_loss += weight * 2 * tf.nn.l2_loss(G - A) / (B * (CH ** 2))
    return style_loss
    
  def _gram_matrix(self, input_tensor, shape=None):
      result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
      input_shape = input_tensor.get_shape()
      num_locations = input_shape[1] * input_shape[2] * input_shape[3]
      num_locations = tf.cast(num_locations, tf.float32)
      return result / num_locations

  def _get_total_variation_loss(self, img):
    B, W, H, CH = img.get_shape()
    return tf.reduce_sum(tf.image.total_variation(img)) / (W*H)

  def _log_protocol(self, L):
    tf.print("iteration: %d, total_loss: %f, style_loss: %f, content_loss: %f, tv_loss: %f" \
                  % (self.iteration, L.total_loss, L.style_loss, L.content_loss, L.tv_loss))

  def _save_protocol(self):
    tf.keras.models.save_model(model=self.transform.model, filepath=self.saved_model_path)

  

In [None]:
trainer = Trainer(
  style_path="/tmp/udnie.jpg", 
  content_file_path="/tmp/data/train2014", 
  epochs=2, 
  batch_size=8,
  content_weight=1e0,
  style_weight=4e1,
  tv_weight=2e2,
  learning_rate=1e-3,
  log_period=100,
  save_period=1000,
  content_layers=["conv4_2"],
  style_layers=["conv1_1", "conv2_1", "conv3_1", "conv4_1", "conv5_1"],
  content_layer_weights=[1],
  style_layer_weights=[0.2, 0.2, 0.2, 0.2, 0.2]
)

In [None]:
trainer.run()

## Inference

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def post_process(img):
  img = tf.clip_by_value(img, 0, 255)
  img = img.numpy()
  img = tf.squeeze(img)
  img = img.numpy()
  img = img.astype(int)
  return img

In [None]:
img = get_image("/tmp/chicago.jpg")
img_tensor = tf.convert_to_tensor(img)
img_tensor = tf.expand_dims(img_tensor, 0)

In [None]:
res = trainer.transform.model(img_tensor)
res = post_process(res)

In [None]:
plt.figure(1, (10,10))
plt.imshow(res)