# Tensorflow Keras subclassing

### Layer subclassing

In [None]:
class MyLinear(tensorflow.keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(initial_value=w_init(shape=(input_dim, units), dtype="float32"), trainable=True,)
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(initial_value=b_init(shape=(units,), dtype="float32"), trainable=True)

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

In [None]:
x = tf.ones((2, 2))
myL = MyLinear(4, 2)
y = myL(x)
print(y)

### Model subclassing
Inherits from Functional model

In [None]:
class ModelSubClassing(tensorflow.keras.Model):
    def __init__(self, num_classes):
        super(ModelSubClassing, self).__init__()
        
        # Layer block 1
        self.conv1 = tf.keras.layers.Conv2D(32, 3, strides=2, activation="relu")
        self.max1  = tf.keras.layers.MaxPooling2D(3)
        self.bn1   = tf.keras.layers.BatchNormalization()

        # Layer block 2
        self.conv2 = tf.keras.layers.Conv2D(64, 3, activation="relu")
        self.bn2   = tf.keras.layers.BatchNormalization()
        self.drop  = tf.keras.layers.Dropout(0.3)

        # GAP, followed by Classifier
        self.gap   = tf.keras.layers.GlobalAveragePooling2D()
        self.dense = tf.keras.layers.Dense(num_classes)

    def call(self, input_tensor, training=False):
        
        # forward pass: block 1 
        x = self.conv1(input_tensor)
        x = self.max1(x)
        x = self.bn1(x)

        # forward pass: block 2 
        x = self.conv2(x)
        x = self.bn2(x)

        # droput followed by gap and classifier
        x = self.drop(x)
        x = self.gap(x)
        return self.dense(x)

In [None]:
model = ModelSubClassing(10)
model.compile(loss = tf.keras.losses.CategoricalCrossentropy(), metrics = tf.keras.metrics.CategoricalAccuracy(), optimizer = tf.keras.optimizers.Adam())
model.fit(x_train, y_train, batch_size=128, epochs=1);