In [7]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import Sequential, activations
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dropout, Flatten, Layer

In [4]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train / 255.
X_test = X_test / 255.

In [5]:
class SimpleDense(Layer):
    def __init__(self, units=32, activation=None):
        super().__init__()
        self.units = units
        self.activation = activations.get(activation)
        
    def build(self, input_shape): 
        # init layer state (e.g., weights)
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            name='weights', 
            initial_value=w_init(shape=(input_shape[-1], self.units), 
                                 dtype='float32'),
            trainable=True)
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(
            name='bias', 
            initial_value=b_init(shape=(self.units,), dtype='float32'),
            trainable=True)
        
    def call(self, inputs):
        # Forward pass
        return self.activation(tf.matmul(inputs, self.w) + self.b)

In [8]:
mod = Sequential([Flatten(input_shape=(28, 28)),
                  SimpleDense(128, activation='relu'),
                  Dropout(0.2),
                  SimpleDense(10, activation='softmax')])

In [9]:
mod.compile(optimizer='adam', 
            loss='sparse_categorical_crossentropy', 
            metrics=['accuracy'])

In [10]:
mod.fit(X_train, y_train, epochs=5)
mod.evaluate(X_test, y_test)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.07487912476062775, 0.9765999913215637]

In [25]:
mod.predict([10.])

array([[18.981302]], dtype=float32)

In [26]:
my_layer.variables

[<tf.Variable 'sequential_3/simple_dense_5/weights:0' shape=(1, 1) dtype=float32, numpy=array([[1.9972901]], dtype=float32)>,
 <tf.Variable 'sequential_3/simple_dense_5/bias:0' shape=(1,) dtype=float32, numpy=array([-0.99159825], dtype=float32)>]