In [9]:
import tensorflow as tf

fmnist = tf.keras.datasets.fashion_mnist

(x_trains, y_trains), (t_tests, y_tests) = fmnist.load_data()

x_trains = x_trains/255
y_trains = y_trains/255

# Creating a Callback class
You can create a callback by defining a class that inherits the <b>tf.keras.callbacks.Callback</b> base class. From there, you can define available methods to set where the callback will be executed. For instance below, you will use the <b>on_epoch_end()</b> method to check the loss at each training epoch.

In [10]:
class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        '''
        Halts the training when the loss falls below 0.4

        Args:
            epoch (integer) - index of epoch (required but unused in the function definition below)
            logs (dict) - metric results from the training epoch
        '''
        #check the loss:
        if(logs.get('loss') < 0.4):
           
            #stop if threshold is met
            print("\nLoss is lower than 0.4 so cancelling training!")
            self.model.stop_training = True

# Instantiate class
callbacks = myCallback()

## define and compile the model

In [11]:
# Define the model
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

# Compile the model
model.compile(optimizer=tf.optimizers.Adam(),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [14]:
model.fit(x_trains, y_trains, epochs=10, callbacks=[callbacks])

Epoch 1/10
Loss is lower than 0.4 so cancelling training!


<keras.src.callbacks.History at 0x21e8a7d5940>