In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

In [2]:
class Net(tf.keras.Model):
    """A simple linear model."""
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = tf.keras.layers.Dense(5)
        
    def call(self, x):
        return self.l1(x)

In [3]:
net = Net()
print(net)

<__main__.Net object at 0x10b7fd4a8>


In [4]:
net.save_weights('easy_checkpoint')

In [5]:
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None,:]

print('inputs', inputs)
print('labels', labels)

inputs tf.Tensor(
[[0.]
 [1.]
 [2.]
 [3.]
 [4.]
 [5.]
 [6.]
 [7.]
 [8.]
 [9.]], shape=(10, 1), dtype=float32)
labels tf.Tensor(
[[ 0.  1.  2.  3.  4.]
 [ 5.  6.  7.  8.  9.]
 [10. 11. 12. 13. 14.]
 [15. 16. 17. 18. 19.]
 [20. 21. 22. 23. 24.]
 [25. 26. 27. 28. 29.]
 [30. 31. 32. 33. 34.]
 [35. 36. 37. 38. 39.]
 [40. 41. 42. 43. 44.]
 [45. 46. 47. 48. 49.]], shape=(10, 5), dtype=float32)


In [6]:
def toy_dataset():
    inputs = tf.range(10.)[:, None]
    labels = inputs * 5. + tf.range(5.)[None,:]
    return tf.data.Dataset.from_tensor_slices(
        dict(x=inputs, y=labels)).repeat(10).batch(2)

In [7]:
def train_step(net, example, optimizer):
    """Trains `net` on `example` using `optimizer`."""
    with tf.GradientTape() as tape:
        output = net(example['x'])
        loss = tf.reduce_mean(tf.abs(output-example['y']))
    variables = net.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))
    return loss
        

In [13]:
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

In [14]:
def train_and_checkpoint(net, manager):
    ckpt.restore(manager.latest_checkpoint)
    if (manager.latest_checkpoint):
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")
    
    for example in toy_dataset():
        loss = train_step(net, example, opt)
        ckpt.step.assign_add(1)
        if int(ckpt.step) % 10 == 0:
            save_path = manager.save()
            print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
        print("loss {:1.2f}".format(loss.numpy()))

In [15]:
train_and_checkpoint(net, manager)

Initializing from scratch.
loss 0.04
loss 0.30
loss 0.23
loss 0.18
loss 0.15
loss 0.06
loss 0.16
loss 0.33
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 0.42
loss 0.36
loss 0.08
loss 0.16
loss 0.29
loss 0.42
loss 0.46
loss 0.08
loss 0.09
loss 0.12
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 0.24
loss 0.26
loss 0.04
loss 0.08
loss 0.18
loss 0.28
loss 0.26
loss 0.08
loss 0.14
loss 0.24
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 0.35
loss 0.41
loss 0.04
loss 0.06
loss 0.13
loss 0.21
loss 0.25
loss 0.05
loss 0.08
loss 0.15
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 0.23
loss 0.20
loss 0.03
loss 0.10
loss 0.21
loss 0.31
loss 0.31
loss 0.04
loss 0.10
loss 0.20
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 0.28
loss 0.32


In [16]:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

In [17]:
train_and_checkpoint(net, manager)

Restored from ./tf_ckpts/ckpt-5
loss 0.06
loss 0.10
loss 0.12
loss 0.09
loss 0.14
loss 0.09
loss 0.13
loss 0.21
loss 0.25
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 0.29
loss 0.05
loss 0.06
loss 0.15
loss 0.26
loss 0.33
loss 0.05
loss 0.07
loss 0.07
loss 0.14
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.20
loss 0.09
loss 0.09
loss 0.10
loss 0.15
loss 0.17
loss 0.06
loss 0.08
loss 0.13
loss 0.16
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.16
loss 0.03
loss 0.04
loss 0.12
loss 0.19
loss 0.21
loss 0.04
loss 0.07
loss 0.13
loss 0.18
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.14
loss 0.08
loss 0.11
loss 0.24
loss 0.36
loss 0.40
loss 0.06
loss 0.06
loss 0.16
loss 0.26
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.31


In [18]:
print(manager.checkpoints)

['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']


checkpoint                  ckpt-8.index
ckpt-10.data-00000-of-00001 ckpt-9.data-00000-of-00001
ckpt-10.index               ckpt-9.index
ckpt-8.data-00000-of-00001
