In [1]:
import tensorflow as tf

In [3]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super(Linear, self).__init__(**kwargs)
        self.units = units
        
    def build(self, input_shape):
        units = self.units
        input_dim = input_shape[-1]
        
        self.w = self.add_weight(
            shape=(input_dim, units),
            initializer='random_normal',
            trainable=True
        )
        
        self.b = self.add_weight(
            shape=(units,),
            initializer='zeros',
            trainable=True
        )
        
    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b
    
    def get_config(self):
        config = super(Linear, self).get_config()
        config.update({ 'units': self.units })
        return config
    
    
x = tf.ones((2, 2))
y = Linear(4)(x)

print(x)
print(y)

tf.Tensor(
[[1. 1.]
 [1. 1.]], shape=(2, 2), dtype=float32)
tf.Tensor(
[[ 0.09609352  0.11880723 -0.02391201 -0.00234276]
 [ 0.09609352  0.11880723 -0.02391201 -0.00234276]], shape=(2, 4), dtype=float32)
