# Style Transfer

In [2]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import keras
import h5py
import math
from keras.preprocessing.image import load_img, img_to_array
from keras.models import Sequential
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras.layers import *
from keras.applications import vgg19

import tensorflow as tf
import numpy as np
from PIL import Image
from IPython.display import display
from io import BytesIO

from keras import backend
from keras.models import Model
from keras.applications.vgg16 import VGG16

from scipy.optimize import fmin_l_bfgs_b

import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from PIL import Image
import io
from ipywidgets import TwoByTwoLayout
from ipywidgets import GridspecLayout

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

Using TensorFlow backend.


In [3]:
# Hyperparameters
MEAN_RGB_VALUES = [123.68, 116.779, 103.939]
IMAGE_SIZE = 500
IMAGE_WIDTH = 500
IMAGE_HEIGHT = 500

ITERATIONS = 5
TOTAL_VARIATION_WEIGHT = 1
TOTAL_VARIATION_LOSS_FACTOR = 1.3
CONTENT_WEIGHT = 0.1
STYLE_WEIGHT = 100 - CONTENT_WEIGHT

In [4]:
inputFile = widgets.FileUpload(accept='image/*', button_style = 'success')
styleFile = widgets.FileUpload(accept='image/*', button_style = 'success')
inputButton = widgets.Button(description = "Update input image")
styleButton = widgets.Button(description = "Update style image")
inputLabel = widgets.Label("Upload input image");
styleLabel = widgets.Label("Upload style image");
inputImage = widgets.Image(width=400, height=400)
styleImage = widgets.Image(width=400, height=400)

resultProgress = widgets.IntProgress(
    value=0,
    min=0,
    max=100,
    step=1,
    description='Progress',
    bar_style='success',
    orientation='horizontal'
)

resultSlider = widgets.IntSlider(
    value=5,
    min=0,
    max=200,
    step=5,
    description='Iterations:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

resultLabel = widgets.Label("Result")
resultButton = widgets.Button(description = "Generate image",  button_style = 'warning')
resultImage = widgets.Image(width=400, height=400)

input_image = 0
style_image = 0

def on_input_button_clicked(b):
    global input_image
    inputImage = widgets.Image(
        value=inputFile.data[0],
        format='png',
        width=400,
        height=400,
    )
    
    grid[4:6, 0:2] = inputImage;
    
    stream = io.BytesIO(inputFile.data[0])
    image = Image.open(stream)
    
    # image = Image.frombytes('RGB', (400,400), inputFile.data[0], 'raw')
    input_image = image.resize((IMAGE_WIDTH, IMAGE_HEIGHT))


inputButton.on_click(on_input_button_clicked)

def on_style_button_clicked(b):
    global style_image
    styleImage = widgets.Image(
        value=styleFile.data[0],
        format='png',
        width=400,
        height=400,
    )
    
    grid[4:6, 2:4] = styleImage;
    
    stream = io.BytesIO(styleFile.data[0])
    image = Image.open(stream)
    
    # image = Image.frombytes('RGB', (128,128), styleFile.data[0], 'raw')
    style_image = image.resize((IMAGE_WIDTH, IMAGE_HEIGHT))

styleButton.on_click(on_style_button_clicked)

In [5]:
def on_result_button_clicked(b):
    global input_image, style_image
    
    ITERATIONS = resultSlider.value
    
    input_image_array = np.asarray(input_image, dtype = "float32")
    input_image_array = np.expand_dims(input_image_array, axis = 0)
    input_image_array[:, :, :, 2] -= MEAN_RGB_VALUES[0]
    input_image_array[:, :, :, 1] -= MEAN_RGB_VALUES[1]
    input_image_array[:, :, :, 0] -= MEAN_RGB_VALUES[2]

    input_image_array = input_image_array[:, :, :, ::-1]

    style_image_array = np.asarray(style_image, dtype = "float32")
    style_image_array = np.expand_dims(style_image_array, axis = 0)
    style_image_array[:, :, :, 2] -= MEAN_RGB_VALUES[0]
    style_image_array[:, :, :, 1] -= MEAN_RGB_VALUES[1]
    style_image_array[:, :, :, 0] -= MEAN_RGB_VALUES[2]

    style_image_array = style_image_array[:, :, :, ::-1]

    # Build VGG16 Model
    input_image = tf.keras.backend.variable(input_image_array)
    style_image = tf.keras.backend.variable(style_image_array)
    comb_image = backend.placeholder((1, IMAGE_HEIGHT, IMAGE_SIZE, 3))

    input_t = backend.concatenate([input_image, style_image, comb_image], axis = 0)
    model = VGG16(input_tensor = input_t, include_top = False)

    def compute_content_loss(content, combination):
        return backend.sum(backend.square(combination - content))

    # Get list of layers as dictionary
    layers = dict([(layer.name, layer.output) for layer in model.layers])

    content_layer = "block2_conv2"
    content_layer_features = layers[content_layer]
    content_image_features = content_layer_features[0, :, :, :]
    comb_features = content_layer_features[2, :, :, :]

    loss = backend.variable(0.)
    loss += CONTENT_WEIGHT * compute_content_loss(content_image_features, comb_features)

    def get_gram_matrix(x):
        features = backend.batch_flatten(backend.permute_dimensions(x, (2, 0, 1)))
        gram = backend.dot(features, backend.transpose(features))
    
        return gram

    def compute_style_loss(style_features, combination_features):
        style = get_gram_matrix(style_features)
        combination = get_gram_matrix(combination_features)
        size = IMAGE_HEIGHT * IMAGE_WIDTH

        return backend.sum(backend.square(style - combination)) / (4. * (3 ** 2) * (size ** 2))

    layers_for_style = ["block1_conv1", "block2_conv1", "block3_conv1", "block4_conv1", "block5_conv1"]

    for layer in layers_for_style:
        layer_features = layers[layer]
        style_features = layer_features[1, :, :, :]
        combination_features = layer_features[2, :, :, :]

        style_loss = compute_style_loss(style_features, combination_features)
        # Set equal weight for each style layer
        loss += (STYLE_WEIGHT / len(layers_for_style)) * style_loss

    def compute_total_variation_loss(x):
        a = backend.square(x[:, :IMAGE_HEIGHT-1, :IMAGE_WIDTH-1, :] - x[:, 1:, :IMAGE_WIDTH-1, :])
        b = backend.square(x[:, :IMAGE_HEIGHT-1, :IMAGE_WIDTH-1, :] - x[:, :IMAGE_HEIGHT-1, 1:, :])
        return backend.sum(backend.pow(a + b, TOTAL_VARIATION_LOSS_FACTOR))

    loss += TOTAL_VARIATION_WEIGHT * compute_total_variation_loss(comb_image)
    

    outputs = [loss]
    outputs += backend.gradients(loss, comb_image)

    class LossGrad:

        def loss(self, x):
            x = x.reshape((1, IMAGE_HEIGHT, IMAGE_WIDTH, 3))
            out = backend.function([comb_image], outputs)([x])

            loss = out[0]
            gradients = out[1].flatten().astype("float64")

            self._gradients = gradients
            return loss

        def gradients(self, x):
            return self._gradients

    loss_grad = LossGrad()

    def displayImage(image_x, num):
        image_x = image_x.reshape((IMAGE_HEIGHT, IMAGE_WIDTH, 3))
        image_x = image_x[:, :, ::-1]
        image_x[:, :, 2] += MEAN_RGB_VALUES[0]
        image_x[:, :, 1] += MEAN_RGB_VALUES[1]
        image_x[:, :, 0] += MEAN_RGB_VALUES[2]
        image_x = np.clip(image_x, 0, 255).astype("uint8")

        output_image = Image.fromarray(image_x)
        # output_image.save("output_%d.png" % num)
        # display(output_image)
        
        pil_im = output_image
        b = io.BytesIO()
        pil_im.save(b, 'jpeg')
        im_bytes = b.getvalue()
        
        resultImage = widgets.Image(
            value=im_bytes,
            format='png',
            width=400,
            height=400
        )
        
        grid[4:7, 5:8] = resultImage;
        

    x = np.random.uniform(0, 255, (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)) - 128.

    for i in range(ITERATIONS):
        # print('Iteration: ', i)
        
        resultProgress = widgets.IntProgress(
            value=i,
            max=ITERATIONS,
            step=1,
            description="Iter: %d" % i,
            bar_style='success',
            orientation='horizontal'
        )
        
        # resultNum = widgets.Label("Current iter: %d" % i)
        # resultProgress.value = i
        grid[3, 5:6] = resultProgress

        x, loss, info = fmin_l_bfgs_b(loss_grad.loss, x.flatten(), fprime = loss_grad.gradients, maxfun = 25)
        # print("Iteration: %d - Loss: %d" % (i, loss))
        image_x = copy.deepcopy(x)
        
        displayImage(image_x, i)
    
resultButton.on_click(on_result_button_clicked)

In [6]:
grid = GridspecLayout(12, 9, height='600px')
grid[0, 0] = inputLabel
grid[1, 0] = inputFile
grid[2, 0] = inputButton

grid[0, 2] = styleLabel
grid[1, 2] = styleFile
grid[2, 2] = styleButton

grid[0, 5:6] = resultLabel
grid[1, 5:6] = resultSlider
grid[2, 5:6] = resultButton
grid[3, 5:6] = resultProgress

grid

GridspecLayout(children=(Label(value='Upload input image', layout=Layout(grid_area='widget001')), FileUpload(v…

Instructions for updating:
Colocations handled automatically by placer.


UnknownError: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.
	 [[{{node block1_conv1/convolution}}]]
	 [[{{node gradients/AddN_16}}]]