In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras import backend as K

In [2]:
class SimpleQuadratic(Layer):
    def __init__(self, units=32, activation=None):
        '''Initializes the class and sets up the internal variables'''
        super(SimpleQuadratic, self).__init__()
        self.units = units
        self.activation = tf.keras.activations.get(activation)
    
    def build(self, input_shape):
        '''Create the state of the layer (weights)'''
        ab_init = tf.random_normal_initializer() 
        self.a = tf.Variable(initial_value = ab_init(shape=(input_shape[-1], self.units)), trainable=True, name="a", dtype="float32")
        self.b = tf.Variable(initial_value = ab_init(shape=(input_shape[-1], self.units)), trainable=True, name="b", dtype="float32")
        c_init = tf.zeros_initializer() 
        self.c = tf.Variable(initial_value = c_init(shape=(self.units,)), trainable=True, name="c", dtype="float32")
    
    def call(self, inputs):
        '''Defines the computation from inputs to outputs'''
        return self.activation(tf.matmul(tf.math.square(inputs), self.a) + tf.matmul(inputs, self.b) + self.c)

In [3]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  SimpleQuadratic(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=10)
model.evaluate(x_test, y_test)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[0.08425131440162659, 0.977400004863739]