# AdaIN-Keras
Keras implementation of [this paper](https://arxiv.org/abs/1703.06868) with [colab](https://colab.research.google.com/).
Wrote by [KOSMOS](https://github.com/kukosmos), Korea University programming club.


You can run this code directly at [here](https://colab.research.google.com/github/kukosmos/adain-keras-2019/blob/master/colab.ipynb) using colab.

# 0. Settings & Utils

You can specify the version of tensorflow to use with following command.
Following command is the command to select the tensorflow 2.x.

In [0]:
%tensorflow_version 2.x

These are configurable variables that used for training.

In [0]:
# datas
content_path = 'data/coco2017train'
style_path = 'data/wikiart'
image_size = 512
crop_size = 256
n_per_epoch = 1000
batch_size = 8
# loss
style_weight = 10.0
content_weight =  1.0
# optimizer
learning_rate = 1e-4
learning_rate_decay = 5e-5
# log
model_dir = 'models/kaiser'
# training
epochs = 1280

In this script we will use our implementation of helper methods and classes available in [github](https://github.com/kukosmos/adain-keras-2019).

In [0]:
import os
import types
import requests

In [0]:
# borrowed from https://stackoverflow.com/a/34491349
def import_from_github(uri, name=None):
  if not name:
    name = os.path.basename(uri).lower().rstrip('.py')
  
  r = requests.get(uri)
  r.raise_for_status()

  codeobj = compile(r.content, uri, 'exec')
  module = types.ModuleType(name)
  exec(codeobj, module.__dict__)
  return module

This is method for help unformat the formatted string.

In [0]:
import re

In [0]:
# borrowed from https://stackoverflow.com/a/36838374
def unformat_string(string, pattern):
  regex = re.sub(r'{(.+?)}', r'(?P<_\1>.+)', pattern)
  values = list(re.search(regex, string).groups())
  keys = re.findall(r'{(.+?)}', pattern)
  return dict(zip(keys, values))

# 1. Data preparation

First, mount the google drive that contains the photos for training.

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
from pathlib import Path

In [0]:
gdrive = Path('/content/drive/My Drive')

Before creating dataset, configure Pillow to handle errors while loading images.

In [0]:
from PIL import Image
from PIL import ImageFile

In [0]:
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

Create dataset with content data and style data in your google drive.

In [0]:
dataloader = import_from_github('https://raw.githubusercontent.com/kukosmos/adain-keras-2019/master/dataloader.py')

In [0]:
dataset = dataloader.ContentStyleLoader(
  content_root=gdrive / content_path,
  content_image_shape=(image_size, image_size),
  content_crop='random',
  content_crop_size=crop_size,
  style_root=gdrive / style_path,
  style_image_shape=(image_size, image_size),
  style_crop='random',
  style_crop_size=crop_size,
  n_per_epoch=n_per_epoch,
  batch_size=batch_size
)

To handle *OSError: [Errno 5]* while creating dataset,
create subdirectories and relocate your images into subdirectories about 10,000 images per one folder.
Or, maybe just re-run the shell to use cached data.
Please, check [here](https://research.google.com/colaboratory/faq.html#drive-timeout) for the reason of error.


# 2. Model creation

The stylizer model gets two inputs: contents and styles, and make stylized output.

In [0]:
from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Lambda
from tensorflow.keras.models import load_model
from tensorflow.keras.models import Model
from tensorflow.keras.utils import CustomObjectScope
import tensorflow.keras.backend as K
import tensorflow.keras.optimizers as optimizers

utils = import_from_github('https://raw.githubusercontent.com/kukosmos/adain-keras-2019/master/utils.py')
network = import_from_github('https://raw.githubusercontent.com/kukosmos/adain-keras-2019/master/network.py')

Stylizer takes 2 input, a content image $c$ and a style image $s$,
and generate a stylized image $g(t)$ with following steps.
First, encode the images and make normalized feature $t$ using **adaptive instance normalization**.
$$
t = AdaIN(\phi_L(c), \phi_L(s))
  = \sigma(\phi_L(s))(\frac{\phi_L(c) - \mu(\phi_L(c))}{\sigma(\phi_L(c))}) + \mu(\phi_L(s))
$$
where $\phi_i$ is $i^{th}$ encoder, $L$ is number of encodings.
Then, create new image from $t$ with decoder $g$.

To train the model we need to define the loss function.
The loss function is conposed with two different parts: style loss and content loss.

The style loss is calucalated as follows:
$$
L_{style}(s, g(t))=
\Sigma_{i=1}^L || \mu( \phi_i( g(t) ) ) - \mu( \phi_i( s ) ) ||_2
+ \Sigma_{i=1}^L || \sigma( \phi_i( g(t) ) ) - \sigma( \phi_i( s ) ) ||_2
$$

In [0]:
def calculate_style_loss(x, epsilon=1e-5):
  y_trues, y_preds = x
  loss = [
    utils.mse_loss(K.mean(y_true, axis=(1, 2)), K.mean(y_pred, axis=(1, 2)))
    + utils.mse_loss(K.sqrt(K.var(y_true, axis=(1, 2)) + epsilon), K.sqrt(K.var(y_pred, axis=(1, 2)) + epsilon))
    for y_true, y_pred in zip(y_trues, y_preds)
  ]
  return K.sum(loss)

The content loss is calucated as follows:
$$
L_{content}(t, g(t)) = || \phi_L(g(t)) - t ||_2
$$

In [0]:
def calculate_content_loss(x):
  y_true, y_pred = x
  return utils.mse_loss(y_true, y_pred)

The loss is weighted sum of the style loss and content loss.
$$
Loss = \lambda_{style} \cdot L_{style} + \lambda_{content} \cdot L_{content}
$$

Then, we can create a model. While training, we will fix the encoder's parameters.

In [0]:
def make_trainer():

  encoder = network.Encoder(input_tensor=Input(shape=(crop_size, crop_size, 3)))
  for l in encoder.layers:
    l.trainable = False
  adain = network.AdaIN(alpha=1.0)
  decoder = network.Decoder(name='decoder')

  content_input = Input(shape=(crop_size, crop_size, 3), name='content_input')
  style_input = Input(shape=(crop_size, crop_size, 3), name='style_input')

  content_features = encoder(content_input)
  style_features = encoder(style_input)
  normalized_feature = adain([content_features[-1], style_features[-1]])
  generated = decoder(normalized_feature)

  generated_features = encoder(generated)
  content_loss = Lambda(calculate_content_loss, name='content_loss')([normalized_feature, generated_features[-1]])
  style_loss = Lambda(calculate_style_loss, name='style_loss')([style_features, generated_features])
  loss = Lambda(lambda x: content_weight * x[0] + style_weight * x[1], name='loss')([content_loss, style_loss])

  trainer = Model(inputs=[content_input, style_input], outputs=[loss])
  optim = optimizers.Adam(learning_rate=learning_rate)
  trainer.compile(optimizer=optim, loss=lambda _, y_pred: y_pred)

  return trainer

To continue the learning process from the last epoch saved, check the directory,
and if there are one or more trainer models, get the lastest one.
Otherwise, create a trainer model from scratch.

In [0]:
model_dir = gdrive / model_dir
trainer_name = 'trainer.epoch-{epoch}.h5'

In [0]:
if not model_dir.exists():
  model_dir.mkdir(parents=True, exist_ok=True)

In [0]:
lastest_epoch = 0
for candidate in model_dir.glob('*'):
  if candidate.is_dir() or candidate.suffix != '.h5':
    pass
  print(candidate)
  epoch = int(unformat_string(candidate.name, trainer_name)['epoch'])
  if epoch > lastest_epoch:
    lastest_epoch = epoch

In [0]:
if lastest_epoch == 0:
  trainer = make_trainer()
else:
  custom_layers = {
    'Encoder': network.Encoder,
    'AdaIN': network.AdaIN,
    'Decoder': network.Decoder,
    'ReflectionPad': network.ReflectionPad,
    '<lambda>': lambda _, y_pred: y_pred
  }
  with CustomObjectScope(custom_layers):
    trainer = load_model(str(model_dir / trainer_name.format(epoch=lastest_epoch)))

In [0]:
trainer.summary()

# 3. Train

Before begin to train, we need some callbacks.

In [0]:
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.callbacks import ModelCheckpoint

First callback is learning rate scheduler that decays the learning rate of optimizer.
Second callback is for saving trainer every epoch.

In [0]:
callbacks = [
  LearningRateScheduler(lambda epoch, _: learning_rate / (1.0 + learning_rate_decay * n_per_epoch * epoch)),
  ModelCheckpoint(str(model_dir / trainer_name), save_freq='epoch')
]

Now we can start training!

In [0]:
trainer.fit_generator(dataset, epochs=epochs, workers=4, callbacks=callbacks, initial_epoch=lastest_epoch)

# 4. Reproduce

To reproduce the images with our script, you need to extract the decoder weights from the trained model with this script.
Following commands will extract the weights from trained model of the last epoch.

In [0]:
model_path = model_dir / trainer_name.format(epoch=epochs)
decoder_path = model_dir / 'decoder.h5'

In [0]:
custom_layers = {
  'Encoder': network.Encoder,
  'AdaIN': network.AdaIN,
  'Decoder': network.Decoder,
  'ReflectionPad': network.ReflectionPad,
  '<lambda>': lambda _, y_pred: y_pred
}
with CustomObjectScope(custom_layers):
  model = load_model(str(model_path))

In [0]:
model.get_layer('decoder').save_weights(str(decoder_path), overwrite=True)