In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Layer

In [2]:
class SimpleQuadratic(Layer):
  def __init__(self, units = 32, activation = None):
    super(SimpleQuadratic, self).__init__()
    self.units = units
    self.activation = tf.keras.activations.get(activation)
  def build(self, input_shape):
    a_init = tf.random_normal_initializer()
    self.a = tf.Variable(initial_value=a_init(shape = (input_shape[-1], self.units), dtype = 'float32'), trainable = True)
    b_init = tf.random_normal_initializer()
    self.b = tf.Variable(initial_value=b_init(shape = (input_shape[-1], self.units), dtype = 'float32'), trainable = True)
    c_init = tf.zeros_initializer()
    self.c = tf.Variable(c_init(shape=(self.units), dtype="float32"), trainable=True)
    super().build(input_shape)
  def call(self, inputs):
    x_squared = tf.math.square(inputs)
    x_squared_times_a = tf.matmul(x_squared, self.a)    
    x_times_b = tf.matmul(inputs, self.b)
    return self.activation(x_squared_times_a + x_times_b + self.c)

In [5]:
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape = (28, 28)),
    SimpleQuadratic(units = 128, activation = 'relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation = 'softmax')
])
model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])

In [6]:
model.fit(x_train, y_train, epochs = 5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f574f94a790>

In [7]:
model.evaluate(x_test, y_test)



[0.0789637565612793, 0.9764999747276306]