In [5]:
from keras.layers import Input, Conv2D, Conv2DTranspose, LeakyReLU, Flatten, Dense, Reshape, Lambda
from keras.backend import random_normal
from keras.models import Model
from keras import optimizers
import numpy as np
import tensorflow as tf
import keras.backend as K
import keras

In [7]:
def _parse_function(filename):
    r_image_string = tf.read_file(filename)
    r_image_decoded = tf.image.decode_jpeg(r_image_string,channels=3)
    r_image_decoded = tf.reshape(r_image_decoded,[256,256,3])
    r_image_decoded = tf.image.convert_image_dtype(r_image_decoded,tf.float32)

    # r_image_decoded = tf.divide(r_image_decoded,[255])


    r_image_decoded_std = tf.image.per_image_standardization(r_image_decoded)

    r_image_decoded_std = tf.image.resize_images(r_image_decoded_std,[64,64])

    return r_image_decoded_std



def parse():
    rainy_start = 637
    rainy_end = 1207

    rainy_files = []

    for i in range(rainy_start,rainy_end):
        file_id = "{0:0=4d}".format(i)
        rainy_files.append('./rainy/000'+str(file_id)+'.jpeg')


    rainy_filenames = tf.constant(rainy_files)

    dataset = tf.data.Dataset.from_tensor_slices(rainy_filenames)
    dataset = dataset.map(_parse_function).repeat().shuffle(buffer_size=50).batch(4,True)
    return dataset


In [8]:
def vae_loss(y_true, y_pred):
    kl = K.sum(0.5 * K.sum(K.exp(dense_std) + K.square(dense_mean) - 1. - dense_std, axis=1))
    
    
    recon = K.sum(K.binary_crossentropy(y_pred, y_true), axis=1)

    
    return kl + recon


def sampling(args):
    """Reparameterization trick by sampling from an isotropic unit Gaussian.
    # Arguments
        args (tensor): mean and log of variance of Q(z|X)
    # Returns
        z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

In [9]:
batch_size = 4
dataset = parse().make_one_shot_iterator()
inputt = Input(tensor=dataset.get_next())

# encoder
conv1 = Conv2D(32, kernel_size=2, strides=2)(inputt)
conv1 = LeakyReLU(alpha=0.2)(conv1)
conv2 = Conv2D(64, kernel_size=2, strides=2)(conv1)
conv2 = LeakyReLU(alpha=0.2)(conv2)

conv3 = Conv2D(128, kernel_size=2, strides=2)(conv2)
conv3 = LeakyReLU(alpha=0.2)(conv3)

# latent space
flattened = Flatten()(conv3)
dense_mean = Dense(1024)(flattened)
dense_std = Dense(1024)(flattened)

# random_normal_samples = random_normal([batch_size,1024])
z = Lambda(sampling, output_shape=(1024,), name='z')([dense_mean, dense_std])
# decoder
conv3_shape = conv3.shape.as_list()
dense_transpose = Dense(conv3_shape[1]*conv3_shape[2]*conv3_shape[3])(z)
conv3t = Reshape((conv3_shape[1],conv3_shape[2],conv3_shape[3]))(dense_transpose)
conv2t = Conv2DTranspose(64,kernel_size=2,strides=2)(conv3t)
conv1t = Conv2DTranspose(32,kernel_size=2,strides=2)(conv2t)

output = Conv2DTranspose(3,kernel_size=2,strides=2)(conv1t)



Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.


In [14]:
tb = keras.callbacks.TensorBoard(log_dir='./ckpt', write_grads=True, write_graph=True, write_images=True)


model = Model(inputt, output)
model.compile(optimizers.Adam(), loss=vae_loss, target_tensors=[dataset.get_next()])
model.fit(steps_per_epoch=100, epochs=5, verbose=2, callbacks=[tb])

Epoch 1/5
 - 2s - loss: -5.3629e+14
Epoch 2/5
 - 2s - loss: -7.3473e+14
Epoch 3/5
 - 2s - loss: -7.1989e+14
Epoch 4/5
 - 2s - loss: -7.4765e+14
Epoch 5/5
 - 2s - loss: -8.8723e+14


<keras.callbacks.History at 0x7fb21c3c33c8>

In [214]:
model.summary()


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_113 (InputLayer)          (4, 128, 128, 3)     0                                            
__________________________________________________________________________________________________
conv2d_300 (Conv2D)             (4, 64, 64, 32)      416         input_113[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_292 (LeakyReLU)     (4, 64, 64, 32)      0           conv2d_300[0][0]                 
__________________________________________________________________________________________________
conv2d_301 (Conv2D)             (4, 32, 32, 64)      8256        leaky_re_lu_292[0][0]            
__________________________________________________________________________________________________
leaky_re_l