In [15]:
import keras
from keras import layers
from keras import backend as k
from keras.models import Model
import numpy as np
k.clear_session()

In [16]:
img_shape = (28,28,1)
batch_size = 16
latent_dim = 2 #latent space shape is a 2-d plane corresponding to 2 node hidden layer values

In [17]:
input_img = keras.Input(shape = img_shape)
x = layers.Conv2D(32,3,padding = 'same', activation = 'relu')(input_img)
x = layers.Conv2D(64,3, padding = 'same', activation = 'relu', 
                  strides = (2,2))(x)
x = layers.Conv2D(64,3, padding = 'same', activation = 'relu')(x)
x = layers.Conv2D(64,3, padding = 'same', activation = 'relu')(x)
shape_before_flattening = k.int_shape(x)
x = layers.Flatten()(x)
x = layers.Dense(32, activation = 'relu')(x)
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)

In [18]:
def sampling(args):
    z_mean, z_log_var = args
    epsilon = k.random_normal(shape = (k.shape(z_mean)[0], latent_dim), mean = 0., stddev = 1.)
    return z_mean + k.exp(z_log_var) * epsilon

In [19]:
z = layers.Lambda(sampling)([z_mean, z_log_var])
print(z)

Tensor("lambda_1/add:0", shape=(?, 2), dtype=float32)


In [33]:
decoder_input = layers.Input(k.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flattening[1:]), activation = 'relu')(decoder_input)
x = layers.Reshape(shape_before_flattening[1:])(x)
x = layers.Conv2DTranspose(32,3,padding = 'same', activation = 'relu', strides = (2,2))(x)
x = layers.Conv2D(1,3,padding = 'same', activation = 'sigmoid')(x)
print(x)

Tensor("conv2d_6/Sigmoid:0", shape=(?, ?, ?, 1), dtype=float32)


In [34]:
decoder = Model(decoder_input, x)
z_decoded = decoder(z)
print(z_decoded)

Tensor("model_4/conv2d_6/Sigmoid:0", shape=(?, ?, ?, 1), dtype=float32)


In [22]:
class CustomVariationalLayer(keras.layers.Layer):
    def vae_loss(self, x, z_decoded):
        x = k.flatten(x)
        z_decoded = k.flatten(z_decoded)
        xent_loss = keras.metrics.binary_crossentropy(x,z_decoded)
        kl_loss = -5e-4 * k.mean(1+z_log_var - k.square(z_mean) - k.exp(z_log_var), axis = -1)
        return(k.mean(xent_loss + kl_loss)) #loss is average of classification error and the kullback liebler divergence
    
    def call(self, inputs):  # custom layer call override
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(z, z_decoded)
        self.add_loss(loss, inputs = inputs)
        return(x)        

In [23]:
y = CustomVariationalLayer()([input_img, z_decoded])

In [24]:
from keras.datasets import mnist

In [29]:
vae = Model(input_img, y)
vae. compile(optimizer = 'rmsprop', loss = None)
vae.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 28, 28, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 14, 14, 64)   18496       conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 14, 14, 64)   36928       conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (

  


In [30]:
(x_train, _), (x_test, y_test) = mnist.load_data()

In [31]:
x_train = x_train.astype('float32')/255.
x_train = x_train.reshape(x_train.shape + (1,))

x_test = x_test.astype('float32')/255.
x_test = x_test.reshape(x_test.shape + (1,))

In [32]:
vae.fit(x = x_train, y = None, shuffle = True, epochs = 10, batch_size = batch_size, validation_data = (x_test, None))

Train on 60000 samples, validate on 10000 samples
Epoch 1/10


InvalidArgumentError: Incompatible shapes: [12544] vs. [32]
	 [[Node: custom_variational_layer_1/logistic_loss/mul = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](custom_variational_layer_1/Log, custom_variational_layer_1/Reshape)]]
	 [[Node: loss_1/add/_223 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_1172_loss_1/add", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'custom_variational_layer_1/logistic_loss/mul', defined at:
  File "c:\users\brad\appdata\local\programs\python\python36\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
    app.start()
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\ipykernel\kernelapp.py", line 486, in start
    self.io_loop.start()
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\tornado\ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\zmq\eventloop\zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\zmq\eventloop\zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\zmq\eventloop\zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\ipykernel\kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\ipykernel\ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\ipykernel\zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\IPython\core\interactiveshell.py", line 2728, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\IPython\core\interactiveshell.py", line 2850, in run_ast_nodes
    if self.run_code(code, result):
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\IPython\core\interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-23-389811409cd3>", line 1, in <module>
    y = CustomVariationalLayer()([input_img, z_decoded])
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\keras\engine\topology.py", line 617, in __call__
    output = self.call(inputs, **kwargs)
  File "<ipython-input-22-123958ef496f>", line 12, in call
    loss = self.vae_loss(z, z_decoded)
  File "<ipython-input-22-123958ef496f>", line 5, in vae_loss
    xent_loss = keras.metrics.binary_crossentropy(x,z_decoded)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\keras\losses.py", line 77, in binary_crossentropy
    return K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\keras\backend\tensorflow_backend.py", line 3069, in binary_crossentropy
    logits=output)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\ops\nn_impl.py", line 181, in sigmoid_cross_entropy_with_logits
    relu_logits - logits * labels,
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\ops\math_ops.py", line 971, in binary_op_wrapper
    return func(x, y, name=name)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1198, in _mul_dispatch
    return gen_math_ops.mul(x, y, name=name)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 4991, in mul
    "Mul", x=x, y=y, name=name)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\framework\ops.py", line 3290, in create_op
    op_def=op_def)
  File "c:\users\brad\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\framework\ops.py", line 1654, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Incompatible shapes: [12544] vs. [32]
	 [[Node: custom_variational_layer_1/logistic_loss/mul = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](custom_variational_layer_1/Log, custom_variational_layer_1/Reshape)]]
	 [[Node: loss_1/add/_223 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_1172_loss_1/add", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
