## 准备数据

In [1]:
import warnings
warnings.simplefilter("ignore")

In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, models, callbacks
from tensorflow.keras.datasets import mnist

import os
file_path = os.path.abspath('./mnist.npz')

(train_x, train_y), (test_x, test_y) = datasets.mnist.load_data(path=file_path)
train_y, test_y = train_y[:1000], test_y[:1000]
train_x = train_x[:1000].reshape(-1, 28 * 28) / 255.0
test_x = test_x[:1000].reshape(-1, 28 * 28) / 255.0

## 搭建模型

In [3]:
def create_model():
    model = models.Sequential([
        layers.Dense(512, activation='relu', input_shape=(784,)),
        layers.Dropout(0.2),
        layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer='adam', metrics=['accuracy'],
                  loss='sparse_categorical_crossentropy')

    return model

def evaluate(target_model):
    _, acc = target_model.evaluate(test_x, test_y)
    print("Restore model, accuracy: {:5.2f}%".format(100*acc))

## 自动保存 checkpoints

In [4]:
# 存储模型的文件名，语法与 str.format 一致
# period=10：每 10 epochs 保存一次
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True, period=10)

model = create_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_x, train_y, epochs=50, callbacks=[cp_callback],
          validation_data=(test_x, test_y), verbose=0)

W0713 00:06:20.997914 140735530943360 callbacks.py:859] `period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.
W0713 00:06:21.173481 140735530943360 deprecation.py:323] From /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where



Epoch 00010: saving model to training_2/cp-0010.ckpt

Epoch 00020: saving model to training_2/cp-0020.ckpt

Epoch 00030: saving model to training_2/cp-0030.ckpt

Epoch 00040: saving model to training_2/cp-0040.ckpt

Epoch 00050: saving model to training_2/cp-0050.ckpt


<tensorflow.python.keras.callbacks.History at 0x13de39e80>

In [5]:
latest = tf.train.latest_checkpoint(checkpoint_dir)
# 'training_2/cp-0050.ckpt'
model = create_model()
model.load_weights(latest)
evaluate(model)

Restore model, accuracy: 87.20%


## 手动保存权重

In [6]:
# 手动保存权重
model.save_weights('./checkpoints/mannul_checkpoint')
model = create_model()
model.load_weights('./checkpoints/mannul_checkpoint')
evaluate(model)

Restore model, accuracy: 87.20%


## 保存整个模型

### HDF5

In [7]:
model.save('my_model.h5')

In [8]:
new_model = models.load_model('my_model.h5')
evaluate(new_model)

W0713 00:06:28.440529 140735530943360 hdf5_format.py:192] Error in loading the saved optimizer state. As a result, your model is starting with a freshly initialized optimizer.


Restore model, accuracy: 87.20%


### saved_model

In [9]:
import time
saved_model_path = "./saved_models/{}".format(int(time.time()))
tf.keras.experimental.export_saved_model(model, saved_model_path)
new_model = tf.keras.experimental.load_from_saved_model(saved_model_path)
model.predict(test_x).shape

W0713 00:06:29.094913 140735530943360 deprecation.py:323] From /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
W0713 00:06:29.096133 140735530943360 export_utils.py:182] Export includes no default signature!
W0713 00:06:29.277875 140735530943360 util.py:244] Unresolved object in checkpoint: (root).optimizer.iter
W0713 00:06:29.278687 140735530943360 util.py:244] Unresolved object in checkpoint: (root).optimizer.beta_1
W0713 00:06:29.279356 140735530943360 util.py:244] Unresolved object in checkpoint: (root).optimizer.beta_2
W0713 00:06:29.280627 140735530943360 util.py:244] Unres

(1000, 10)

In [10]:
new_model.compile(optimizer=model.optimizer,
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
evaluate(new_model)

Restore model, accuracy: 87.20%
