In [4]:
import tensorflow as tf
from tensorflow.keras.layers import Layer

## Custom Layer with Activation

In [5]:
class SimpleDense(Layer):

    # add an activation parameter
    def __init__(self, units=32, activation=None):
        super(SimpleDense, self).__init__()
        self.units = units
        
        # define the activation to get from the built-in activation layers in Keras
        self.activation = tf.keras.activations.get(activation)


    def build(self, input_shape):
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(name="kernel",
            initial_value=w_init(shape=(input_shape[-1], self.units),
                                 dtype='float32'),
            trainable=True)
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(name="bias",
            initial_value=b_init(shape=(self.units,), dtype='float32'),
            trainable=True)
        super().build(input_shape)


    def call(self, inputs):
        
        return self.activation(tf.matmul(inputs, self.w) + self.b)

In [3]:
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    SimpleDense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

Train on 60000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.07649828119501471, 0.9773]

In [18]:
x = SimpleDense()
x(tf.constant([[1,2,3]], dtype=tf.float32))

<tf.Tensor: shape=(1, 32), dtype=float32, numpy=
array([[ 0.14905033, -0.10889474, -0.22530735, -0.16234002, -0.02111359,
        -0.00377623,  0.11889336, -0.2083413 , -0.01421928, -0.032706  ,
         0.12195873, -0.01101046,  0.1242689 , -0.32487822,  0.05884798,
         0.27692515,  0.19730459, -0.2507856 ,  0.16286388, -0.02388777,
         0.02844812,  0.3027302 ,  0.08558559, -0.10599934, -0.27116987,
        -0.30830002, -0.243434  , -0.15631768, -0.17491907,  0.06462561,
        -0.08592308, -0.0228786 ]], dtype=float32)>

In [19]:
renamedClass = SimpleDense
x = renamedClass()
x(tf.constant([[1,2,3]], dtype=tf.float32))

<tf.Tensor: shape=(1, 32), dtype=float32, numpy=
array([[ 0.40951863, -0.18145406,  0.04423184,  0.08229198,  0.25369936,
        -0.1285687 , -0.21048641, -0.22881882, -0.08406536,  0.02499339,
        -0.0929079 , -0.03146731, -0.2177437 ,  0.09126538,  0.1089728 ,
         0.10752554,  0.27490938, -0.18165825,  0.03771023,  0.08802094,
         0.02851969,  0.16809234,  0.16256362, -0.27289653,  0.21033329,
         0.09239422, -0.19682926,  0.09811161, -0.12423241, -0.49919328,
        -0.27128595,  0.02388945]], dtype=float32)>