In [None]:
#Reference: https://www.tensorflow.org/tutorials/generative/style_transfer
import os
import tensorflow as tf
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'

In [None]:
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12, 12)
mpl.rcParams['axes.grid'] = False

import numpy as np
np.random.seed(42)

import PIL.Image
import time
import functools
import cv2 as cv

In [None]:
def tensor_to_image(tensor):
  tensor = tensor*255
  tensor = np.array(tensor, dtype=np.uint8)
  if np.ndim(tensor)>3:
    assert tensor.shape[0] == 1
    tensor = tensor[0]
  return PIL.Image.fromarray(tensor)

def save_tensor_as_image(tensor, file_path):
  tensor = tensor*255
  tensor = np.array(tensor, dtype=np.uint8)
  if np.ndim(tensor)>3:
    assert tensor.shape[0] == 1
    tensor = tensor[0]
  cv.imwrite(file_path, tensor)

In [None]:
def load_img(path_to_img):
  max_dim = 512
  img = tf.io.read_file(path_to_img)
  img = tf.image.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)

  shape = tf.cast(tf.shape(img)[:-1], tf.float32)
  long_dim = max(shape)
  scale = max_dim / long_dim

  new_shape = tf.cast(shape * scale, tf.int32)

  img = tf.image.resize(img, new_shape)
  img = img[tf.newaxis, :]
  return img

In [None]:
def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

In [None]:
import random
import tensorflow_hub as hub

# The model to use for style transfer
hub_model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')
# How many times is the content image modified with the style image?
ITER = 3

CONTENT_PATH = f"path/to/the/content_dataset"
STYLE_PATH = f"path/to/the/style_dataset"
SAVE_PATH = f"path/to/output"

content_files = sorted(os.listdir(CONTENT_PATH))

# for every content image ...
for f in content_files:
    content_path = CONTENT_PATH + f
    
    #... randomly choose a style image ...
    style_file = random.choice(os.listdir(STYLE_PATH)
    style_path = f"../{DATASET_STYLE}_syn/" + style_file

    content_image = load_img(content_path)
    style_image = load_img(style_path)

    # to multiple iterations of of styling
    if ITER==1:
        stylized_image = hub_model(tf.constant(content_image), tf.constant(style_image))[0]
    elif ITER==2:
        stylized_image = hub_model(tf.constant(content_image), tf.constant(style_image))[0]
        stylized_image = hub_model(tf.constant(stylized_image), tf.constant(style_image))[0]
    elif ITER==3:
        stylized_image = hub_model(tf.constant(content_image), tf.constant(style_image))[0]
        stylized_image = hub_model(tf.constant(stylized_image), tf.constant(style_image))[0]
        stylized_image = hub_model(tf.constant(stylized_image), tf.constant(style_image))[0]
    """
    plt.subplot(1, 3, 1)
    imshow(content_image, 'Content Image')

    plt.subplot(1, 3, 2)
    imshow(style_image, 'Style Reference Image')
    
    plt.subplot(1, 3, 3)
    imshow(stylized_image, 'Styled Image')
    
    plt.show()
    """
    save_path = SAVE_PATH + f
    save_tensor_as_image(stylized_image, save_path)