In [6]:
import numpy as np
from keras.layers import (Lambda, Input, Reshape,
                          Dense, 
                          Conv2D, Conv2DTranspose,
                          Flatten, MaxPool2D,)
from keras.losses import mse, mae, binary_crossentropy
from keras.models import Model
import keras.backend as K

In [12]:
def sampling(args):
    """Reparameterization trick by sampling fr an isotropic unit Gaussian.

    # Arguments
        args (tensor): mean and log of variance of Q(z|X)

    # Returns
        z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

In [19]:
input_shape = (1024,1,1) # if you want 1D fft inputs, make the input_shape (nfft,1,1)
filters = 1
layers = 2
kernel_size = [5,5]
strides = [2,2]
dilation = [1,1]
intermediate = 16
latent_dim = 4

inputs = Input(shape=input_shape)
x = inputs
for i in range(layers):
    filters *= 2
    x = Conv2D(filters,
               kernel_size=kernel_size,
               strides=strides,
               dilation_rate=dilation,
               activation='relu',
               padding='same')(x)


# shape info needed to build decoder model
shape = K.int_shape(x)

# generate latent vector Q(z|X)
x = Flatten()(x)
x = Dense(intermediate, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary with the TensorFlow backend
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

encoder = Model(inputs,[z_mean, z_log_var, z],name='encoder')
encoder.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            (None, 1024, 1, 1)   0                                            
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 512, 1, 2)    52          input_5[0][0]                    
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 256, 1, 4)    204         conv2d_13[0][0]                  
__________________________________________________________________________________________________
flatten_5 (Flatten)             (None, 1024)         0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
dense_10 (

In [20]:
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

for i in range(layers):
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        strides=strides,
                        activation='relu',
                        padding='same')(x)
    filters //= 2

x = Conv2DTranspose(filters=1,
                    kernel_size=kernel_size,
                    activation='relu', # we may need to play with this, the keras example has sigmoid
                    padding='same',
                    name='decoder_output')(x)

# if 1D input we need to maxpool dim 2 to collapse the extra bins the Transpose conv2D put in that dimension
if input_shape[1]==1:
    out_shape = K.int_shape(x)
    outputs = MaxPool2D(pool_size=(1,out_shape[2]), strides=None, padding='valid')(x)
else:
    outputs = x

decoder = Model(latent_inputs,outputs)
decoder.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z_sampling (InputLayer)      (None, 4)                 0         
_________________________________________________________________
dense_11 (Dense)             (None, 1024)              5120      
_________________________________________________________________
reshape_6 (Reshape)          (None, 256, 1, 4)         0         
_________________________________________________________________
conv2d_transpose_15 (Conv2DT (None, 512, 2, 4)         404       
_________________________________________________________________
conv2d_transpose_16 (Conv2DT (None, 1024, 4, 2)        202       
_________________________________________________________________
decoder_output (Conv2DTransp (None, 1024, 4, 1)        51        
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 1024, 1, 1)        0         
Total para

In [15]:
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae')

reconstruction_loss = binary_crossentropy(K.flatten(inputs),
                                          K.flatten(outputs))

reconstruction_loss *= input_shape[0] * input_shape[1]
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop')

In [11]:
n_samples = 10000
X = np.random.uniform(0,1,(n_samples,)+input_shape)

In [90]:
epochs = 1
batch_size = 1000
vae.fit(X,epochs=epochs,batch_size=batch_size)

Epoch 1/1


<keras.callbacks.History at 0x7f0c2c77c3c8>