# Using Callbacks to Control Training

- We will use the Callbacks API to stop training when a specified metric is met.
- 임계값에 도달하면 모든 epochs을 돌지 않아도 되는 유용한 기능임.

# Load and Normalize the Fashion MNIST dataset

In [1]:
import tensorflow as tf

# instantiate the dataset API
fmnist = tf.keras.datasets.fashion_mnist

# Load the dataset
(x_train, y_train), (x_test, y_test) = fmnist.load_data()

# Normalize the pixel values
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


# Creating a Callback class

- Create a callback by defining a class that inherits the `tf.keras.callbacks.Callback` back class.
- 여기에서는 각각의 training epoch에서 정확도를 달성했는지 체크하기 위해 on_epoch_end()를 사용할 것이다.

In [3]:
class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):

    # 정확도 달성했는지 체크
    if (logs.get('accuracy')>0.6):

      # Stop if threshold is met
      print("\nReached 60% accuracy so canceling training!")
      self.model.stop_training = True

# instantiate class
callbacks = myCallback()

# Define and compile the model

In [4]:
# Define the model
model = tf.keras.models.Sequential([
                                    tf.keras.layers.Flatten(input_shape=(28, 28)),
                                    tf.keras.layers.Dense(512, activation = tf.nn.relu),
                                    tf.keras.layers.Dense(10, activation = tf.nn.softmax)                                   
])

# Compile the model
model.compile(optimizer = tf.optimizers.Adam(),
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
- 10 에폭을 전부 돌지 않고, 정확도가 60%를 달성하면 멈춤

In [5]:
# 각각의 call back 마다 callbacks function을 실행할 수 있음
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

Epoch 1/10
Reached 60% accuracy so canceling training!


<keras.callbacks.History at 0x7f0da6504150>