**tf.train.Checkpoint ：變數的保存與還原**

很多時候，我們希望在模型訓練完成後能將訓練好的參數（變數）保存起來。在需要使用模型的其他地方載入模型和參數，就能直接得到訓練好的模型。可能你第一個想到的是用 Python 的序列化模組 pickle 存儲 model.variables。但不幸的是，TensorFlow 的變數類型 ResourceVariable 並不能被序列化。

最後提供一個實例，以前章的 多層感知器模型 為例展示模型變數的保存和載入：

In [None]:
import tensorflow as tf
import numpy as np

class MNISTLoader():
    def __init__(self):
        mnist = tf.keras.datasets.mnist
        (self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()
        # MNIST中的圖片預設為uint8（0-255的數字）。以下程式碼將其正規化到0-1之間的浮點數，並在最後增加一維作為顏色通道
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]
        self.train_label = self.train_label.astype(np.int32)    # [60000]
        self.test_label = self.test_label.astype(np.int32)      # [10000]
        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]

    def get_batch(self, batch_size):
        # 從資料集中隨機取出batch_size個元素並返回
        index = np.random.randint(0, self.num_train_data, batch_size)
        return self.train_data[index, :], self.train_label[index]

class MLP(tf.keras.Model):
  def __init__(self):
      super().__init__()
      self.flatten = tf.keras.layers.Flatten()    # Flatten層將除第一維（batch_size）以外的維度展平
      self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
      self.dense2 = tf.keras.layers.Dense(units=10)

  def call(self, inputs):         # [batch_size, 28, 28, 1]
      x = self.flatten(inputs)    # [batch_size, 784]
      x = self.dense1(x)          # [batch_size, 100]
      x = self.dense2(x)          # [batch_size, 10]
      output = tf.nn.softmax(x)
      return output

num_epochs = 5
batch_size = 50
learning_rate = 0.001

model = MLP()
data_loader = MNISTLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


# train.py 模型訓練階段
model = MyModel()
# 實例化Checkpoint，指定保存對象為model（如果需要保存Optimizer的參數也可加入）
checkpoint = tf.train.Checkpoint(myModel=model)
# ...（模型訓練程式碼）
# 模型訓練完畢後將參數保存到文件（也可以在模型訓練過程中每隔一段時間就保存一次）
checkpoint.save('./save/model.ckpt')

In [None]:
#跑到一半按停止
checkpoint = tf.train.Checkpoint(myAwesomeModel=model)      
# 使用tf.train.CheckpointManager管理Checkpoint
manager = tf.train.CheckpointManager(checkpoint, directory='./save', max_to_keep=3)

num_batches = int(data_loader.num_train_data // batch_size * num_epochs)

for batch_index in range(num_batches):
    X, y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

    path = manager.save(checkpoint_number=batch_index)         
    print("model saved to %s" % path)

[1;30;43m串流輸出內容已截斷至最後 5000 行。[0m
batch 1094: loss 0.196923
model saved to ./save/ckpt-1094
batch 1095: loss 0.140449
model saved to ./save/ckpt-1095
batch 1096: loss 0.064256
model saved to ./save/ckpt-1096
batch 1097: loss 0.252140
model saved to ./save/ckpt-1097
batch 1098: loss 0.058045
model saved to ./save/ckpt-1098
batch 1099: loss 0.095853
model saved to ./save/ckpt-1099
batch 1100: loss 0.147519
model saved to ./save/ckpt-1100
batch 1101: loss 0.239354
model saved to ./save/ckpt-1101
batch 1102: loss 0.111969
model saved to ./save/ckpt-1102
batch 1103: loss 0.315267
model saved to ./save/ckpt-1103
batch 1104: loss 0.181788
model saved to ./save/ckpt-1104
batch 1105: loss 0.207790
model saved to ./save/ckpt-1105
batch 1106: loss 0.138629
model saved to ./save/ckpt-1106
batch 1107: loss 0.253511
model saved to ./save/ckpt-1107
batch 1108: loss 0.265074
model saved to ./save/ckpt-1108
batch 1109: loss 0.085185
model saved to ./save/ckpt-1109
batch 1110: loss 0.362376
model saved

KeyboardInterrupt: ignored

# test.py 模型使用階段

model = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model)             # 實例化Checkpoint，指定還原對象為model
checkpoint.restore(tf.train.latest_checkpoint('./save'))    # 從文件還原模型參數
# 模型使用程式碼

In [None]:
#恢復之前訓練的結果
model = MLP()
checkpoint = tf.train.Checkpoint(myAwesomeModel=model)  
checkpoint.restore(tf.train.latest_checkpoint('./save'))

sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batches = int(data_loader.num_test_data // batch_size)

for batch_index in range(num_batches):
    start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    y_pred = model.predict(data_loader.test_data[start_index: end_index])
    sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)
print("test accuracy: %f" % sparse_categorical_accuracy.result())

test accuracy: 0.970600
