In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import scipy.io
import pdb
import functools
import gc
import matplotlib.pyplot as plt
import time

from scipy.misc import imread, imresize, imsave, fromimage, toimage
from PIL import Image

from tensorflow.python.layers import utils

from keras import optimizers
from keras.models import Model, load_model
from keras.layers import Input, ZeroPadding2D,merge, Lambda, concatenate
from keras.layers.convolutional import Convolution2D, AveragePooling2D, MaxPooling2D,Deconvolution2D 
from keras.layers.convolutional import Conv2D,UpSampling2D,Cropping2D, Conv2DTranspose
from keras.layers.normalization import BatchNormalization
from keras.layers.merge import add, concatenate
from keras.layers.core import Activation
from keras.initializers import RandomNormal

from keras.layers.advanced_activations import LeakyReLU
from keras.regularizers import Regularizer

from keras import regularizers
from keras import initializers
from keras import constraints

from keras import backend as K
from keras.preprocessing import image
from keras.engine.topology import Layer
from keras.engine import InputSpec

from keras.applications.vgg19 import preprocess_input
from keras.applications.vgg19 import VGG19
from keras.applications.vgg16 import VGG16

import warnings
warnings.filterwarnings('ignore')

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
tv_weight = 1e-6

IMG_WIDTH=512
IMG_HEIGHT=512

In [3]:
def load_image(image_path, IMG_WIDTH=IMG_WIDTH, IMG_HEIGHT=IMG_HEIGHT):
    mode = "RGB"
    img = imread(image_path, mode=mode)

    img = imresize(img,(IMG_WIDTH, IMG_HEIGHT)).astype('float32')

    # 這個keras內建函式可以幫我們做到上面的轉換
    #img = preprocess_input(img)
    
    if K.image_dim_ordering() == "th":
        img = img.transpose((2, 0, 1)).astype('float32')

    img = np.expand_dims(img, axis=0)
    print(img.shape)
    return img

def preprocess_image(image_path, IMG_WIDTH=IMG_WIDTH, IMG_HEIGHT=IMG_HEIGHT):
    mode = "RGB"
    img = imread(image_path, mode=mode)

    img = imresize(img,(IMG_WIDTH, IMG_HEIGHT)).astype('float32')
    
    # 這個keras內建函式可以幫我們做到上面的轉換
    img = preprocess_input(img)
    
    if K.image_dim_ordering() == "th":
        img = img.transpose((2, 0, 1)).astype('float32')

    img = np.expand_dims(img, axis=0)
    print(img.shape)
    return img

def deprocess_image(x, IMG_WIDTH=IMG_WIDTH, IMG_HEIGHT=IMG_HEIGHT):
    if K.image_dim_ordering() == "th":
        x = x.reshape((3, IMG_WIDTH, IMG_HEIGHT))
        x = x.transpose((1, 2, 0))
    else:
        x = x.reshape((IMG_WIDTH, IMG_HEIGHT, 3))

    x[:, :, 0] += 103.939
    x[:, :, 1] += 116.779
    x[:, :, 2] += 123.68
    
    # BGR -> RGB
    x = x[:, :, ::-1]

    # 將陣列的值的範圍縮回 0~255，因為處理的結果有可能出現超過這個範圍的數字
    x = np.clip(x, 0, 255).astype('uint8')
    return x

def get_ratio(image):
    img = Image.fromarray(image).convert('RGB')
    img_WIDTH, img_HEIGHT = img.size
    aspect_ratio = float(img_HEIGHT) / img_WIDTH
    return aspect_ratio


def show(x, IMG_WIDTH=IMG_WIDTH, IMG_HEIGHT=IMG_HEIGHT, save=False, name='', iterate=0):
    # 將張量轉回圖片的後處理
    img = deprocess_image(x.copy())
    
    # 取得原圖比例
    aspect_ratio = get_ratio(img)  
    img_ht = int(IMG_WIDTH * aspect_ratio)
    #print("Rescaling Image to (%d, %d)" % (IMG_WIDTH, img_ht))
    img = imresize(img, (IMG_WIDTH, img_ht), interp='bilinear')
    im = toimage(img)
    if save:
        filename = 'output/%s_%d.jpg' % (name, iterate)
        imsave(filename, im)
    else:
        plt.imshow(im)
    
def show_without_deprocess(x, IMG_WIDTH=IMG_WIDTH, IMG_HEIGHT=IMG_HEIGHT, save=False, name='', iterate=0):

    x = x.reshape((IMG_WIDTH, IMG_HEIGHT, 3))
    img = np.clip(x, 0, 255).astype('uint8')
    
    # 取得原圖比例
    aspect_ratio = get_ratio(img)  
    img_ht = int(IMG_WIDTH * aspect_ratio)
    print("Rescaling Image to (%d, %d)" % (IMG_WIDTH, img_ht))
    img = imresize(img, (IMG_WIDTH, img_ht), interp='bilinear')
    print(img)
    im = toimage(img)
    if save:
        filename = 'output/%s_%d.jpg' % (name, iterate)
        imsave(filename, im)
    else:
        plt.imshow(im)

In [4]:
class AdaptiveInstanceNormalize(Layer):
    def __init__(self, **kwargs):
        super(InstanceNormalize, self).__init__(**kwargs)
        self.epsilon = 1e-5
        self.alpha = 0.8
            

    def call(self, x, mask=None):
        style_mean, style_variance = tf.nn.moments(style_features, [1,2], keep_dims=True)
        content_mean, content_variance = tf.nn.moments(content_features, [1,2], keep_dims=True)
        
        normalized_content_features = tf.nn.batch_normalization(content_features, content_mean,
                                                            content_variance, style_mean, 
                                                            tf.sqrt(style_variance), self.epsilon)
        
        normalized_content_features = self.alpha * normalized_content_features + (1 - self.alpha) * content_features
        return normalized_content_features
                                                 
    def compute_output_shape(self,input_shape):
        return input_shape

class InstanceNormalize(Layer):
    def __init__(self, **kwargs):
        super(InstanceNormalize, self).__init__(**kwargs)
        self.epsilon = 1e-3
            

    def call(self, x, mask=None):
        mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
        return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, self.epsilon)))

                                                 
    def compute_output_shape(self,input_shape):
        return input_shape
    
class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), dim_ordering='default', **kwargs):
        super(ReflectionPadding2D, self).__init__(**kwargs)

        if dim_ordering == 'default':
            dim_ordering = K.image_dim_ordering()

        self.padding = padding
        if isinstance(padding, dict):
            if set(padding.keys()) <= {'top_pad', 'bottom_pad', 'left_pad', 'right_pad'}:
                self.top_pad = padding.get('top_pad', 0)
                self.bottom_pad = padding.get('bottom_pad', 0)
                self.left_pad = padding.get('left_pad', 0)
                self.right_pad = padding.get('right_pad', 0)
            else:
                raise ValueError('Unexpected key found in `padding` dictionary. '
                                 'Keys have to be in {"top_pad", "bottom_pad", '
                                 '"left_pad", "right_pad"}.'
                                 'Found: ' + str(padding.keys()))
        else:
            padding = tuple(padding)
            if len(padding) == 2:
                self.top_pad = padding[0]
                self.bottom_pad = padding[0]
                self.left_pad = padding[1]
                self.right_pad = padding[1]
            elif len(padding) == 4:
                self.top_pad = padding[0]
                self.bottom_pad = padding[1]
                self.left_pad = padding[2]
                self.right_pad = padding[3]
            else:
                raise TypeError('`padding` should be tuple of int '
                                'of length 2 or 4, or dict. '
                                'Found: ' + str(padding))

        if dim_ordering not in {'tf'}:
            raise ValueError('dim_ordering must be in {tf}.')
        self.dim_ordering = dim_ordering
        self.input_spec = [InputSpec(ndim=4)] 


    def call(self, x, mask=None):
        top_pad=self.top_pad
        bottom_pad=self.bottom_pad
        left_pad=self.left_pad
        right_pad=self.right_pad        
        
        paddings = [[0,0],[left_pad,right_pad],[top_pad,bottom_pad],[0,0]]

        
        return tf.pad(x,paddings, mode='REFLECT', name=None)

    def compute_output_shape(self,input_shape):
        if self.dim_ordering == 'tf':
            rows = input_shape[1] + self.top_pad + self.bottom_pad if input_shape[1] is not None else None
            cols = input_shape[2] + self.left_pad + self.right_pad if input_shape[2] is not None else None

            return (input_shape[0],
                    rows,
                    cols,
                    input_shape[3])
        else:
            raise ValueError('Invalid dim_ordering:', self.dim_ordering)
            
    def get_config(self):
        config = {'padding': self.padding}
        base_config = super(ReflectionPadding2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))     
    
    
class UnPooling2D(UpSampling2D):
    def __init__(self, size=(2, 2)):
        super(UnPooling2D, self).__init__(size)

  
    def call(self, x, mask=None):
        shapes = x.get_shape().as_list() 
        w = self.size[0] * shapes[1]
        h = self.size[1] * shapes[2]
        return tf.image.resize_nearest_neighbor(x, (w,h))
    
class VGGNormalize(Layer):
    '''
    Custom layer to subtract the outputs of previous layer by 120,
    to normalize the inputs to the VGG network.
    '''

    def __init__(self, **kwargs):
        super(VGGNormalize, self).__init__(**kwargs)

    def build(self, input_shape):
        pass

    def call(self, x, mask=None):
        # No exact substitute for set_subtensor in tensorflow
        # So we subtract an approximate value       
        
        # 'RGB'->'BGR'
        x = preprocess_input(x)
        return x
   

    def compute_output_shape(self,input_shape):
        return input_shape
    
class TanhNormalize(Layer):
    '''
    Custom layer to subtract the outputs of previous layer by 120,
    to normalize the inputs to the VGG network.
    '''

    def __init__(self, **kwargs):
        super(TanhNormalize, self).__init__(**kwargs)

    def build(self, input_shape):
        pass

    def call(self, x, mask=None):
        # No exact substitute for set_subtensor in tensorflow
        # So we subtract an approximate value       
        
        # 'RGB'->'BGR'
        x = (x + 1) * (255.0 / 2)
        return x
   

    def compute_output_shape(self,input_shape):
        return input_shape

In [8]:
class InstanceNormalization(Layer):
    def __init__(self,
                 axis=None,
                 epsilon=1e-3,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 **kwargs):
        super(InstanceNormalization, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)

    def build(self, input_shape):
        ndim = len(input_shape)
        if self.axis == 0:
            raise ValueError('Axis cannot be zero')

        if (self.axis is not None) and (ndim == 2):
            raise ValueError('Cannot specify axis for rank 1 tensor')

        self.input_spec = InputSpec(ndim=ndim)

        if self.axis is None:
            shape = (1,)
        else:
            shape = (input_shape[self.axis],)

        if self.scale:
            self.gamma = self.add_weight(shape=shape,
                                         name='gamma',
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(shape=shape,
                                        name='beta',
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint)
        else:
            self.beta = None
        self.built = True

    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        reduction_axes = list(range(0, len(input_shape)))

        if (self.axis is not None):
            del reduction_axes[self.axis]

        del reduction_axes[0]

        mean = K.mean(inputs, reduction_axes, keepdims=True)
        stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
        normed = (inputs - mean) / stddev

        broadcast_shape = [1] * len(input_shape)
        if self.axis is not None:
            broadcast_shape[self.axis] = input_shape[self.axis]

        if self.scale:
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            normed = normed * broadcast_gamma
        if self.center:
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            normed = normed + broadcast_beta
        return normed

    def get_config(self):
        config = {
            'axis': self.axis,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
        }
        base_config = super(InstanceNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [12]:
def build_encode_net(input_shape):
    
    #vgg = VGG19(include_top=False, input_tensor=input_tensor)
    vgg = VGG19(include_top=False, input_shape=input_shape)
    
    encode_layer = Model(vgg.input, vgg.layers[7].output)
    
    for layer in encode_layer.layers[:]:
        layer.trainable = False
    
    encode_layer.compile(optimizer='adam', loss='mse')
    return encode_layer

def build_encode_net_with_swap_3_1(input_shape):
    
    content_input = Input(shape=input_shape, name='content_input')
    style_input = Input(shape=input_shape, name='style_input')
    x = concatenate([content_input, style_input], axis=0)
    
    vgg = VGG19(include_top=False, input_tensor=x)
    
    swapped = Lambda(style_swap_layer, output_shape=(64, 64, 256))(vgg.layers[-15].output)
    
    encode_layer = Model([content_input, style_input], swapped)
    
    for layer in encode_layer.layers[:]:
        layer.trainable = False
    
    encode_layer.compile(optimizer='adam', loss='mse')
    return encode_layer

def build_encode_net_with_swap(input_shape):
    
    content_input = Input(shape=input_shape, name='content_input')
    style_input = Input(shape=input_shape, name='style_input')
    x = concatenate([content_input, style_input], axis=0)
    
    vgg = VGG19(include_top=False, input_tensor=x)
    
    swapped = Lambda(benben_swap_layer)(vgg.layers[-13].output)
    
    encode_layer = Model([content_input, style_input], swapped)
    
    for layer in encode_layer.layers[:]:
        layer.trainable = False
    
    encode_layer.compile(optimizer='adam', loss='mse')
    return encode_layer

def build_encode_net_vgg16_3_1(input_shape):
    
    content_input = Input(shape=input_shape, name='content_input')
    style_input = Input(shape=input_shape, name='style_input')
    x = concatenate([content_input, style_input], axis=0)
    
    vgg = VGG16(include_top=False, input_tensor=x)
    
    swapped = Lambda(style_swap_layer, output_shape=(64, 64, 256))(vgg.layers[-12].output)
    
    encode_layer = Model([content_input, style_input], swapped)
    
    for layer in encode_layer.layers[:]:
        layer.trainable = False
    
    encode_layer.compile(optimizer='adam', loss='mse')
    return encode_layer

def build_encode_net_with_swap_3_3(input_shape):
    
    content_input = Input(shape=input_shape, name='content_input')
    style_input = Input(shape=input_shape, name='style_input')
    x = concatenate([content_input, style_input], axis=0)
    
    vgg = VGG19(include_top=False, input_tensor=x)
    
    swapped = Lambda(wct_style_swap, output_shape=(64, 64, 256))(vgg.layers[-13].output)
    
    encode_layer = Model([content_input, style_input], swapped)
    
    for layer in encode_layer.layers[:]:
        layer.trainable = False
    
    encode_layer.compile(optimizer='adam', loss='mse')
    return encode_layer

In [13]:
def conv_in_relu(nb_filter, nb_row, nb_col,stride):   
    def conv_func(x):
        x = Conv2D(nb_filter, (nb_row, nb_col), strides=stride , padding='same')(x)
        x = InstanceNormalization()(x)
        x = Activation("relu")(x)
        return x
    return conv_func

def res_conv(nb_filter, nb_row, nb_col,stride=(1,1)):
    def _res_func(x):
        #identity = Cropping2D(cropping=((2,2),(2,2)))(x)

        a = Conv2D(nb_filter, (nb_row, nb_col), strides=stride, padding='same')(x)
        a = InstanceNormalize()(a)
        #a = LeakyReLU(0.2)(a)
        a = Activation("relu")(a)
        a = Conv2D(nb_filter, (nb_row, nb_col), strides=stride, padding='same')(a)
        y = InstanceNormalize()(a)

        return  add([x, y])

    return _res_func

def dconv_bn_nolinear(nb_filter, nb_row, nb_col,stride=(2,2),activation="relu"):
    def _dconv_bn(x):
        x = UnPooling2D(size=stride)(x)
        x = ReflectionPadding2D(padding=stride)(x)
        x = Conv2D(nb_filter, (nb_row, nb_col), padding='valid')(x)
        x = InstanceNormalization()(x)
        x = Activation(activation)(x)
        return x
    return _dconv_bn


def add_total_variation_loss(transform_output_layer,weight):
    # Total Variation Regularization
    layer = transform_output_layer  # Output layer
    tv_regularizer = TVRegularizer(weight)(layer)
    layer.add_loss(tv_regularizer)
    
    
class TVRegularizer(Regularizer):
    """ Enforces smoothness in image output. """

    def __init__(self, weight=1.0):
        self.weight = weight
        self.uses_learning_phase = False
        super(TVRegularizer, self).__init__()

    def __call__(self, x):
        assert K.ndim(x.output) == 4
        x_out = x.output
        
        shape = K.shape(x_out)
        img_width, img_height,channel = (shape[1],shape[2], shape[3])
        size = img_width * img_height * channel 
        if K.image_dim_ordering() == 'th':
            a = K.square(x_out[:, :, :img_width - 1, :img_height - 1] - x_out[:, :, 1:, :img_height - 1])
            b = K.square(x_out[:, :, :img_width - 1, :img_height - 1] - x_out[:, :, :img_width - 1, 1:])
        else:
            a = K.square(x_out[:, :img_width - 1, :img_height - 1, :] - x_out[:, 1:, :img_height - 1, :])
            b = K.square(x_out[:, :img_width - 1, :img_height - 1, :] - x_out[:, :img_width - 1, 1:, :])
        loss = self.weight * K.sum(K.pow(a + b, 1.25)) 
        return loss


In [14]:
def wct_style_swap(x, alpha=0.8, patch_size=3, stride=1, eps=1e-8):
    '''Modified Whiten-Color Transform that performs style swap on whitened content/style encodings before coloring
       Assume that content/style encodings have shape 1xHxWxC
    '''
    
    content = K.expand_dims(x[0], 0)
    style = K.expand_dims(x[1], 0)
    
    content_t = tf.transpose(tf.squeeze(content), (2, 0, 1))
    style_t = tf.transpose(tf.squeeze(style), (2, 0, 1))

    Cc, Hc, Wc = tf.unstack(tf.shape(content_t))
    Cs, Hs, Ws = tf.unstack(tf.shape(style_t))

    # CxHxW -> CxH*W
    content_flat = tf.reshape(content_t, (Cc, Hc*Wc))
    style_flat = tf.reshape(style_t, (Cs, Hs*Ws))

    # Content covariance
    mc = tf.reduce_mean(content_flat, axis=1, keep_dims=True)
    fc = content_flat - mc
    fcfc = tf.matmul(fc, fc, transpose_b=True) / (tf.cast(Hc*Wc, tf.float32) - 1.) + tf.eye(Cc)*eps

    # Style covariance
    ms = tf.reduce_mean(style_flat, axis=1, keep_dims=True)
    fs = style_flat - ms
    fsfs = tf.matmul(fs, fs, transpose_b=True) / (tf.cast(Hs*Ws, tf.float32) - 1.) + tf.eye(Cs)*eps

    # tf.svd is slower on GPU, see https://github.com/tensorflow/tensorflow/issues/13603
    with tf.device('/cpu:0'):  
        Sc, Uc, _ = tf.svd(fcfc)
        Ss, Us, _ = tf.svd(fsfs)

    ## Uncomment to perform SVD for content/style with np in one call
    ## This is slower than CPU tf.svd but won't segfault for ill-conditioned matrices
    # @jit
    # def np_svd(content, style):
    #     '''tf.py_func helper to run SVD with NumPy for content/style cov tensors'''
    #     Uc, Sc, _ = np.linalg.svd(content)
    #     Us, Ss, _ = np.linalg.svd(style)
    #     return Uc, Sc, Us, Ss
    # Uc, Sc, Us, Ss = tf.py_func(np_svd, [fcfc, fsfs], [tf.float32, tf.float32, tf.float32, tf.float32])
    
    k_c = tf.reduce_sum(tf.cast(tf.greater(Sc, 1e-5), tf.int32))
    k_s = tf.reduce_sum(tf.cast(tf.greater(Ss, 1e-5), tf.int32))

    ### Whiten content
    Dc = tf.diag(tf.pow(Sc[:k_c], -0.5))

    fc_hat = tf.matmul(tf.matmul(tf.matmul(Uc[:,:k_c], Dc), Uc[:,:k_c], transpose_b=True), fc)

    # Reshape before passing to style swap, CxH*W -> 1xHxWxC
    whiten_content = tf.expand_dims(tf.transpose(tf.reshape(fc_hat, [Cc,Hc,Wc]), [1,2,0]), 0)

    ### Whiten style before swapping
    Ds = tf.diag(tf.pow(Ss[:k_s], -0.5))
    whiten_style = tf.matmul(tf.matmul(tf.matmul(Us[:,:k_s], Ds), Us[:,:k_s], transpose_b=True), fs)
    # Reshape before passing to style swap, CxH*W -> 1xHxWxC
    whiten_style = tf.expand_dims(tf.transpose(tf.reshape(whiten_style, [Cs,Hs,Ws]), [1,2,0]), 0)

    ### Style swap whitened encodings
    #ss_feature = ori_style_swap_layer(whiten_content, whiten_style, patch_size, stride)
    
    ###############################################
    nC = tf.shape(whiten_style)[-1]  # Num channels of input content feature and style-swapped output

    ### Extract patches from style image that will be used for conv/deconv layers
    style_patches = tf.extract_image_patches(whiten_style, [1,patch_size,patch_size,1], [1,stride,stride,1], [1,1,1,1], 'VALID')
    before_reshape = tf.shape(style_patches)  # NxRowsxColsxPatch_size*Patch_size*nC
    style_patches = tf.reshape(style_patches, [before_reshape[1]*before_reshape[2],patch_size,patch_size,nC])
    style_patches = tf.transpose(style_patches, [1,2,3,0])  # Patch_sizexPatch_sizexIn_CxOut_c

    # Normalize each style patch
    style_patches_norm = tf.nn.l2_normalize(style_patches, dim=3)

    # Compute cross-correlation/nearest neighbors of patches by using style patches as conv filters
    ss_enc = tf.nn.conv2d(whiten_content,
                          style_patches_norm,
                          [1,stride,stride,1],
                          'VALID')

    # For each spatial position find index of max along channel/patch dim  
    ss_argmax = tf.argmax(ss_enc, axis=3)
    encC = tf.shape(ss_enc)[-1]  # Num channels in intermediate conv output, same as # of patches
    
    # One-hot encode argmax with same size as ss_enc, with 1's in max channel idx for each spatial pos
    ss_oh = tf.one_hot(ss_argmax, encC, 1., 0., 3)

    # Calc size of transposed conv out
    deconv_out_H = utils.deconv_output_length(tf.shape(ss_oh)[1], patch_size, 'valid', stride)
    deconv_out_W = utils.deconv_output_length(tf.shape(ss_oh)[2], patch_size, 'valid', stride)
    deconv_out_shape = tf.stack([1,deconv_out_H,deconv_out_W,nC])

    # Deconv back to original content size with highest matching (unnormalized) style patch swapped in for each content patch
    ss_dec = tf.nn.conv2d_transpose(ss_oh,
                                    style_patches,
                                    deconv_out_shape,
                                    [1,stride,stride,1],
                                    'VALID')

    ### Interpolate to average overlapping patch locations
    ss_oh_sum = tf.reduce_sum(ss_oh, axis=3, keep_dims=True)

    filter_ones = tf.ones([patch_size,patch_size,1,1], dtype=tf.float32)
    
    deconv_out_shape = tf.stack([1,deconv_out_H,deconv_out_W,1])  # Same spatial size as ss_dec with 1 channel

    counting = tf.nn.conv2d_transpose(ss_oh_sum,
                                         filter_ones,
                                         deconv_out_shape,
                                         [1,stride,stride,1],
                                         'VALID')

    counting = tf.tile(counting, [1,1,1,nC])  # Repeat along channel dim to make same size as ss_dec

    ss_feature = tf.divide(ss_dec, counting)
    ###############################################
    
    # HxWxC -> CxH*W
    ss_feature = tf.transpose(tf.reshape(ss_feature, [Hc*Wc,Cc]), [1,0])

    ### Color style-swapped encoding with style 
    Ds_sq = tf.diag(tf.pow(Ss[:k_s], 0.5))
    fcs_hat = tf.matmul(tf.matmul(tf.matmul(Us[:,:k_s], Ds_sq), Us[:,:k_s], transpose_b=True), ss_feature)
    fcs_hat = fcs_hat + ms

    ### Blend style-swapped & colored encoding with original content encoding
    blended = alpha * fcs_hat + (1 - alpha) * (fc + mc)
    # CxH*W -> CxHxW
    blended = tf.reshape(blended, (Cc,Hc,Wc))
    # CxHxW -> 1xHxWxC
    blended = tf.expand_dims(tf.transpose(blended, (1,2,0)), 0)

    return blended


In [15]:
def style_swap_layer(x, patch_size=3, stride=1):
    '''Efficiently swap content feature patches with nearest-neighbor style patches
       Original paper: https://arxiv.org/abs/1612.04337
       Adapted from: https://github.com/rtqichen/style-swap/blob/master/lib/NonparametricPatchAutoencoderFactory.lua
    '''
    content = K.expand_dims(x[0], 0)
    style = K.expand_dims(x[1], 0)
        
        
    nC = style.shape[-1]  # Num channels of input content feature and style-swapped output

    content_t = tf.transpose(tf.squeeze(content), (2, 0, 1))
    style_t = tf.transpose(tf.squeeze(style), (2, 0, 1))

    Cc, Hc, Wc = tf.unstack(content_t.shape)
    Cs, Hs, Ws = tf.unstack(style_t.shape)


    ### Extract patches from style image that will be used for conv/deconv layers
    style_patches = tf.extract_image_patches(style, [1, patch_size, patch_size, 1], 
                                                 [1, stride, stride, 1], [1, 1, 1, 1], 'VALID')

    before_reshape = style_patches.shape  # NxRowsxColsxPatch_size*Patch_size*nC

    style_patches = tf.reshape(style_patches, [before_reshape[1]*before_reshape[2], patch_size, patch_size, nC])

    style_patches = tf.transpose(style_patches, [1, 2, 3, 0])  # Patch_sizexPatch_sizexIn_CxOut_c

    # Normalize each style patch
    style_patches_norm = tf.nn.l2_normalize(style_patches, dim=3)

    # Compute cross-correlation/nearest neighbors of patches by using style patches as conv filters
    ss_enc = tf.nn.conv2d(content,
                              style_patches_norm,
                              [1, stride, stride, 1],
                              'VALID')

    # For each spatial position find index of max along channel/patch dim  
    ss_argmax = tf.argmax(ss_enc, axis=3)
    encC = ss_enc.shape[-1]  # Num channels in intermediate conv output, same as # of patches

    # One-hot encode argmax with same size as ss_enc, with 1's in max channel idx for each spatial pos
    ss_oh = tf.one_hot(ss_argmax, encC, 1., 0., 3)

    # Calc size of transposed conv out
    deconv_out_H = utils.deconv_output_length(ss_oh.shape[1], patch_size, 'valid', stride)
    deconv_out_W = utils.deconv_output_length(ss_oh.shape[2], patch_size, 'valid', stride)
    deconv_out_shape = tf.stack([1, deconv_out_H, deconv_out_W, nC])


    # Deconv back to original content size with highest matching (unnormalized) style patch swapped in for each content patch
    ss_dec = tf.nn.conv2d_transpose(ss_oh,
                                        style_patches,
                                        deconv_out_shape,
                                        [1, stride, stride, 1],
                                        'VALID')

    ### Interpolate to average overlapping patch locations
    ss_oh_sum = tf.reduce_sum(ss_oh, axis=3, keep_dims=True)

    filter_ones = tf.ones([patch_size, patch_size, 1, 1], dtype=tf.float32)

    deconv_out_shape = tf.stack([1, deconv_out_H, deconv_out_W, 1])  # Same spatial size as ss_dec with 1 channel

    counting = tf.nn.conv2d_transpose(ss_oh_sum,
                                             filter_ones,
                                             deconv_out_shape,
                                             [1, stride, stride,1],
                                             'VALID')

    counting = tf.tile(counting, [1, 1, 1, nC])  # Repeat along channel dim to make same size as ss_dec

    interpolated_dec = tf.divide(ss_dec, counting)
        
    return interpolated_dec

In [16]:
def benben_swap_layer(x, cell_size = 3):
    content_feature = tf.expand_dims(x[0], 0)
    style_feature = tf.expand_dims(x[1], 0)
    
    style_amount = style_feature.get_shape()[0].value
    
    print("style_amount:", style_amount)
    
    rows = tf.split(style_feature, num_or_size_splits=list(
            [cell_size] * (style_feature.get_shape()[1].value // cell_size) + [style_feature.get_shape()[1].value % cell_size]), axis=1)[:-1]
    cells = [tf.split(row, num_or_size_splits=list(
            [cell_size] * (style_feature.get_shape()[2].value // cell_size) + [style_feature.get_shape()[2].value % cell_size]), axis=2)[:-1]
                 for row in rows]

    print("row shape:" , np.array(cells).shape)
    
    stacked_cells = [tf.stack(row_cell, axis=4) for row_cell in cells]
    
    print("stacked_cells:" , np.array(stacked_cells).shape)
    
    filters = tf.concat(stacked_cells, axis=-1)
    
    print("filters:" , filters.get_shape())
    
    swaped_list = []
    for style_filter in tf.unstack(filters, axis=0, num=style_amount):
        
        height = tf.shape(content_feature)[1]
        width = tf.shape(content_feature)[2]
        #print(style_filter)
        normalized_filters = tf.nn.l2_normalize(style_filter, dim=3)
        
        print("normalized_filters:" , normalized_filters.get_shape())
        print("content_feature:" , content_feature.get_shape())
        
        """ change the strides to see difference"""
        similarity = tf.nn.conv2d(content_feature, normalized_filters, strides=[1, 1, 1, 1], padding="VALID")

        arg_max_filter = tf.argmax(similarity, axis=-1)
        one_hot_filter = tf.one_hot(arg_max_filter, depth=similarity.get_shape()[-1].value)

        swap = tf.nn.conv2d_transpose(one_hot_filter, style_filter, output_shape=tf.shape(content_feature),
                                      strides=[1, 1, 1, 1], padding="VALID")
        
        swaped_list.append(swap / 9.0)
    
    layer_out = tf.concat(swaped_list, axis=0)
    print(layer_out)
    return layer_out

In [11]:
def InverseNet_4(feature):
    ## feature = shape of content concatenate with style
    
    content_input = Input(shape=feature, name='content_input')
    style_input = Input(shape=feature, name='style_input')
    x = concatenate([content_input, style_input], axis=0)
    
    x = Lambda(wct_style_swap, output_shape=feature)(x)
    x = conv_in_relu(128, 3, 3, stride=(1,1))(x)
    x = UpSampling2D()(x)
    x = conv_in_relu(128, 3, 3, stride=(1,1))(x)
    x = conv_in_relu(64, 3, 3, stride=(1,1))(x)
    x = UpSampling2D()(x)
    x = conv_in_relu(64, 3, 3, stride=(1,1))(x)
    inverse_net_output = Conv2D(3, (3, 3), padding='same', name='inverse_net_output')(x)
    
    model = Model(inputs=[content_input, style_input], outputs=inverse_net_output)
   
    add_total_variation_loss(model.layers[-1], tv_weight)
      
    return model

def InverseNet_5(feature):
    ## feature = shape of content concatenate with style
    
    swapped_input = Input(shape=feature, name='swapped_input')
    
    x = conv_in_relu(128, 3, 3, stride=(1,1))(swapped_input)
    x = UnPooling2D()(x)
    x = conv_in_relu(128, 3, 3, stride=(1,1))(x)
    x = conv_in_relu(64, 3, 3, stride=(1,1))(x)
    x = UnPooling2D()(x)
    x = conv_in_relu(64, 3, 3, stride=(1,1))(x)
    inverse_net_output = Conv2D(3, (3, 3), padding='same', name='inverse_net_output')(x)
    
    model = Model(inputs=[swapped_input], outputs=inverse_net_output)
   
    add_total_variation_loss(model.layers[-1], tv_weight)
      
    return model

def InverseNet_3_3(feature):
    ## feature = shape of content concatenate with style
    
    swapped_input = Input(shape=feature, name='swapped_input')
    
    x = conv_in_relu(256, 3, 3, stride=(1,1))(swapped_input)
    x = conv_in_relu(256, 3, 3, stride=(1,1))(x)
    x = conv_in_relu(256, 3, 3, stride=(1,1))(x)
    x = conv_in_relu(256, 3, 3, stride=(1,1))(x)
    x = UpSampling2D()(x)
    x = conv_in_relu(128, 3, 3, stride=(1,1))(x)
    x = conv_in_relu(128, 3, 3, stride=(1,1))(x)
    x = UpSampling2D()(x)
    x = conv_in_relu(64, 3, 3, stride=(1,1))(x)
    x = conv_in_relu(64, 3, 3, stride=(1,1))(x)
    
    inverse_net_output = Conv2D(3, (3, 3), padding='same', name='inverse_net_output')(x)
    
    model = Model(inputs=[swapped_input], outputs=inverse_net_output)
   
    add_total_variation_loss(model.layers[-1], tv_weight)
      
    return model

def InverseNet_3_1(feature):
    ## feature = shape of content concatenate with style
    
    swapped_input = Input(shape=feature, name='swapped_input')
    
    x = conv_in_relu(256, 3, 3, stride=(1,1))(swapped_input)
    x = UpSampling2D()(x)
    x = conv_in_relu(128, 3, 3, stride=(1,1))(x)
    x = conv_in_relu(128, 3, 3, stride=(1,1))(x)
    x = UpSampling2D()(x)
    x = conv_in_relu(64, 3, 3, stride=(1,1))(x)
    x = conv_in_relu(64, 3, 3, stride=(1,1))(x)
    
    inverse_net_output = Conv2D(3, (3, 3), padding='same', name='inverse_net_output')(x)
              
    model = Model(inputs=[swapped_input], outputs=inverse_net_output)
   
    add_total_variation_loss(model.layers[-1], tv_weight)
      
    return model

def InverseNet_3_3_res(feature):
    ## feature = shape of content concatenate with style
    
    swapped_input = Input(shape=feature, name='swapped_input')
    
    x = conv_in_relu(256, 3, 3, stride=(1,1))(swapped_input)
    x = res_conv(256, 3, 3, stride=(1,1))(x)
    x = res_conv(256, 5, 5, stride=(1,1))(x)
    x = res_conv(256, 7, 7, stride=(1,1))(x)
    x = UpSampling2D()(x)
    x = conv_in_relu(128, 5, 5, stride=(1,1))(x)
    #x = conv_in_relu(128, 3, 3, stride=(1,1))(x)
    x = UpSampling2D()(x)
    x = conv_in_relu(64, 5, 5, stride=(1,1))(x)
    x = conv_in_relu(64, 3, 3, stride=(1,1))(x)
    
    inverse_net_output = Conv2D(3, (3, 3), padding='same', name='inverse_net_output',activation="tanh")(x)
    
    x = TanhNormalize()(inverse_net_output)

              
    model = Model(inputs=[swapped_input], outputs=x)
   
    add_total_variation_loss(model.layers[-1], tv_weight)
      
    return model

In [26]:
test_content_image_1 = 'images/content/101.jpg'
handsome = 'images/content/handsomeman.jpg'
lucechapel = 'images/content/lucechapel.jpg'
cat = 'images/content/gilbert.jpg'
bike = 'images/content/bike.jpg'
uchiha = "images/content/uchiha.jpg"
autumn = "images/content/autumn.jpg"
street = "images/content/street.jpg"

snow = 'images/style/snow.jpg'
starry_night = 'images/style/starry_night.jpg'
crystal = 'images/style/crystal.jpg'
la_muse = 'images/style/la_muse.jpg'
udnie = 'images/style/udnie.jpg'
water = 'images/style/water.jpg'
night = 'images/style/101night.jpg'
chingmin = 'images/style/Chingmin.jpg'
colorhole = 'images/style/colorhole.jpg'
des_glaneuses = 'images/style/des_glaneuses.jpg'
jingdo = 'images/style/jingdo.jpg'
monalisa = 'images/style/monalisa.jpg'
mountainwater = 'images/style/mountainwater.jpg'
picassoself = 'images/style/picassoself.jpg'
wave_crop = 'images/style/wave_crop.jpg'
tiger = 'images/style/tiger.jpg'
small_world = 'images/style/small_world.jpg'
composition_x = 'images/style/composition_x.jpg'
sky = 'images/style/sky.jpg'

bridge_image = preprocess_image(test_content_image_1)
lucechapel_image = preprocess_image(lucechapel)
handsome_image = preprocess_image(handsome)
cat_image = preprocess_image(cat)
bike_image = preprocess_image(bike)
uchiha_image = preprocess_image(uchiha)
autumn_image = preprocess_image(autumn)
street_image = preprocess_image(street)


snow_image = preprocess_image(snow)
starry_night_image = preprocess_image(starry_night)
crystal_image = preprocess_image(crystal)
la_muse_image = preprocess_image(la_muse)
udnie_image = preprocess_image(udnie)
water_image = preprocess_image(water)
night_image = preprocess_image(night)
chingmin_image = preprocess_image(chingmin)
colorhole_image = preprocess_image(colorhole)
des_glaneuses_image = preprocess_image(des_glaneuses)
jingdo_image = preprocess_image(jingdo)
monalisa_image = preprocess_image(monalisa)
mountainwater_image = preprocess_image(mountainwater)
picassoself_image = preprocess_image(picassoself)
wave_crop_image = preprocess_image(wave_crop)
tiger_image = preprocess_image(tiger)
small_world_image = preprocess_image(small_world)
composition_x_image = preprocess_image(composition_x)
sky_image = preprocess_image(sky)

(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)
(1, 512, 512, 3)


In [13]:
#wct_style_swap(K.concatenate([K.variable(creepy), K.variable(starry_night)], axis=0))

In [14]:
#benben_swap_layer(K.concatenate([K.variable(creepy), K.variable(starry_night)], axis=0))

In [17]:
#encode_net = build_encode_net_vgg16_3_1((IMG_WIDTH, IMG_HEIGHT, 3))
encode_net = build_encode_net_with_swap_3_1((IMG_WIDTH, IMG_HEIGHT, 3))

In [None]:
encode_net_2 = build_encode_net_with_swap_3_3((IMG_WIDTH, IMG_HEIGHT, 3))

In [18]:
encode_net.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
content_input (InputLayer)      (None, 512, 512, 3)  0                                            
__________________________________________________________________________________________________
style_input (InputLayer)        (None, 512, 512, 3)  0                                            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 512, 512, 3)  0           content_input[0][0]              
                                                                 style_input[0][0]                
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 512, 512, 64) 1792        concatenate_1[0][0]              
__________

In [19]:
inverse_net = InverseNet_3_1((int(IMG_WIDTH/4) ,int(IMG_HEIGHT/4) ,256))

In [20]:
#inverse_net.load_weights("models/inverse_net_vgg16.h5", by_name=True)
inverse_net.load_weights("models/inverse_net_vgg19.h5", by_name=True)

In [21]:
inverse_net.compile(optimizer="adam", loss='mse')
inverse_net.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
swapped_input (InputLayer)   (None, 128, 128, 256)     0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 128, 128, 256)     590080    
_________________________________________________________________
instance_normalization_1 (In (None, 128, 128, 256)     2         
_________________________________________________________________
activation_1 (Activation)    (None, 128, 128, 256)     0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 256, 256, 256)     0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 256, 256, 128)     295040    
_________________________________________________________________
instance_normalization_2 (In (None, 256, 256, 128)     2         
__________

In [22]:

# bridge_image = preprocess_image(test_content_image_1)
# lucechapel_image = preprocess_image(lucechapel)
# handsome_image = preprocess_image(handsome)
# cat_image

# snow_image = preprocess_image(snow)
# starry_night_image = preprocess_image(starry_night)
# crystal_image = preprocess_image(crystal)
# la_muse_image = preprocess_image(la_muse)
# udnie_image = preprocess_image(udnie)
# water_image = preprocess_image(water)
# night_image = preprocess_image(night)
# chingmin_image = preprocess_image(chingmin)
# colorhole_image = preprocess_image(colorhole)
# des_glaneuses_image = preprocess_image(des_glaneuses)
# jingdo_image = preprocess_image(jingdo)
# monalisa_image = preprocess_image(monalisa)
# mountainwater_image = preprocess_image(mountainwater)
# picassoself_image = preprocess_image(picassoself)
# wave_crop_image = preprocess_image(wave_crop)

In [23]:
def transform_images(content, content_name, save=False):
    features = []
    styles = [snow_image, starry_night_image, crystal_image, la_muse_image, udnie_image, water_image, night_image, 
             chingmin_image, colorhole_image, des_glaneuses_image, jingdo_image, monalisa_image, mountainwater_image, picassoself_image
             , wave_crop_image, tiger_image, small_world_image, composition_x_image, sky_image]
    
    style_names = ["snow", "starry_night", "crystal", "la_muse", "udnie", "water", "night", "chingmin", "colorhole", "des_glaneuses",
                  "jingdo", "monalisa", "mountainwater", "picassoself", "wave_crop", "tiger", "small_world", "composition_x"
                  , "sky"]
    
    for s in styles:
        features.append(encode_net_2.predict([content, s]))

    for i in range(len(features)):
        
        show_without_deprocess(x=inverse_net.predict([features[i]]), save=save, name=content_name+"_"+style_names[i])
        
def transform_iterate_images(content_name, iterations=3, iterate=0, save=False):
    features = []
    styles = [snow_image, starry_night_image, crystal_image, la_muse_image, udnie_image, water_image, night_image, 
             chingmin_image, colorhole_image, des_glaneuses_image, jingdo_image, monalisa_image, mountainwater_image, picassoself_image
             , wave_crop_image, tiger_image, small_world_image, composition_x_image, sky_image]
    
    style_names = ["snow", "starry_night", "crystal", "la_muse", "udnie", "water", "night", "chingmin", "colorhole", "des_glaneuses",
                  "jingdo", "monalisa", "mountainwater", "picassoself", "wave_crop", "tiger", "small_world", "composition_x"
                  , "sky"]
    
    
    for s in range(len(styles)):
        content = preprocess_image("output/%s_%s_%d.jpg" % (content_name, style_names[s], iterate))
        features.append(encode_net_2.predict([content, styles[s]]))

    for i in range(len(features)):
        show_without_deprocess(x=inverse_net.predict([features[i]]), save=save, name=content_name+"_"+style_names[i], iterate=iterate+1)
        
    

In [None]:
# test_iterate1 = 'output/101_iterate2.jpg'
# test_iterate2 = 'output/cat_iterate2.jpg'
# test_iterate1_img = preprocess_image(test_iterate1)
# test_iterate2_img = preprocess_image(test_iterate2)

In [27]:
image1_feature = encode_net.predict([bridge_image, snow_image])
image2_feature = encode_net.predict([bridge_image, starry_night_image])
image3_feature = encode_net.predict([bridge_image, crystal_image])
image4_feature = encode_net.predict([bridge_image, la_muse_image])
image5_feature = encode_net.predict([bridge_image, udnie_image])
image6_feature = encode_net.predict([bridge_image, water_image])

# image1_feature = encode_net_2.predict([test_iterate1_img, water_image])
# image2_feature = encode_net_2.predict([test_iterate2_img, water_image])
# image3_feature = encode_net_2.predict([handsome_image, uchiha_image])

In [28]:
predict_1 = inverse_net.predict([image1_feature])
predict_2 = inverse_net.predict([image2_feature])
predict_3 = inverse_net.predict([image3_feature])
predict_4 = inverse_net.predict([image4_feature])
predict_5 = inverse_net.predict([image5_feature])
predict_6 = inverse_net.predict([image6_feature])

# predict_1 = inverse_net.predict([image1_feature])
# predict_2 = inverse_net.predict([image2_feature])
# predict_3 = inverse_net.predict([image3_feature])
# show_without_deprocess(predict_1[0], save=True, name='test2_iterate1_img')
# show_without_deprocess(predict_2[0], save=True, name='test2_iterate2_img')
# show_without_deprocess(predict_3[0], save=True, name='test2_iterate3_img')

In [28]:
predict_1.shape

(1, 512, 512, 3)

In [29]:
start = time.time()
transform_images(street_image, "street", True)
end = time.time()

NameError: name 'encode_net_2' is not defined

In [None]:
transform_iterate_images("street", save=True, iterate=1)

In [None]:
print((end - start) / 17)

In [29]:
show_without_deprocess(predict_1[0], save=True, name='101_snow_image')
show_without_deprocess(predict_2[0], save=True, name='101_starry_night_image')
show_without_deprocess(predict_3[0], save=True, name='101_crystal_image')
show_without_deprocess(predict_4[0], save=True, name='101_la_muse_image')
show_without_deprocess(predict_5[0], save=True, name='101_udnie_image')
show_without_deprocess(predict_6[0], save=True, name='101_water_image')

Rescaling Image to (512, 512)
[[[148 126 131]
  [197 179 187]
  [178 171 171]
  ...
  [157 157 182]
  [172 167 185]
  [189 181 185]]

 [[167 139 148]
  [188 173 183]
  [168 171 171]
  ...
  [147 156 196]
  [137 129 168]
  [190 176 191]]

 [[183 152 153]
  [189 173 168]
  [176 181 164]
  ...
  [165 181 195]
  [162 154 180]
  [187 169 179]]

 ...

 [[185 163 170]
  [181 178 187]
  [179 195 188]
  ...
  [147 165 155]
  [164 153 151]
  [192 169 166]]

 [[202 168 164]
  [216 177 186]
  [205 175 181]
  ...
  [159 156 163]
  [180 145 157]
  [222 183 185]]

 [[183 149 141]
  [204 154 160]
  [183 133 142]
  ...
  [135 110 124]
  [142  92 108]
  [159 117 115]]]
Rescaling Image to (512, 512)
[[[181 146 136]
  [187 137 142]
  [189 157 159]
  ...
  [174 160 147]
  [181 156 138]
  [173 152 137]]

 [[176 131 139]
  [117  78 100]
  [126 113 135]
  ...
  [133 125 132]
  [158 131 124]
  [190 165 147]]

 [[190 171 183]
  [148 155 184]
  [133 171 198]
  ...
  [111 127 148]
  [143 136 136]
  [191 175 157]]

In [None]:
show_without_deprocess(predict_3[0], save=False, name='cat_water')