In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

print(tf.__version__)

2.0.0-rc1


## Checkpoint によるモデルの保存

Checkpoint を使ってモデルを保存するやり方のメモ。

```
# TODO: もう少し説明を付け加える
```

In [2]:
# まずは簡単なデータセットを作っておく

def make_toy_dataset(num_data):
    X = np.random.randn(num_data, 3)
    y = 3 * X[:, 0] - 2 * X[:, 1]**3 + 2 * X[:, 2]**2 + 0.5 * np.random.randn(num_data)
    y = y[:, np.newaxis]
    return X.astype(np.float32), y.astype(np.float32)

N_train = 400

# get toy datasets
X_train, y_train = make_toy_dataset(N_train)

# make datasets
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=10000).batch(64)

In [3]:
# 説明用のモデルと訓練関数

def build_model():
    model = keras.models.Sequential([
        keras.layers.Dense(3, activation="relu", input_dim=3),
        keras.layers.Dense(10, activation="relu"),
        keras.layers.Dense(1)
    ])

    return model

def train(model, optimizer, loss_fn, X_batch, y_batch):
    with tf.GradientTape() as tape:
        y_pred = model(X_batch)
        loss = loss_fn(y_batch, y_pred)
    
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    
    return loss

In [4]:
model = build_model()
optimizer = keras.optimizers.SGD(0.01)
loss_fn = keras.losses.MeanSquaredError()
train_loss_metric = keras.metrics.Mean()

print("kernel weights of first layer:\n", model.get_weights()[0])

kernel weights of first layer:
 [[-0.33032894 -0.37823057  0.8178873 ]
 [ 0.3777125   0.27058053  0.6953354 ]
 [ 0.86268115 -0.51806164 -0.26143527]]


In [5]:
# checkpoint を保存しながら学習

checkpoint_prefix = os.path.join(".", "ckpt_dir", "ckpt")
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)

checkpoint.save(checkpoint_prefix)
for epoch in range(2):
    for X_batch, y_batch in train_dataset:
        loss = train(model, optimizer, loss_fn, X_batch, y_batch)
        train_loss_metric(loss)
        
    print("train loss:", train_loss_metric.result())
    print("kernel weights of first layer:\n", model.get_weights()[0])
    print("\n")
    
    train_loss_metric.reset_states()
    
    checkpoint.save(checkpoint_prefix)

train loss: tf.Tensor(125.74383, shape=(), dtype=float32)
kernel weights of first layer:
 [[-0.3704103  -0.30586207  0.89951116]
 [ 0.4682661   0.6152704   0.47357768]
 [ 0.8369156  -0.4063372  -0.252726  ]]


train loss: tf.Tensor(103.58496, shape=(), dtype=float32)
kernel weights of first layer:
 [[-0.4631116  -0.39243665  0.9965098 ]
 [ 0.60907483  0.89838797  0.40517837]
 [ 0.75219005 -0.29455265 -0.27950624]]




In [6]:
os.listdir(os.path.join(".", "ckpt_dir"))

['checkpoint',
 'ckpt-1.data-00000-of-00001',
 'ckpt-1.index',
 'ckpt-2.data-00000-of-00001',
 'ckpt-2.index',
 'ckpt-3.data-00000-of-00001',
 'ckpt-3.index']

In [7]:
status = checkpoint.restore(os.path.join(".", "ckpt_dir", "ckpt-3"))

In [8]:
# 訓練後のweight に一致している
print("kernel weights of first layer:\n", model.get_weights()[0])

kernel weights of first layer:
 [[-0.4631116  -0.39243665  0.9965098 ]
 [ 0.60907483  0.89838797  0.40517837]
 [ 0.75219005 -0.29455265 -0.27950624]]


In [9]:
# 復元できているかの確認、できていなければエラー
status.assert_consumed()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x1c1dcdd4748>

別のモデルを作って、そこに復元してみる

In [10]:
model1 = build_model()
model1.get_weights()[0]

array([[-0.40004373, -0.39759374,  0.6263609 ],
       [-0.40927815, -0.60612726, -0.6868794 ],
       [-0.18071747, -0.99379706,  0.54345655]], dtype=float32)

In [11]:
checkpoint1 = tf.train.Checkpoint(model=model1)
status1 = checkpoint1.restore(os.path.join(".", "ckpt_dir", "ckpt-3"))

In [12]:
# 訓練後のモデルのweight と同じ
print("kernel weights of first layer:\n", model1.get_weights()[0])

kernel weights of first layer:
 [[-0.4631116  -0.39243665  0.9965098 ]
 [ 0.60907483  0.89838797  0.40517837]
 [ 0.75219005 -0.29455265 -0.27950624]]


In [13]:
# optimizer のパラメータは復元できていないのでエラー
status1.assert_consumed()

AssertionError: Unresolved object in checkpoint (root).optimizer: children {
  node_id: 8
  local_name: "iter"
}
children {
  node_id: 9
  local_name: "decay"
}
children {
  node_id: 10
  local_name: "learning_rate"
}
children {
  node_id: 11
  local_name: "momentum"
}


In [14]:
# こっちのassert にはパスする
status1.assert_existing_objects_matched()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x1c1e9301f98>