# A Guide to TensorFlow Callbacks
For more details on each function, check out the corresponding post [on the Paperspace blog](https://blog.paperspace.com/tensorflow-callbacks).

### Import Libraries

In [None]:
!pip install tensorflow_datasets

In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_datasets as tfds

### Load data
Here we are using MNIST dataset

In [None]:
train = tfds.load(name="mnist", split="train[:95%]", as_supervised=True)
valid = tfds.load(name="mnist", split="train[95%:]", as_supervised=True)
test = tfds.load(name="mnist", split="test", as_supervised=True)

### Proprocessing the dataset
Normalize, shuffle, batching and prefetching(for faster execution)

In [None]:
def normalize(image,label):
  return tf.cast(image, tf.float16)/255.0,label

train = train.map(normalize)
train = train.shuffle(128)
train = train.batch(64)
train = train.prefetch(tf.data.experimental.AUTOTUNE)

valid = valid.batch(64)
valid=valid.map(normalize)
valid = valid.prefetch(tf.data.experimental.AUTOTUNE)

### Define model

In [None]:
model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=(28,28,1)))
model.add(keras.layers.Dense(256, activation='relu'))
model.add(keras.layers.Dense(10, activation='softmax'))

## CALLBACKS

### CSVLogger

In [None]:
csv_callback = tf.keras.callbacks.CSVLogger("csv.log", append=True)

### EarlyStopping

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0.01, patience=1, verbose=0, mode='auto',
    baseline=None, restore_best_weights=False
)

### Tensorboard

In [None]:
import tensorboard
import os
import datetime
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

### LearningRateScheduler

In [None]:
learning_rate = 0.002
def updateLearningRate(epoch):
  if(epoch>3):
    return 0.002*0.1
  else:
    return learning_rate
learningRate_callback = tf.keras.callbacks.LearningRateScheduler(schedule=updateLearningRate, verbose=1)

### LambdaCallback

In [None]:
def printCustom(batch, logs):
  with open("CustomLogs.txt", "a+") as f:
    f.write(f"Batch is {batch} \n")
    f.write(f"Logs  {logs} \n")

lambda_callback = tf.keras.callbacks.LambdaCallback(on_batch_end=printCustom)

### ReduceLROnPlateau

In [None]:
reduceLR_callback = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.1, patience=1, verbose=1,
    min_delta=0.0001, cooldown=2, min_lr=0
)

### RemoteMonitor

In [None]:
remote_callback = tf.keras.callbacks.RemoteMonitor(root='http://localhost:8000', path='/', field='data',send_as_json=True)

### Compile the Model

In [None]:
model.compile(optimizer=keras.optimizers.Adam(learning_rate), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

### Train the model

In [None]:
history = model.fit(train, epochs=3, validation_data=valid, callbacks=[learningRate_callback, early_stopping]) ### Pass the callbacks you need

### Checking the history object

In [None]:
history.history

In [None]:
history.params