# Goal

`f(image, class_id, effect_id) --> image_with_effect_on_class_segments`

```python
def f1(image, class_id, effect_id):
    """One way to implement f"""
    image_with_global_effect = apply_effect_to_entire_image(image, effect_id)
    local_region = get_local_region(image, class_id)
    image_with_local_effect = apply_effect_to_subimage(image,
                                                       image_with_global_effect,
                                                       local_region)
    return image_with_local_effect
```

    # V1: Trained on ImageNet 20k, 288x288, black borders, JH's cropped Van Gogh image
        # Training:
            # lr=1e-3, batch_size=8, nb_epoch=2
            # lr=1e-4, batch_size=16, nb_epoch=1
        # Try with more images
        # Try with bigger images
        # Try with center cropping instead
        # Try a different training procedure

### Setup

#### Change this to your deephacks repo path

In [None]:
INPUT_DIR = '/nbs/deephacks/' 

#### Other

In [None]:
WEIGHTS_DIR = INPUT_DIR + "weights/"
SEGMENTATION_WEIGHTS_PATH = INPUT_DIR + 'conversion/dilation8_pascal_voc.npy'

RESULTS_DIR = INPUT_DIR+'results/'
OUTPUT_DIR = 'images/segmentations/'
IMAGES_DIR = 'images/'

TEST_PHOTO_FILENAME = 'cat.jpg'

INPUT_FILE = IMAGES_DIR + TEST_PHOTO_FILENAME
MASK_FILE = OUTPUT_DIR + TEST_PHOTO_FILENAME+'_seg.png'
OUTPUT_FILE = RESULTS_DIR+TEST_PHOTO_FILENAME+'_stylized.png'

### Low-level functions

#### Visualization functions

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

def plots(ims, figsize=(12,6), rows=1, cols=1, interp=None, titles=None, cmap=None):
    fig = plt.figure(figsize=figsize)
    for i in range(len(ims)):
        sp = fig.add_subplot(rows, cols, i+1)
        if titles:
            sp.set_title(titles[i], fontsize=18)
        plt.imshow(ims[i], interpolation=interp, cmap=cmap)
        plt.axis('off')

In [None]:
from copy import deepcopy
from scipy.misc import imsave
def show_results(image, mask, output_filename, effect_id="van_gogh"):
    image = deepcopy(image)
    stylized_global = apply_effect_to_entire_image(image, effect_id=effect_id)[0]
    imgs = [image, mask, stylized_global]
    plots(imgs, figsize=(12, 12), rows=1, cols=3)
    result = apply_effect_to_subimage(image, stylized_global, mask, offset=3)
    imsave(output_filename, result)
    plots([result], figsize=(8, 8), rows=1, cols=1)

#### Style transfer functions

In [None]:
import numpy as np
np.random.seed(8675309)
import tensorflow as tf
from keras.models import Model
from keras.layers import Layer, Input, InputSpec, Lambda, Convolution2D, BatchNormalization, Activation, UpSampling2D, merge
import keras.backend as K
from PIL import Image

class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)
        
    def get_output_shape_for(self, s):
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
    
def conv_block(x, filters, size, stride=(2,2), mode='same', act=True):
    x = Convolution2D(filters, size, size, subsample=stride, border_mode=mode)(x)
    x = BatchNormalization(mode=2)(x)
    return Activation('relu')(x) if act else x

def res_crop_block(ip, nf=64):
    x = conv_block(ip, nf, 3, (1,1), 'valid')
    x = conv_block(x,  nf, 3, (1,1), 'valid', False)
    ip = Lambda(lambda x: x[:, 2:-2, 2:-2])(ip)
    return merge([x, ip], mode='sum')

def up_block(x, filters, size):
    x = UpSampling2D()(x)
    x = Convolution2D(filters, size, size, border_mode='same')(x)
    x = BatchNormalization(mode=2)(x)
    return Activation('relu')(x)

def make_mixer(mixer_input):
    c = 2 # Number of conv blocks and up blocks
    r = 5 # Number of res blocks
    r2 = r * 8 # Amount of reflection padding
    nf = 64
    x = ReflectionPadding2D((r2, r2))(mixer_input)
    x = conv_block(x, nf, 9, (1,1))
    for i in range(c): x = conv_block(x, nf, 3)
    for i in range(r): x = res_crop_block(x, nf)
    for i in range(c): x = up_block(x, nf, 3)
    x = Convolution2D(3, 9, 9, activation='tanh', border_mode='same')(x)
    mixer_output = Lambda(lambda x: (x+1)*127.5)(x)
    return Model(mixer_input, mixer_output, name="mixer")

### Medium-level functions

In [None]:
def load_mixer(shape, version="1"):
    weights_dir = WEIGHTS_DIR + "van_gogh/"
    mixer_input = Input(shape, name="mixer_input")
    mixer = make_mixer(mixer_input)
    mixer.load_weights(weights_dir + f'v{version}.h5')
    return mixer

### High-level functions

In [None]:
def f1(image, class_id, effect_id):
    """One way to implement f
    The highest level function"""
    image_with_global_effect = apply_effect_to_entire_image(image, effect_id)
    local_region = get_local_region(image, class_id)
    image_with_local_effect = apply_effect_to_subimage(image,
                                                       image_with_global_effect,
                                                       local_region)
    return image_with_local_effect

def get_local_region(image, class_id):
    """Matthew is covering this."""
    # load segmentation network
    # get segmentation (i.e. mask, i.e. local region)
    return local_region
    
def apply_effect_to_entire_image(images, effect_id="van_gogh"):
    """Applies an effect to a list of images.
    Matthew is covering this."""
    if type(images) == list:
        images = np.array(images)
    elif type(images) == np.ndarray and len(images.shape) < 4:
        images = np.expand_dims(images, 0)
    shape = images[0].shape
    if effect_id == "van_gogh":
        mixer = load_mixer(shape, version="1")
        raw_results = mixer.predict(images)
        images_with_global_effect = [np.round(raw_result).astype('uint8') for raw_result in raw_results]
    if effect_id == "black":
        images_with_global_effect = [np.zeros(shape)]
    return images_with_global_effect
    
def apply_effect_to_subimage(image, image_with_global_effect, mask, offset):
    width, height, channels = image_with_global_effect.shape
    for i in range(width-offset):
        for j in range(height-offset):
            if mask[i, j] == 255:
                image[i, j, :] = image_with_global_effect[i, j, :]
    return image

### Highest-level API so far

In [None]:
image = np.array(Image.open(INPUT_FILE))
mask = np.array(Image.open(MASK_FILE))

In [None]:
plt.imshow(image)

In [None]:
show_results(image, mask, OUTPUT_FILE+"_van_gogh.png", effect_id="van_gogh")

In [None]:
show_results(image, mask, OUTPUT_FILE+"_black.png", effect_id="black")