## Tensorflow tutorial 1: Models and training

Don't panic, read the docs (https://www.tensorflow.org/beta)

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import tqdm_notebook as tqdm

### 1) Classifying hand-written digits / MNIST dataset
Contents:
- `tf.keras.models.Sequential`
- `tf.keras.layers.*`

- Obtain data
- Constructing a model
- Choose a loss function
- Training
- Validation

In [None]:
mnist = tf.keras.datasets.mnist

Load data

In [None]:
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = np.array(x_train, dtype=np.float32)
x_test = np.array(x_test, dtype=np.float32)

First dimension identifies sample

In [None]:
x_train.shape, y_train.shape

In [None]:
plt.matshow(x_train[0])

In [None]:
y_train[0]

Formulate a model $f_w(x)$ and attach loss function and observables

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.summary()

Note that we use the __Sequential__ API here, i.e. we define the computation graph when `model` is constructed. Tensorflow (Keras) also supports a __functional__ API (https://www.tensorflow.org/beta/guide/keras/functional) in which you procedurally build the graph (potentially with branches), and then construct a model from this.

Model is randomly initialized, applying the forward pass to the image of the five will not yield a reliable prediction.

In [None]:
prediction = model.apply(x_train[0:1])
plt.bar(range(0,10), prediction[0])
plt.xlabel("Digit")
plt.ylabel("Assignment probability")

Train the model

In [None]:
model.fit(x_train, y_train, epochs=3)

Apply model again to the image of the five

In [None]:
prediction = model.apply(x_train[0:1])
plt.bar(range(0,10), prediction[0])
plt.xlabel("Digit")
plt.ylabel("Assignment probability");

More interestingly we can evaluate the performance on unseen data

In [None]:
model.evaluate(x_test, y_test)

### 2) MNIST using convolutional layers
Contents:
- functional api to construct models, `tf.keras.Input`
- `tf.keras.layers.Conv2D`
- `tf.keras.optimizers`
- `tf.keras.losses`
- `tf.keras.metrics`
- `tf.data.Dataset`

- Constructing a model (convolutions better suited for images compared to dense)
- Validation

![](https://upload.wikimedia.org/wikipedia/commons/6/63/Typical_cnn.png)
image taken from wikipedia

We use a convolutional layer with 32 filters, each one with a size 3x3

In [None]:
inputs = tf.keras.Input(shape=(28, 28, 1))
inputs.shape

In [None]:
functional_api = True
if not functional_api:
    conv_model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
else:
    inputs = tf.keras.Input(shape=(28, 28, 1))
    x = tf.keras.layers.Conv2D(32, 3, activation='relu')(inputs)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(128, activation='relu')(x)
    outputs = tf.keras.layers.Dense(10, activation='softmax')(x)

    conv_model = tf.keras.Model(inputs=inputs, outputs=outputs)

conv_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                   loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                   metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [None]:
conv_model.summary()

We now use the `tf.data.Dataset`, which generalizes data coming from different sources. Here we will simply consume our `np.ndarray` objects. The `Dataset` supports functions like [`apply`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset#apply), [`batch`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset#shuffle), [`map`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset#map)

In [None]:
# Construct the dataset from numpy arrays
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

# Add channel dimension
train_dataset = train_dataset.map(lambda x, y: (tf.expand_dims(x, -1), y))
test_dataset= test_dataset.map(lambda x, y: (tf.expand_dims(x, -1), y))

# which on direct numpy data would be
x_train2 = x_train[..., np.newaxis]
x_test2 = x_test[..., np.newaxis]

# Batch data
train_dataset = train_dataset.shuffle(len(x_train)).batch(64)
test_dataset= test_dataset.batch(64)

In [None]:
image, label = next(train_dataset.__iter__())
prediction = conv_model.apply(image[0:1])
plt.bar(range(0,10), prediction[0])
plt.xlabel("Digit")
plt.ylabel("Assignment probability")
print("Label", label[0:1])

In [None]:
conv_model.fit(train_dataset, epochs=2)

In [None]:
conv_model.evaluate(test_dataset)

### 3) Linear regression, example Hooke's law of a two-dimensional oscillator
- subclass `tf.keras.Model`
- `tf.keras.Model.save_weights`
- `tf.keras.Model.load_weights`
- Constructing a (linear) model

Given a point mass is attached to two springs such that it can oscillate in two orthogonal directions  independently. Given measurements of the force on the point mass at given excitations $x$, we want to find a model that predicts the force from $x$. Our model for the resetting force $F\in\mathbb{R}^2$ is a linear dependence on the displacement vector $x\in\mathbb{R}^2$, i.e. $F(x)=Wx$, where $W\in\mathbb{R}^{2\times2}$ is a matrix that contains the spring constants. (Model is both linear in $x$ and $W$)

In [None]:
# alternatively try the '2d-hooke-rotated.npz'
with np.load("2d-hooke.npz") as file:
    xs, ys = file["xs"], file["ys"]

Let's inspect the data

In [None]:
xs.shape, ys.shape

In [None]:
_ = plt.hist2d(xs[:,0], xs[:,1])

In [None]:
_ = plt.hist2d(ys[:,0], ys[:,1])

We can also inspect correlations in the data

In [None]:
import pandas as pd
import seaborn as sns
df = pd.DataFrame({"x1": xs[:,0], "x2": xs[:,1], "y1": ys[:,0], "y2": ys[:,1]})
sns.pairplot(df)

In [None]:
n_samples=len(xs)
n_test = n_samples // 10
n_train = n_samples - n_test
n_samples, n_train, n_test 

Here we will formulate our own model, which here consists out of some trainable parameters (contained in `self.dense`) and an implementation of the forward pass `call(self, inputs)`

In [None]:
class MyModel(tf.keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()
        self.dense = tf.keras.layers.Dense(2, use_bias=False)

    def call(self, inputs):
        return self.dense(inputs)

model = MyModel()

In order to optimize the model to fit our data, we have to specify an optimization criterion. In our case we aim to minimize the mean squared error of our model $F_W=f_w$ on the observed data $(x_n, y_n), ~ n \in [N]$ with respect to the parameters $w$:

$$w_{*} = \arg\min_{w} \tfrac{1}{N} \sum_{n=1}^{N}(f_{w}(x_i) - y_i)^{2}$$

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(1e-2),
              loss=tf.keras.losses.mean_squared_error,
              metrics=[tf.keras.metrics.mean_squared_error])
model.build((None, 2))
model.summary()

In [None]:
model.fit(xs[:n_train], ys[:n_train], epochs=200,
          validation_data=(xs[n_train:], ys[n_train:]))

In [None]:
model.metrics[0].result().numpy()

In [None]:
model.dense.weights

In [None]:
n = 100
x = np.linspace(-0.5,0.5,n)
y = np.linspace(-0.5,0.5,n)
X, Y = np.meshgrid(x,y)

samples = np.stack((X,Y), axis=-1)
samples_flat = np.reshape(samples, (n*n, 2))
res_flat = model.apply(samples_flat)
res = np.reshape(res_flat, (n,n,2))
z1 = res[..., 0]
z2 = res[..., 1]

In [None]:
pcol = plt.pcolor(X, Y, z1, vmin=-2.0,vmax=2.0)
plt.contour(X, Y, z1, colors="white", linestyles="-")
cbar = plt.colorbar(pcol)
cbar.ax.set_ylabel(r"Prediction $y_1$")
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")

In [None]:
pcol = plt.pcolor(X, Y, z2, vmin=-2.0,vmax=2.0)
plt.contour(X, Y, z2, colors="white", linestyles="-")
cbar = plt.colorbar(pcol)
cbar.ax.set_ylabel(r"Prediction $y_2$")
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")

Let's save the weights to file for later usage. In particular this will write a single [checkpoint](https://www.tensorflow.org/beta/guide/checkpoints).

In [None]:
model.save_weights("./hooke-model")

In [None]:
!ls | grep hooke-model

In [None]:
loaded = MyModel()
loaded.build((None, 2))
loaded.load_weights("./hooke-model")
loaded.dense.weights

### 4) Overfitting
- Checkpointing with `keras.callbacks.ModelCheckpoint`
- Visualize training metrics using `history`

In the previous example we knew exactly how many parameters are required to reconstruct the given data (up to noise). In the general case, the optimal complexity of a model is not known _a priori_.

To demonstrate this let's use a model with much more parameters than there are datapoints. In other words, the model is much too expressive for the problem and the amount of data we have.

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(256, activation=tf.nn.relu),
    tf.keras.layers.Dense(256, activation=tf.nn.relu),
    tf.keras.layers.Dense(256, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [None]:
model.summary()

In [None]:
checkpoint_dir = "./overfit_ckpts"
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(checkpoint_dir, "of-model"))

In [None]:
history = model.fit(x_train[::1000], y_train[::1000], epochs=40, validation_data=(x_test, y_test), 
                    callbacks=[model_checkpoint], batch_size=64)

In [None]:
history.epoch

In [None]:
history.history.keys()

In [None]:
fig, axes = plt.subplots(1,2, figsize=(8,3))

plt.sca(axes[0])
plt.plot(history.epoch, history.history["loss"], label="train loss")
plt.plot(history.epoch, history.history["val_loss"], label="val loss")
plt.legend()
plt.xlabel("epochs")
plt.ylabel("loss")
#plt.yscale("log")

plt.sca(axes[1])
plt.plot(history.epoch, history.history["sparse_categorical_accuracy"],
         label="train accuracy")
plt.plot(history.epoch, history.history["val_sparse_categorical_accuracy"], 
         label="val accuracy")
plt.legend()
plt.xlabel("epochs")
plt.ylabel("accuracy")

fig.tight_layout()

While the performance on the training set always improves, the performance on "unseen" data declines after some time. This indicates __overfitting__. The model is too complex/expressive/rich in parameters and can learn the training set by heart. Hence it also occurs when you have too few data.

The textbook analogy is doing a polynomial fit, with as many trainable parameters as datapoints.

<img src=https://qph.fs.quoracdn.net/main-qimg-28d4d605380ee139f5079e18bacdf630 width="300">
image taken from quora

Generally overfitting is avoided by choosing a simpler model.
### 5) Underfitting
Let's consider a very simple model with few parameters. In other words, a model that is not expressive enough to fulfill the desired function.

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.MaxPool2D(pool_size=(8,8), input_shape=(28, 28, 1)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [None]:
model.summary()

In [None]:
history = model.fit(x_train2, y_train, epochs=10, validation_data=(x_test2, y_test))

In [None]:
fig, axes = plt.subplots(1,2, figsize=(8,3))

plt.sca(axes[0])
plt.plot(history.epoch, history.history["loss"], label="train loss")
plt.plot(history.epoch, history.history["val_loss"], label="val loss")
plt.legend()
plt.xlabel("epochs")
plt.ylabel("loss")
#plt.yscale("log")

plt.sca(axes[1])
plt.plot(history.epoch, history.history["sparse_categorical_accuracy"],
         label="train accuracy")
plt.plot(history.epoch, history.history["val_sparse_categorical_accuracy"], 
         label="val accuracy")
plt.legend()
plt.xlabel("epochs")
plt.ylabel("accuracy")

fig.tight_layout()

Here we certainly do not overfit, but we cannot even make good predictions on the training data. This is __underfitting__. The model we chose has too little complexity/is too strongly biased.

E.g. consider the prediction of the digit 'five':

In [None]:
plt.bar(range(10), model.apply(x_train2[0:1]).numpy()[0])

### 6) Regularization
- `tf.train.Checkpoint`
- write own training procedure
- eager execution

Oftentimes it is not clear, which simple model is the right one, i.e. the optimal model bias is not known. Starting from a complex model, __regularization__ prevents overfitting by introducing a systematic bias. Most important regularization methods are:
- Dropout
- Early stopping (`tf.keras.callbacks.EarlyStopping`)
- L1 or L2 penalty on parameters

Going back to the very first MNIST model, where we used dropout.

Note that we will now write our own training procedure which enables us to use tensorflows `tf.train.Checkpoint` and `tf.train.CheckpointManager` (without relying on keras' callback specification).

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dropout(0.4),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.summary()

Gather `model`, `optimizer`, `losses` and `metrics`. Note that all these objects are _stateful_. E.g.`SparseCategoricalCrossentropy` is not simply a function but an object that computes and stores the result. This is why we need one metric for each: training loss and validation loss.

In [None]:
model.build((64, 28, 28, 1))

optimizer = tf.keras.optimizers.Adam()

loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

train_loss = tf.keras.metrics.SparseCategoricalCrossentropy(name='loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='sparse_categorical_accuracy')

test_loss = tf.keras.metrics.SparseCategoricalCrossentropy(name='val_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='val_sparse_categorical_accuracy')

In [None]:
def train_step(images, labels):
    """
    Predicts the output of `images`, calculates and applies gradients to model parameters.
    Also calculates train metrics
    """
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    # calculate metrics
    train_accuracy(labels, predictions)
    train_loss(labels, predictions)

In [None]:
def test_step(images, labels):
    """Evaluates prediction on given data and calculates test metrics"""
    predictions = model(images, training=False)
    # calculate metrics
    test_loss(labels, predictions)
    test_accuracy(labels, predictions)

In [None]:
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)

In [None]:
checkpoint_dir = "./regularization_ckpts/"
manager = tf.train.CheckpointManager(checkpoint, directory=checkpoint_dir, max_to_keep=None)
status = checkpoint.restore(manager.latest_checkpoint)

In [None]:
# poor man's history
history = {"epoch": [], "train_loss": [], "train_accuracy": [], "test_loss": [], "test_accuracy": []}

for epoch in tqdm(range(10)):
    for images, labels in train_dataset:
        train_step(images, labels)

    for test_images, test_labels in test_dataset:
        test_step(test_images, test_labels)

    ckpt_path = manager.save()
    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}, ckpt {}'
    print(template.format(epoch+1,
                         train_loss.result(),
                         tf.round(train_accuracy.result()*1000),
                         test_loss.result(),
                         tf.round(test_accuracy.result()*1000),
                         ckpt_path))
    
    history["epoch"].append(epoch)
    history["train_loss"].append(train_loss.result())
    history["train_accuracy"].append(train_accuracy.result())
    history["test_loss"].append(test_loss.result())
    history["test_accuracy"].append(test_accuracy.result())
    
    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

In [None]:
image, label = next(train_dataset.__iter__())
prediction = model(image[0:1], training=False)
plt.bar(range(0,10), prediction[0])
plt.xlabel("Digit")
plt.ylabel("Assignment probability")
print("Label", label[0:1])

In [None]:
model.weights[3]

In [None]:
manager.checkpoints

In [None]:
checkpoint.restore(manager.checkpoints[0])

In [None]:
model.weights[3]

In [None]:
image, label = next(train_dataset.__iter__())
prediction = model.apply(image[0:1])
plt.bar(range(0,10), prediction[0])
plt.xlabel("Digit")
plt.ylabel("Assignment probability")
print("Label", label[0:1])

In [None]:
fig, axes = plt.subplots(1,2, figsize=(8,3))

plt.sca(axes[0])
plt.plot(history["epoch"], history["train_loss"], label="train loss")
plt.plot(history["epoch"], history["test_loss"], label="val loss")
plt.legend()
plt.xlabel("epochs")
plt.ylabel("loss")
#plt.yscale("log")

plt.sca(axes[1])
plt.plot(history["epoch"], history["train_accuracy"],
         label="train accuracy")
plt.plot(history["epoch"], history["test_accuracy"], 
         label="val accuracy")
plt.legend()
plt.xlabel("epochs")
plt.ylabel("accuracy")

fig.tight_layout()

### 7) MNIST Autoencoder, fill-in-the-blank exercise
![](https://upload.wikimedia.org/wikipedia/commons/thumb/3/37/Autoencoder_schema.png/220px-Autoencoder_schema.png)
(wikipedia)

We now want to learn a compressed representation of the MNIST dataset by building an __Autoencoder__, which comprises of:
- an Encoder, that takes the input image and compresses it to a lower dimensional (latent) representation
- a Decoder, which takes the output of the encoder and expands it again into the original pixel reprensentation

The target is to minimize the reconstruction of this feed forward model. This is an __unsupervised__ method, which means that we do not use the labels.

The latent representation shall be a vector of a few dimensions (fewer than the 28x28 original image).

Hints: Play around with the activation functions (relu, leaky relu, or sigmoid), and play around with the dimension of the latent space.

Additionally we want to visualize the latent representation of our data using t-SNE (can be done in tensorboard e.g. https://projector.tensorflow.org/)

In [None]:
class Encoder(tf.keras.Model):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.... = ### FILL
    
    def call(self, inputs):
        ### FILL
        return ### FILL
    
class Decoder(tf.keras.Model):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.... = ### FILL
    
    def call(self, inputs):
        ### FILL
        return ### FILL

class Autoencoder(tf.keras.Model):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def call(self, inputs):
        ### FILL
        return ### FILL

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train)).shuffle(1024).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, x_test)).shuffle(1024).batch(32)

In [None]:
dim = ### FILL
ae = Autoencoder(dim)
ae.build((32, 28, 28))
ae.compile(optimizer=tf.keras.optimizers.Adam(),
           loss=tf.keras.losses.MeanSquaredError(),
           metrics=[tf.keras.metrics.MeanSquaredError()])
ae.summary()

In [None]:
history = ae.fit(train_ds, epochs=5, validation_data=test_ds)

In [None]:
plt.plot(history.epoch, history.history["loss"], label="train loss")
plt.plot(history.epoch, history.history["val_loss"], label="val loss")
plt.legend()
plt.xlabel("epochs")
plt.ylabel("loss")

Lets look at some predictions

In [None]:
for images, _ in train_ds.take(1):
    n_images = len(images)
    n_cols = 8
    n_rows = 2*n_images // n_cols
    
    fig, axes = plt.subplots(n_rows,n_cols, figsize=(n_cols*1.5,n_rows*1.5))
    axes = axes.flatten()
    
    predictions = ae.call(images)
    
    for i, image in enumerate(images):
        axes[2*i].matshow(image)
        axes[2*i].set_yticklabels([])
        axes[2*i].set_xticklabels([])
        axes[2*i+1].matshow(predictions[i])
        axes[2*i+1].set_yticklabels([])
        axes[2*i+1].set_xticklabels([])

Get the latent representation of the dataset

In [None]:
latent = ### FILL (get 1200 images and transform them to the latent represenation)
latent.shape

In [None]:
labels = ### FILL (get the corresponding 1200 labels)
labels.shape

In [None]:
import pandas as pd
import seaborn as sns

In [None]:
df = pd.DataFrame({f"{i}": latent[:,i] for i in range(dim)})
df.head()

In [None]:
sns.pairplot(df)

In [None]:
import tsne

In [None]:
proj = tsne.pca(latent.numpy().astype(np.float64), no_dims=2)

In [None]:
scatter = plt.scatter(proj[:,0], proj[:,1], c=labels, cmap=plt.get_cmap("tab10"))
plt.legend(*scatter.legend_elements())

In [None]:
proj = tsne.tsne(latent.numpy().astype(np.float64), max_iter=300)

In [None]:
proj.shape

In [None]:
scatter = plt.scatter(proj[:,0], proj[:,1], c=labels, cmap=plt.get_cmap("tab10"))
plt.legend(*scatter.legend_elements())

## Summary
Tensorflow offers many ways to create models and train them:
- More high level methods conform to the Keras specification.
- Using lower level methods gives more control (write own model and training loop).

Combinations of these are quite interoperable

In detail we have learned:
- Construct models of type `tf.keras.Model` via the sequential API, the functional API, or writing your own subclass
- Choose loss function, optimizer, and metrics
- Supply data directly from numpy arrays or `tf.data.Dataset`
- Train via `model.fit()` or writing your own training loop, e.g.
```python
for i in range(n_epochs):
    for xs, ys in dataset:
        with tf.GradientTape() as tape:
            predictions = model(xs)
            loss = loss_object(ys, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
```
- History of training metrics is automatically returned by `model.fit()`. In your own training loop you can do anything, e.g. fill a list of metrics yourself. (Tomorrow you'll learn about `tensorboard`)
- Saving a model means using checkpoints, which is done via 
    - `keras.ModelCheckpoint` and `model.fit(..., callbacks=...)`, or
    - write your own loop and use a `tf.train.Checkpoint` and `tf.train.CheckpointManager`