# ModelCheckpoint Callback

## 模型

In [2]:
import tensorflow as tf

mnist = tf.keras.datasets.mnist

# 載入 MNIST 手寫阿拉伯數字資料
(x_train, y_train),(x_test, y_test) = mnist.load_data()

# 特徵縮放，使用常態化(Normalization)，公式 = (x - min) / (max - min)
x_train_norm, x_test_norm = x_train / 255.0, x_test / 255.0

# 建立模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Input((28, 28)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

# 設定優化器(optimizer)、損失函數(loss)、效能衡量指標(metrics)的類別
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])


## ModelCheckpoint callback

In [3]:
checkpoint_filepath = 'model.{epoch:02d}.weights.h5' # 使用 f-string 變數
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath, # 設定存檔名稱
    save_weights_only=True,       # 只存權重
    monitor='val_accuracy',       # 監看驗證資料的準確率
    mode='max',           # 設定save_best_only=True時，best是指 max or min
    save_best_only=True)          # 只存最佳模型

EPOCHS = 3  # 訓練 3 次
model.fit(x_train_norm, y_train, epochs=EPOCHS, validation_split=0.2, 
          callbacks=[model_checkpoint_callback])

Epoch 1/3
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - accuracy: 0.8665 - loss: 0.4583 - val_accuracy: 0.9608 - val_loss: 0.1365
Epoch 2/3
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - accuracy: 0.9611 - loss: 0.1318 - val_accuracy: 0.9703 - val_loss: 0.1001
Epoch 3/3
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step - accuracy: 0.9733 - loss: 0.0873 - val_accuracy: 0.9754 - val_loss: 0.0839


<keras.src.callbacks.history.History at 0x17de0e220c0>

In [4]:
# 再訓練 3 次，觀察 accuracy，會接續上一次，繼續改善 accuracy。
model.fit(x_train_norm, y_train, epochs=EPOCHS, validation_split=0.2,  
          callbacks=[model_checkpoint_callback])

Epoch 1/3
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - accuracy: 0.9785 - loss: 0.0685 - val_accuracy: 0.9749 - val_loss: 0.0794
Epoch 2/3
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step - accuracy: 0.9828 - loss: 0.0551 - val_accuracy: 0.9770 - val_loss: 0.0740
Epoch 3/3
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step - accuracy: 0.9861 - loss: 0.0442 - val_accuracy: 0.9781 - val_loss: 0.0756


<keras.src.callbacks.history.History at 0x17dc1f6fda0>