# Deep Pensieve™
A Residual Multi-Stage Maximum Mean Discrepancy Variational Resize-Convolution Auto-Encoder with Group Normalization (RMSMMDVRCAECwGN)

## Imports

In [None]:
import time
import json
import random
import numpy as np
import tensorflow as tf

from libs import utils, gif
from libs.group_norm import GroupNormalization

from keras.models import Model, load_model, model_from_json
from keras.layers import Input, Flatten, Reshape, Add, Multiply, Activation, Lambda
from keras.layers import Dense, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization
from keras.objectives import mean_squared_error, mean_absolute_error
from keras.callbacks import ModelCheckpoint, LambdaCallback

from keras import optimizers
from keras import backend as K

## Load Images

In [None]:
DIRECTORY = 'roadtrip'

SIZE = 256
CHANNELS = 3
FEATURES = SIZE*SIZE*CHANNELS

MODEL_NAME = DIRECTORY+'-'+str(SIZE)

In [None]:
# load images
imgs, xs, ys = utils.load_images(directory=DIRECTORY,rx=SIZE,ry=SIZE)

# normalize pixels
IMGS = imgs/127.5 - 1
FLAT = np.reshape(IMGS,(-1,SIZE*SIZE*CHANNELS))
SAMPLES =  np.random.permutation(FLAT)[:9]
TOTAL_BATCH = IMGS.shape[0]

# print shapes
print("MODEL: ",MODEL_NAME)
print("IMGS: ",IMGS.shape)
print("FLAT: ",FLAT.shape)
print("SAMPLES: ",SAMPLES.shape)

## Maximum Mean Discrepancy 

In [None]:
def compute_kernel(x, y):
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1]))
    tiled_y = tf.tile(tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1]))
    return tf.exp(-tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32))

def compute_mmd(x, y):
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return tf.reduce_mean(x_kernel) + tf.reduce_mean(y_kernel) - 2 * tf.reduce_mean(xy_kernel)

## Encoder

In [None]:
def encode(x):
    # set current layer
    current_layer = Reshape((SIZE,SIZE,CHANNELS))(x)
    
    # convolution layers
    for layer, n_filters in enumerate(FILTERS):

        # stacked 3x3 convolutions with group normalization + activation
        current_layer = Conv2D(n_filters,3,padding='SAME',kernel_initializer=INITIALIZER)(current_layer)
        current_layer = GroupNormalization(groups=n_filters,axis=-1)(current_layer)
        current_layer = Activation(ACTIVATION)(current_layer)

        current_layer = Conv2D(n_filters,3,padding='SAME',kernel_initializer=INITIALIZER)(current_layer)
        current_layer = GroupNormalization(groups=n_filters,axis=-1)(current_layer)
        current_layer = Activation(ACTIVATION)(current_layer)
         
        # max pooling
        current_layer = MaxPooling2D()(current_layer)
    
    # grab the last shape for reconstruction
    shape = current_layer.get_shape().as_list()
    
    # flatten
    flat = Flatten()(current_layer)
    
    # latent vector
    z = Dense(LATENT_DIM)(flat)
    
    return z, (shape[1],shape[2],shape[3])

## Decoder

In [None]:
def decode(z,z_g,shape=None):
    
    # reverse the encoder
    filters = FILTERS[::-1]

    # inflate
    inflated = shape[0]*shape[1]*shape[2]
    inflate = Dense(inflated)
    current_layer = inflate(z) ; generator = inflate(z_g)
    
    # reshape
    reshape = Reshape(shape)
    current_layer = reshape(current_layer) ; generator = reshape(generator)
    
    # build layers
    for layer, n_filters in enumerate(filters):
        
        # upsample
        u = UpSampling2D()
        current_layer = u(current_layer) ; generator = u(generator)

        # stacked 3x3 convolutions with group normalization + activation
        c1 = Conv2D(n_filters,3,padding='SAME',kernel_initializer=INITIALIZER)
        b1 = GroupNormalization(groups=n_filters,axis=-1)
        a1 = Activation(ACTIVATION)

        current_layer = c1(current_layer) ; generator = c1(generator)
        current_layer = b1(current_layer) ; generator = b1(generator)
        current_layer = a1(current_layer) ; generator = a1(generator)

        c2 = Conv2D(n_filters,3,padding='SAME',kernel_initializer=INITIALIZER)
        b2 = GroupNormalization(groups=n_filters,axis=-1)
        a2 = Activation(ACTIVATION)

        current_layer = c2(current_layer) ; generator = c2(generator)
        current_layer = b2(current_layer) ; generator = b2(generator)
        current_layer = a2(current_layer) ; generator = a2(generator)
    
    # output convolution + activation
    conv = Conv2D(CHANNELS,1,padding='SAME')
    activation = Activation('tanh')
    
    current_layer = conv(current_layer)       ; generator = conv(generator)
    current_layer = activation(current_layer) ; generator = activation(generator)

    # flatten
    f = Flatten()
    current_layer = f(current_layer) ; generator = f(generator)
    
    return current_layer, generator

## Residual

In [None]:
def residual(x,x_g):
    
    current_layer = x ; generator = x_g

    # shortcuts
    shortcut = current_layer 
    shortcut_g = generator

    # conv 1
    c1 = Conv2D(R_FILTERS,3,padding='SAME',kernel_initializer=INITIALIZER)
    current_layer = c1(current_layer) ; generator = c1(generator)
    
    # activation 1
    a1 = Activation(ACTIVATION)
    current_layer = a1(current_layer) ; generator = a1(generator)

    # conv 2
    c2 = Conv2D(R_FILTERS,3,padding='SAME',kernel_initializer=INITIALIZER)
    current_layer = c2(current_layer) ; generator = c2(generator)

    # residual scaling
    current_layer = Lambda(lambda x: x * .1)(current_layer)
    generator = Lambda(lambda x: x * .1)(generator)

    # fix shortcut shape if mismatch
    if(shortcut.shape[-1] != current_layer.shape[-1]):
        s = Conv2D(R_FILTERS,1,padding='SAME',kernel_initializer=INITIALIZER)
        shortcut = s(current_layer) ; shortcut_g = s(generator)
        
    # merge shortcut
    merge = Add()
    current_layer = merge([current_layer, shortcut]) ; generator = merge([generator, shortcut_g])

    return current_layer, generator

## Refiner

In [None]:
def refine(x,x_g):
    
    # reshape
    reshape = Reshape((SIZE,SIZE,CHANNELS))
    current_layer = reshape(x) ; generator = reshape(x_g)

    # residual layers
    for layer in range(R_LAYERS):
    
        # residual block
        current_layer, generator = residual(current_layer,generator)
    
    # output convolution + activation
    conv = Conv2D(CHANNELS,1,padding='SAME')
    activation = Activation('tanh')
    
    current_layer = conv(current_layer)       ; generator = conv(generator)
    current_layer = activation(current_layer) ; generator = activation(generator)
    
    return Flatten()(current_layer), Flatten()(generator)

## Callbacks

In [None]:
RECONS = []

def gifit(epoch=None):
    if (epoch % GIF_STEPS == 0):
        print('saving gif ...')
        [i,z,y] = AUTOENCODER.predict_on_batch(SAMPLES)
        img = np.clip(127.5*(i+1).reshape((-1, SIZE, SIZE, CHANNELS)), 0, 255)
        RECONS.append(utils.montage(img).astype(np.uint8))
        
def saveit(epoch=None):
    if (epoch == 0):
        print('saving model ...')
        with open(MODEL_NAME+'-model.json', 'w') as f:
            json.dump(AUTOENCODER.to_json(), f, ensure_ascii=False)
            
    if (epoch % MODEL_STEPS == 0):
        print('saving weights ...')
        AUTOENCODER.save_weights(MODEL_NAME+'-weights.h5')
        
        print('saving encoder ...')
        ENCODER.save(MODEL_NAME+'-encoder.hdf5')
        
        print('saving generator ...')
        GENERATOR.save(MODEL_NAME+'-generator.hdf5')

## Model

In [None]:
# 512px
# FILTERS = [64,80,96,112,96,80,64]

# 256px
FILTERS = [64,96,128,160,128,64]

# Residuals
R_LAYERS  = 16
R_FILTERS = 64
R_SCALING = .1

# Default initializer and activation
INITIALIZER = 'he_normal'
ACTIVATION  = 'elu'

# Latent dimension size
LATENT_DIM = 1024

## Training

In [None]:
EPOCHS      = 10001
BATCH_SIZE  = 8

MODEL_STEPS = 50
GIF_STEPS   = 50

In [None]:
# input
X = Input(shape=(FEATURES,))

# latent
Z, shape = encode(X)

# latent loss
epsilon = tf.random_normal(tf.stack([BATCH_SIZE, LATENT_DIM]))
latent_loss = compute_mmd(epsilon, Z)

# generator input
Z_G = Input(shape=(LATENT_DIM,))

# coarse reconstruction
Y, YG = decode(Z,Z_G,shape)
coarse_loss = mean_squared_error(X,Y)

# fine reconstruction
IMG, IMG_G = refine(Y,YG)
fine_loss = mean_absolute_error(X,IMG)

# define autoencoder
AUTOENCODER = Model(inputs=[X], outputs=[IMG,Z,Y])
AUTOENCODER.add_loss(latent_loss)
AUTOENCODER.add_loss(coarse_loss)
AUTOENCODER.add_loss(fine_loss)

# define encoder
ENCODER = Model(inputs=[X], outputs=[Z])

# define generator
GENERATOR = Model(inputs=[Z_G], outputs=[IMG_G])

# define optimizer
ADAM = optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8, decay=0.0, amsgrad=True)

# compile models
AUTOENCODER.compile(optimizer=ADAM)
ENCODER.compile(optimizer=ADAM,loss='mse')
GENERATOR.compile(optimizer=ADAM,loss='mse')

# print summary
AUTOENCODER.summary()

# callbacks
giffer = LambdaCallback(on_epoch_end=lambda epoch, logs: gifit(epoch))
saver = LambdaCallback(on_epoch_end=lambda epoch, logs: saveit(epoch))

# fit model
AUTOENCODER.fit(x=FLAT,batch_size=BATCH_SIZE,verbose=1,epochs=EPOCHS,shuffle=True,callbacks=[giffer,saver])

# save animated gif
gif.build_gif(RECONS, saveto=MODEL_NAME+'-final'+ "-"+str(time.time())+'.gif')

print("done")

## Load Models

In [None]:
print('loading encoder ...')
ENCODER = load_model(MODEL_NAME+'-encoder.hdf5')

print('loading generator ...')
GENERATOR = load_model(MODEL_NAME+'-generator.hdf5')

print('done')

## Reconstruction 

In [None]:
def reconstruct(index=0):
    x = np.reshape(FLAT[index],(-1,FEATURES))
    z = np.reshape(ENCODER.predict_on_batch(x),(-1,LATENT_DIM))
    y = np.reshape(GENERATOR.predict_on_batch(z),(-1,FEATURES))
    
    t = IMGS[index]/2 + .5
    img = np.reshape(y[0]/2 + .5,(SIZE,SIZE,CHANNELS))
    
    print("PSNR: %.3f" % utils.psnr(t,img))
    print("MS-SSIM: %.3f" % utils.MultiScaleSSIM(np.reshape(t,(1,SIZE,SIZE,CHANNELS)),
                                                 np.reshape(img,(1,SIZE,SIZE,CHANNELS)),
                                                 max_val=1.))
    
   
    return t, img

In [None]:
r = random.randint(0,TOTAL_BATCH) ; print(r)
orig, img = reconstruct(r)
utils.showImagesHorizontally(images=[orig,img])

## Latent  Animation

In [None]:
def random_latents(n_imgs=3,path='linear',steps=30,slices=1,directory='roadtrip'):
    imgs = np.random.permutation(FLAT)[:n_imgs]
    latent_animation(imgs,steps,slices,path=path)

def latent_animation(imgs=None,steps=None,slices=None,path=None,filename="latent-animation-"):
    # get encodings
    print('getting latent vectors ...')
    latents = []
    for index,img in enumerate(imgs):
        img = np.reshape(img,(-1,FEATURES))
        latent = ENCODER.predict_on_batch(img)
        latents.append(latent)

    # calculate latent transitions
    print('calculating latent manifold path ...')
    recons = []
    current_step = None
    for i in range(len(latents)-1):
        print("IMG: " + str(i))
        l1 = latents[i]
        l2 = latents[i+1]

        # latent image distance
        image_distance = l2 - l1

        # sine wave for animation steps
        integral = steps*(1+np.cos(np.pi/steps))/np.pi
        normalizer = image_distance/integral

        # start image
        current_step = l1
        
        # build latent vectors to animate transition
#         recons.append(l1)
#         for i in range(steps):
#             if (path == 'contract'):
#                 current_step = current_step + normalizer*np.sin(np.pi*i/steps)
#             else: # linear
#                 current_step = l1 + i*image_distance/steps

#             recons.append(current_step)
#         recons.append(l2)
        
        recons.append(l1)
        for i in range(steps):
            current_step = l1 + i*image_distance/steps
            recons.append(current_step)
            
            if(i > 1 and (i+1) % int(steps/slices) == 0):
                print('reconstructing ... ',i)
                recons = np.reshape(recons,(-1,LATENT_DIM))
                i = GENERATOR.predict_on_batch(recons)

                # de-normalize and clip the output
                final = np.clip((127.5*(i+1)).reshape((-1,SIZE,SIZE,CHANNELS)),0,255)

                # build the gif
                filename = filename+str(time.time())
                gif.build_gif([utils.montage([r]).astype(np.uint8) for r in final], saveto=filename+"-final.gif",dpi=SIZE)

                print(filename)
                
                recons = []
                filename="latent-animation-"
        recons.append(l2)
                
#         recons.append(l2)
    
    # get predictions from latent vectors
    print('reconstructing ... ')
    recons = np.reshape(recons,(-1,LATENT_DIM))
    i = GENERATOR.predict_on_batch(recons)

    # de-normalize and clip the output
    final = np.clip((127.5*(i+1)).reshape((-1,SIZE,SIZE,CHANNELS)),0,255)

    # build the gif
    filename = filename+str(time.time())
    gif.build_gif([utils.montage([r]).astype(np.uint8) for r in final], saveto=filename+"-final.gif",dpi=SIZE)

    print(filename)


In [None]:
random_latents(n_imgs=3,steps=160,slices=4)

In [None]:
imgs =  np.random.permutation(FLAT)
for i in range(TOTAL_BATCH):
    print(i)
    latent_animation([imgs[i],imgs[i+1]],40)
    