In [1]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0

In [3]:
model = keras.Sequential(
    [
        layers.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), padding="same"),
        layers.ReLU(),
        layers.Conv2D(128, (3, 3), padding="same"),
        layers.ReLU(),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="model",
)

In [4]:
print(model.summary())

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 28, 28, 64)        640       
                                                                 
 re_lu (ReLU)                (None, 28, 28, 64)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 28, 28, 128)       73856     
                                                                 
 re_lu_1 (ReLU)              (None, 28, 28, 128)       0         
                                                                 
 flatten (Flatten)           (None, 100352)            0         
                                                                 
 dense (Dense)               (None, 10)                1003530   
                                                                 
Total params: 1078026 (4.11 MB)
Trainable params: 1078026 (4.

In [5]:

class CustomFit(keras.Model):
    def __init__(self, model):
        super(CustomFit, self).__init__()
        self.model = model

    def compile(self, optimizer, loss):
        super(CustomFit, self).compile()
        self.optimizer = optimizer
        self.loss = loss

    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            # Caclulate predictions
            y_pred = self.model(x, training=True)

            # Loss
            loss = self.loss(y, y_pred)

        # Gradients
        training_vars = self.trainable_variables
        gradients = tape.gradient(loss, training_vars)

        # Step with optimizer
        self.optimizer.apply_gradients(zip(gradients, training_vars))
        acc_metric.update_state(y, y_pred)

        return {"loss": loss, "accuracy": acc_metric.result()}

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_pred = self.model(x, training=False)

        # Updates the metrics tracking the loss
        loss = self.loss(y, y_pred)

        # Update the metrics.
        acc_metric.update_state(y, y_pred)
        return {"loss": loss, "accuracy": acc_metric.result()}

- compile method of the CustomFit class is used to specify the optimizer and loss function to be used during training. This method is overridden from the parent class (keras.Model). The overridden method is calling the parent method to ensure proper compilation.
- A tf.GradientTape context is created to record the operations that involve calculating gradients.
- The optimizer's apply_gradients method is used to update the model's trainable variables based on the calculated gradients.

In [6]:

acc_metric = keras.metrics.SparseCategoricalAccuracy(name="accuracy")

In [7]:
training = CustomFit(model)
training.compile(
    optimizer=keras.optimizers.Adam(learning_rate=3e-4),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)

- creating a instance of custom model class

In [8]:
training.fit(x_train, y_train, batch_size=64, epochs=2,verbose=2)
training.evaluate(x_test, y_test, batch_size=64,verbose=2)

Epoch 1/2
938/938 - 186s - loss: 0.0392 - accuracy: 0.9489 - 186s/epoch - 198ms/step
Epoch 2/2
938/938 - 188s - loss: 0.2251 - accuracy: 0.9658 - 188s/epoch - 201ms/step
157/157 - 7s - loss: 0.0028 - accuracy: 0.9673 - 7s/epoch - 42ms/step


[0.9673230648040771, 0.0028109289705753326]