# Keras回调函数详解

本教程介绍Keras的回调机制，用于在训练过程中执行自定义操作。

## 学习目标

1. 理解回调函数的工作机制
2. 掌握常用内置回调的使用方法
3. 学会实现自定义回调
4. 了解回调函数的最佳实践

## 1. 环境配置与数据准备

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 设置随机种子
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

print(f"TensorFlow版本: {tf.__version__}")

In [None]:
# 准备数据
housing = fetch_california_housing()

X_train_full, X_test, y_train_full, y_test = train_test_split(
    housing.data, housing.target, test_size=0.2, random_state=RANDOM_SEED
)
X_train, X_valid, y_train, y_valid = train_test_split(
    X_train_full, y_train_full, test_size=0.25, random_state=RANDOM_SEED
)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_valid = scaler.transform(X_valid)
X_test = scaler.transform(X_test)

print(f"训练集: {X_train.shape[0]} 样本")
print(f"验证集: {X_valid.shape[0]} 样本")
print(f"测试集: {X_test.shape[0]} 样本")

In [None]:
def create_model():
    """创建回归模型"""
    model = keras.Sequential([
        keras.layers.Dense(30, activation='relu', input_shape=[8]),
        keras.layers.Dense(30, activation='relu'),
        keras.layers.Dense(1)
    ])
    model.compile(
        loss='mse',
        optimizer=keras.optimizers.Adam(learning_rate=1e-3),
        metrics=['mae']
    )
    return model

## 2. ModelCheckpoint - 模型检查点

自动保存训练过程中的最佳模型或定期保存检查点。

In [None]:
model = create_model()

# ModelCheckpoint回调配置
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    filepath='checkpoints/best_model.h5',
    monitor='val_loss',        # 监控指标
    save_best_only=True,       # 只保存最佳模型
    save_weights_only=False,   # 保存完整模型
    mode='min',                # 监控指标越小越好
    verbose=1
)

# 训练模型
history = model.fit(
    X_train, y_train,
    epochs=30,
    validation_data=(X_valid, y_valid),
    callbacks=[checkpoint_cb],
    verbose=1
)

In [None]:
# 加载最佳模型进行评估
best_model = keras.models.load_model('checkpoints/best_model.h5')
test_loss, test_mae = best_model.evaluate(X_test, y_test, verbose=0)
print(f"最佳模型测试集 - MSE: {test_loss:.4f}, MAE: {test_mae:.4f}")

## 3. EarlyStopping - 早停

当监控指标不再改善时提前停止训练，防止过拟合。

In [None]:
model = create_model()

# EarlyStopping回调配置
early_stopping_cb = keras.callbacks.EarlyStopping(
    monitor='val_loss',           # 监控验证损失
    patience=10,                  # 连续10个epoch没有改善则停止
    min_delta=0.001,              # 改善阈值
    restore_best_weights=True,    # 恢复最佳权重
    verbose=1
)

# 训练模型
history = model.fit(
    X_train, y_train,
    epochs=100,  # 设置较大的epoch数，让早停来决定何时停止
    validation_data=(X_valid, y_valid),
    callbacks=[early_stopping_cb],
    verbose=1
)

print(f"\n实际训练轮数: {len(history.history['loss'])}")

## 4. 组合使用回调函数

结合ModelCheckpoint和EarlyStopping是最佳实践。

In [None]:
model = create_model()

# 组合回调函数
callbacks = [
    # 保存最佳模型
    keras.callbacks.ModelCheckpoint(
        'checkpoints/combined_best.h5',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
    # 早停
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    )
]

history = model.fit(
    X_train, y_train,
    epochs=100,
    validation_data=(X_valid, y_valid),
    callbacks=callbacks,
    verbose=1
)

# 评估
test_loss, test_mae = model.evaluate(X_test, y_test, verbose=0)
print(f"\n测试集 - MSE: {test_loss:.4f}, MAE: {test_mae:.4f}")

## 5. ReduceLROnPlateau - 学习率调度

当指标停止改善时自动降低学习率。

In [None]:
model = create_model()

# 学习率调度回调
reduce_lr_cb = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,              # 学习率降低因子
    patience=5,              # 等待改善的epoch数
    min_lr=1e-6,             # 最小学习率
    verbose=1
)

callbacks = [
    keras.callbacks.EarlyStopping(patience=15, restore_best_weights=True),
    reduce_lr_cb
]

history = model.fit(
    X_train, y_train,
    epochs=100,
    validation_data=(X_valid, y_valid),
    callbacks=callbacks,
    verbose=1
)

## 6. LearningRateScheduler - 自定义学习率调度

按照自定义函数调整学习率。

In [None]:
def exponential_decay(epoch, lr):
    """
    指数衰减学习率调度
    
    Parameters:
    -----------
    epoch : int
        当前epoch
    lr : float
        当前学习率
    
    Returns:
    --------
    float : 新学习率
    """
    if epoch < 10:
        return lr
    else:
        return lr * 0.9

model = create_model()

lr_scheduler_cb = keras.callbacks.LearningRateScheduler(exponential_decay, verbose=1)

history = model.fit(
    X_train, y_train,
    epochs=30,
    validation_data=(X_valid, y_valid),
    callbacks=[lr_scheduler_cb],
    verbose=1
)

In [None]:
# 可视化学习率变化
if 'lr' in history.history:
    plt.figure(figsize=(10, 4))
    plt.plot(history.history['lr'])
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.grid(True, alpha=0.3)
    plt.show()

## 7. 自定义回调函数

通过继承`keras.callbacks.Callback`创建自定义回调。

In [None]:
class TrainingMonitor(keras.callbacks.Callback):
    """
    自定义训练监控回调
    
    功能：
    - 记录每个epoch的训练状态
    - 在验证损失改善时打印消息
    - 跟踪最佳性能
    """
    
    def on_train_begin(self, logs=None):
        """训练开始时调用"""
        self.best_val_loss = float('inf')
        self.best_epoch = 0
        print("训练开始...")
    
    def on_epoch_end(self, epoch, logs=None):
        """每个epoch结束时调用"""
        val_loss = logs.get('val_loss')
        if val_loss < self.best_val_loss:
            improvement = self.best_val_loss - val_loss
            self.best_val_loss = val_loss
            self.best_epoch = epoch
            print(f"  Epoch {epoch+1}: 验证损失改善 {improvement:.4f}")
    
    def on_train_end(self, logs=None):
        """训练结束时调用"""
        print(f"\n训练完成!")
        print(f"最佳验证损失: {self.best_val_loss:.4f} (Epoch {self.best_epoch+1})")

In [None]:
model = create_model()

# 使用自定义回调
custom_cb = TrainingMonitor()

history = model.fit(
    X_train, y_train,
    epochs=20,
    validation_data=(X_valid, y_valid),
    callbacks=[custom_cb],
    verbose=0  # 关闭默认输出
)

## 8. 可视化训练过程

In [None]:
# 使用完整回调组合进行最终训练
model = create_model()

callbacks = [
    keras.callbacks.ModelCheckpoint(
        'checkpoints/final_best.h5',
        monitor='val_loss',
        save_best_only=True
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-6
    )
]

history = model.fit(
    X_train, y_train,
    epochs=100,
    validation_data=(X_valid, y_valid),
    callbacks=callbacks,
    verbose=1
)

In [None]:
# 绘制训练曲线
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# 损失曲线
axes[0].plot(history.history['loss'], label='Training')
axes[0].plot(history.history['val_loss'], label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss (MSE)')
axes[0].set_title('Loss Curves')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# MAE曲线
axes[1].plot(history.history['mae'], label='Training')
axes[1].plot(history.history['val_mae'], label='Validation')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('MAE')
axes[1].set_title('MAE Curves')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# 最终评估
test_loss, test_mae = model.evaluate(X_test, y_test, verbose=0)
print(f"\n最终模型测试集 - MSE: {test_loss:.4f}, MAE: {test_mae:.4f}")

## 小结

### 常用回调函数

| 回调 | 功能 | 典型用法 |
|------|------|----------|
| ModelCheckpoint | 保存模型 | 保存最佳模型 |
| EarlyStopping | 早停 | 防止过拟合 |
| ReduceLROnPlateau | 降低学习率 | 突破训练瓶颈 |
| LearningRateScheduler | 自定义学习率 | 实现特定调度策略 |
| TensorBoard | 可视化 | 监控训练过程 |

### 最佳实践

1. **始终使用ModelCheckpoint**: 避免因意外中断丢失进度
2. **结合EarlyStopping使用restore_best_weights=True**: 确保获得最佳模型
3. **设置合理的patience值**: 太小会过早停止，太大会浪费时间
4. **监控val_loss而非loss**: 关注泛化能力而非训练表现