[View in Colaboratory](https://colab.research.google.com/github/oduerr/dl_tutorial/blob/master/tensorflow/keras/using_tf_in_keras.ipynb)

### Creating Custom Layer

Keras is great, but sometimes you need the flexibility of the low level framework. 

For example, currently (spring 2018) keras does not support for dropout in the forward pass. If you want to add this you can build your own function.

#### Stateless custom layes
If you don't need variables or other state in your custom layer, you can use the 

In [12]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf 
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.layers.core import Lambda # needed to build the custom layer
from keras import backend as K #Now we have access to the backend (could be tensorflow, theano,... )

tf.__version__, keras.__version__

('1.6.0', '2.1.5')

##### Definition of the custom layer
We use the `droupout` function of the backend, to create a custom function. 

In [0]:
def mcdropout(x):
  #return tf.nn.dropout(x=x, keep_prob=0.33333) #using TensorFlow
  return K.dropout(x, level=0.5) # beeing agnostic of the backend

##### Integrating the custom function as a layer

We now integrate the custom function in a layer

In [0]:
model = Sequential()
model.add(Lambda(mcdropout, input_shape=(5,)))
#model.add(Dense(10))
#... Usually you would have many more layers 
model.compile(loss='categorical_crossentropy',optimizer='adam')

In [82]:
model.predict(np.ones((2,5))) 

array([[2., 2., 0., 2., 2.],
       [0., 2., 0., 0., 0.]], dtype=float32)

#### All in one

If you really need to show that you master python you can also do this in one line. 

In [0]:
model = Sequential()
model.add(Lambda(lambda x: K.dropout(x, level=0.5), input_shape=(5,)))
#model.add(Dense(10))
#... Usually you would have many more layers 
model.compile(loss='categorical_crossentropy',optimizer='adam')

In [86]:
model.predict(np.ones((2,5))) 

array([[0., 2., 0., 2., 0.],
       [0., 0., 2., 2., 2.]], dtype=float32)