# TensorFlow2教程-自定义回调

自定义回调是一个很好用的工具，可以在训练，评估和推理期间自定义模型的行为，包括读取/更改keras模型等。

In [1]:
from __future__ import absolute_import, division,print_function, unicode_literals
import tensorflow as tf

## 1 Keras回调简介
在Kreas中，Callback是一个python类，旨在被子类化以提供特定功能，并在训练的各阶段（包括每个batch/epoch的开始和结束），以及测试中调用一组方法。

我们可以通过回调列表，传递回调方法，在训练/评估/推断的不同阶段调用回调方法。

构建一个模型

In [2]:
def get_model():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(1, activation = 'linear', input_dim = 784))
    model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])
    return model

导入数据

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

定义一个简单的自定义回调，以跟踪每批数据的开始和结束。

In [4]:
import datetime

class MyCustomCallback(tf.keras.callbacks.Callback):

    def on_train_batch_begin(self, batch, logs=None):
        print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

    def on_train_batch_end(self, batch, logs=None):
        print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

    def on_test_batch_begin(self, batch, logs=None):
        print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

    def on_test_batch_end(self, batch, logs=None):
        print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

在训练时传入回调函数

In [5]:
model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          epochs=1,
          steps_per_epoch=5,
          verbose=0,
          callbacks=[MyCustomCallback()])

Training: batch 0 begins at 22:19:00.580488
Training: batch 0 ends at 22:19:01.284324
Training: batch 1 begins at 22:19:01.284324
Training: batch 1 ends at 22:19:01.286318
Training: batch 2 begins at 22:19:01.286318
Training: batch 2 ends at 22:19:01.288312
Training: batch 3 begins at 22:19:01.288312
Training: batch 3 ends at 22:19:01.290307
Training: batch 4 begins at 22:19:01.290307
Training: batch 4 ends at 22:19:01.291305


### 1.1 以下方法会调用回调函数
fit(), fit_generator()
训练或使用迭代数据进行训练。

evaluate(), evaluate_generator()
评估或使用迭代数据进行评估。

predict(), predict_generator()
预测或使用迭代数据进行预测。

In [6]:
_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=5,
          callbacks=[MyCustomCallback()])

Evaluating: batch 0 begins at 22:19:01.337182
Evaluating: batch 0 ends at 22:19:01.428937
Evaluating: batch 1 begins at 22:19:01.429935
Evaluating: batch 1 ends at 22:19:01.430940
Evaluating: batch 2 begins at 22:19:01.430940
Evaluating: batch 2 ends at 22:19:01.432925
Evaluating: batch 3 begins at 22:19:01.432925
Evaluating: batch 3 ends at 22:19:01.434921
Evaluating: batch 4 begins at 22:19:01.434921
Evaluating: batch 4 ends at 22:19:01.436914


## 2 回调方法概述
### 2.1 训练/测试/预测的常用方法
为了进行训练，测试和预测，提供了以下方法来替代。

on_(train|test|predict)_begin(self, logs=None)
在fit/ evaluate/ predict开始时调用。

on_(train|test|predict)_end(self, logs=None)
在fit/ evaluate/ predict结束时调用。

on_(train|test|predict)_batch_begin(self, batch, logs=None)
在培训/测试/预测期间处理批次之前立即调用。在此方法中，logs是带有batch和size可用键的字典，代表当前批次号和批次大小。

on_(train|test|predict)_batch_end(self, batch, logs=None)
在培训/测试/预测批次结束时调用。在此方法中，logs是一个包含状态指标结果的字典。

### 2.2 训练时特定方法
另外，为了进行培训，提供以下内容。

on_epoch_begin（self，epoch，logs = None）
在训练期间的开始时调用。

on_epoch_end（self，epoch，logs = None）
在训练期间的末尾调用。

### 2.3 logsdict的用法
该logs字典包含损loss，已经每个batch和epoch的结束时的所有指标。示例包括loss和平均绝对误差。

In [7]:
class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):

    def on_train_batch_end(self, batch, logs=None):
        print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

    def on_test_batch_end(self, batch, logs=None):
        print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

    def on_epoch_end(self, epoch, logs=None):
        print('The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.'.format(epoch, logs['loss'], logs['mae']))

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=3,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback()])

For batch 0, loss is   29.96.
For batch 1, loss is  406.26.
For batch 2, loss is  279.36.
For batch 3, loss is  213.05.
For batch 4, loss is  171.98.
The average loss for epoch 0 is  171.98 and mean absolute error is    8.11.
For batch 0, loss is    5.96.
For batch 1, loss is    5.39.
For batch 2, loss is    5.70.
For batch 3, loss is    5.54.
For batch 4, loss is    5.79.
The average loss for epoch 1 is    5.79 and mean absolute error is    1.97.
For batch 0, loss is    7.96.
For batch 1, loss is    7.76.
For batch 2, loss is   10.37.
For batch 3, loss is   13.20.
For batch 4, loss is   14.84.
The average loss for epoch 2 is   14.84 and mean absolute error is    3.18.


同样，可以在evaluate时调用回调。

In [8]:
_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=20,
          callbacks=[LossAndErrorPrintingCallback()])

For batch 0, loss is   28.84.
For batch 1, loss is   28.11.
For batch 2, loss is   27.57.
For batch 3, loss is   28.24.
For batch 4, loss is   28.06.
For batch 5, loss is   28.64.
For batch 6, loss is   28.82.
For batch 7, loss is   28.68.
For batch 8, loss is   28.79.
For batch 9, loss is   28.84.
For batch 10, loss is   28.65.
For batch 11, loss is   28.76.
For batch 12, loss is   28.77.
For batch 13, loss is   28.50.
For batch 14, loss is   28.26.
For batch 15, loss is   27.95.
For batch 16, loss is   28.18.
For batch 17, loss is   28.04.
For batch 18, loss is   28.13.
For batch 19, loss is   28.27.


## 3 keras回调示例
## 3.1 以最小的损失尽早停止
第一个示例展示了Callback通过达到最小损失时更改属性model.stop_training（布尔值），停止Keras训练。用户可以提供一个参数patience来指定训练最终停止之前应该等待多少个时期。

注：tf.keras.callbacks.EarlyStopping 提供了更完整，更通用的实现。

In [9]:
import numpy as np
class EarlyStoppingAtMinLoss(tf.keras.callbacks.Callback):
    def __init__(self, patience=0):
        super(EarlyStoppingAtMinLoss, self).__init__()
        self.patience = patience
        self.best_weights = None  # loss最低时的权重
    def on_train_begin(self, logs=None):
        # loss不再下降时等待的轮数
        self.wait = 0
        # 停止时的轮数
        self.stopped_epoch = 0
        # 开始时的最优loss
        self.best = np.Inf
    
    def on_epoch_end(self, epoch, logs=None):
        current = logs.get('loss')
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            # 最佳权重
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print('导入当前最佳模型')
                self.model.set_weights(self.best_weights)
    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print('在%05d: 提前停止训练'% (self.stopped_epoch+1))
        

In [10]:
model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=30,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()])

For batch 0, loss is   29.74.
For batch 1, loss is  474.35.
For batch 2, loss is  324.32.
For batch 3, loss is  245.87.
For batch 4, loss is  198.54.
The average loss for epoch 0 is  198.54 and mean absolute error is    8.51.
For batch 0, loss is    8.64.
For batch 1, loss is    6.84.
For batch 2, loss is    6.66.
For batch 3, loss is    6.49.
For batch 4, loss is    6.61.
The average loss for epoch 1 is    6.61 and mean absolute error is    2.10.
For batch 0, loss is    6.12.
For batch 1, loss is    6.46.
For batch 2, loss is    5.86.
For batch 3, loss is    5.81.
For batch 4, loss is    5.79.
The average loss for epoch 2 is    5.79 and mean absolute error is    1.91.
For batch 0, loss is    4.88.
For batch 1, loss is    6.08.
For batch 2, loss is    7.50.
For batch 3, loss is    8.78.
For batch 4, loss is   11.92.
The average loss for epoch 3 is   11.92 and mean absolute error is    2.73.
导入当前最佳模型
在00004: 提前停止训练


### 自定义学习率
在模型训练中通常要做的一件事是随着训练轮次改变学习率。Keras后端公开了可用于设置变量的get_value API。在此示例中，我们展示了如何使用自定义的回调来动态更改学习率。

注：这只是示例实现，请参见callbacks.LearningRateScheduler和keras.optimizers.schedules有关更一般的实现。

In [11]:
class LearningRateScheduler(tf.keras.callbacks.Callback):
    def __init__(self, schedule):
        super(LearningRateScheduler, self).__init__()
        self.schedule = schedule
        
    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer没有lr参数。')
        # 获取当前lr
        lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
        # 调整lr
        scheduled_lr = self.schedule(epoch, lr)
        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
        print('Epoch %05d: 学习率为%6.4f.'%(epoch, scheduled_lr))

按轮次调整学习率

In [12]:
LR_SCHEDULE = [
    # (epoch to start, learning rate) tuples
    (3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001)
]

def lr_schedule(epoch, lr):
    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
        return lr
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=15,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback(), LearningRateScheduler(lr_schedule)])



Epoch 00000: 学习率为0.1000.
For batch 0, loss is   22.01.
For batch 1, loss is  364.13.
For batch 2, loss is  248.63.
For batch 3, loss is  189.12.
For batch 4, loss is  153.57.
The average loss for epoch 0 is  153.57 and mean absolute error is    7.55.
Epoch 00001: 学习率为0.1000.
For batch 0, loss is   11.59.
For batch 1, loss is   11.58.
For batch 2, loss is   11.24.
For batch 3, loss is   10.42.
For batch 4, loss is    9.61.
The average loss for epoch 1 is    9.61 and mean absolute error is    2.53.
Epoch 00002: 学习率为0.1000.
For batch 0, loss is    6.96.
For batch 1, loss is    5.86.
For batch 2, loss is    5.69.
For batch 3, loss is    6.14.
For batch 4, loss is    7.78.
The average loss for epoch 2 is    7.78 and mean absolute error is    2.27.
Epoch 00003: 学习率为0.0500.
For batch 0, loss is   26.20.
For batch 1, loss is   16.84.
For batch 2, loss is   13.06.
For batch 3, loss is   11.28.
For batch 4, loss is    9.98.
The average loss for epoch 3 is    9.98 and mean absolute error is    2.