##### Copyright 2018 The TensorFlow Authors.

Licensed under the Apache License, Version 2.0 (the "License").

In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TensorFlow 2.0: Train and save a model

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/2/guide/train_and_save"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/guide/train_and_save.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/r2/guide/train_and_save.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

This notebook trains a simple MNIST model to demonstrate the basic workflow for using TensorFlow 2.0 APIs:

1. Define a model
2. Preprocessing your data into a `tf.data.Dataset`
3. Train the model with the dataset
  - Use `tf.GradientTape` to compute gradients
  - Use stateful `tf.keras.metrics.*` to collect metrics of interest
  - Log metrics with `tf.summary.*` APIs to view in TensorBoard
  - Use `tf.train.Checkpoint` to save and restore weights
4. Export a `SavedModel` with `tf.saved_model` (this is a portable representation of the model that can be imported into C++, JS, Python without knowledge of the original TensorFlow code.)
5. Re-import the `SavedModel` and demonstrate its usage in Python.

## Setup

Import TensorFlow 2.0 Preview Nightly and enable TF 2.0 mode:

In [0]:
from __future__ import absolute_import, division, print_function

import os
import time
import numpy as np

In [0]:
!pip install tf-nightly-2.0-preview
import tensorflow as tf

In [0]:
from tensorflow.python.ops import summary_ops_v2

## Define a model with the tf.Keras API


Build a convolutional model using the [tf.Keras API](https://www.tensorflow.org/guide/keras). This model uses the `channel_last` [data format](https://www.tensorflow.org/guide/performance/overview#data_formats).

In [0]:
from tensorflow.keras import layers


def create_model():
  max_pool = layers.MaxPooling2D((2, 2), (2, 2), padding='same')
  # The model consists of a sequential chain of layers, so tf.keras.Sequential
  # (a subclass of tf.keras.Model) makes for a compact description.
  return tf.keras.Sequential([
      layers.Reshape(
          target_shape=[28, 28, 1],
          input_shape=(28, 28,)),
      layers.Conv2D(2, 5, padding='same', activation=tf.nn.relu),
      max_pool,
      layers.Conv2D(4, 5, padding='same', activation=tf.nn.relu),
      max_pool,
      layers.Flatten(),
      layers.Dense(32, activation=tf.nn.relu),
      layers.Dropout(0.4),
      layers.Dense(10)])

compute_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

compute_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

Create the model and optimizer:

In [0]:
model = create_model()

optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.5)

## Download and create datasets

Load the MNIST dataset into a [tf.data.Dataset](https://www.tensorflow.org/guide/datasets). This provides useful transformations like batching and shuffling.

Note: Keras models can train directly on numpy arrays for small datasets (see [basic classification](../keras/basic_classification.ipynb)). The use of `tf.data` here is to demonstrate the API for applications that need more scalability.   

In [0]:
# Set up datasets
def mnist_datasets():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
  # Numpy defaults to dtype=float64; TF defaults to float32. Stick with float32.
  x_train, x_test = x_train / np.float32(255), x_test / np.float32(255)
  y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
  test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
  return train_dataset, test_dataset

In [0]:
train_ds, test_ds = mnist_datasets()
train_ds = train_ds.shuffle(60000).batch(100)
test_ds = test_ds.batch(100)

print('Dataset will yield tensors of the following shape: {}'.format(train_ds.output_shapes))

## Configure training

Note: Keras models include a complete training loop (see [basic classification](../keras/basic_classification.ipynb)). The training process is only defined manually here as a starting point for applications that need deeper customization.   

The `train()` function iterates over the training dataset, computing the gradients for each batch and then applying them to the model variables. It periodically outputs summaries.

In [0]:
@tf.function
def train_step(model, optimizer, images, labels):
  # Record the operations used to compute the loss, so that the gradient
  # of the loss with respect to the variables can be computed.
  with tf.GradientTape() as tape:
    logits = model(images, training=True)
    loss = compute_loss(labels, logits)
    compute_accuracy(labels, logits)
  grads = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grads, model.trainable_variables))
  return loss


def train(model, optimizer, dataset, log_freq=50):
  """Trains model on `dataset` using `optimizer`."""
  start = time.time()
  # Metrics are stateful. They accumulate values and return a cumulative
  # result when you call .result(). Clear accumulated values with .reset_states()
  avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)
  
  # Datasets can be iterated over like any other Python iterable.
  for images, labels in dataset:
    loss = train_step(model, optimizer, images, labels)
    avg_loss(loss)

#     if tf.equal(optimizer.iterations % log_freq, 0):
#       summary_ops_v2.scalar('loss', avg_loss.result(), step=optimizer.iterations)
#       summary_ops_v2.scalar('accuracy', compute_accuracy.result(), step=optimizer.iterations)
    avg_loss.reset_states()
    compute_accuracy.reset_states()
    rate = log_freq / (time.time() - start)
    print('Step #%d\tLoss: %.6f (%d steps/sec)' % (optimizer.iterations, loss, rate))
    start = time.time()

In [0]:
def test(model, dataset, step_num):
  """Perform an evaluation of `model` on the examples from `dataset`."""
  avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)

  for (images, labels) in dataset:
    logits = model(images, training=False)
    avg_loss(compute_loss(labels, logits))
    compute_accuracy(labels, logits)
  print('Model test set loss: {:0.4f} accuracy: {:0.2f}%'.format(
      avg_loss.result(), compute_accuracy.result() * 100))
#   summary_ops_v2.scalar('loss', avg_loss.result(), step=step_num)
#   summary_ops_v2.scalar('accuracy', compute_accuracy.result(), step=step_num)


## Configure model directory

Use one directory to save the relevant artifacts—summary logs, checkpoints, and `SavedModel` exports.

In [0]:
# Where to save checkpoints, tensorboard summaries, etc.
MODEL_DIR = '/tmp/tensorflow/mnist'


def apply_clean():
  if tf.io.gfile.exists(MODEL_DIR):
    print('Removing existing model dir: {}'.format(MODEL_DIR))
    tf.io.gfile.rmtree(MODEL_DIR)

In [0]:
# Optional: wipe the existing directory
apply_clean()

You can configure the output location for the training summaries. Previously, we called `tf.summary.scalar(...)` in the `train()` function, by using the `summary_writer` in a `with` block, you can catch those generated summaries and direct them to a file. View the summaries with `tensorboard --logdir=<model_dir>`

In [0]:
train_dir = os.path.join(MODEL_DIR, 'summaries', 'train')
test_dir = os.path.join(MODEL_DIR, 'summaries', 'eval')

# train_summary_writer = summary_ops_v2.create_file_writer(
#   train_dir, flush_millis=10000)
# test_summary_writer = summary_ops_v2.create_file_writer(
#   test_dir, flush_millis=10000, name='test')

## Configure checkpoints

The `tf.train.Checkpoint` object helps manage which `tf.Variable`s are saved and restored from the checkpoint files.

A checkpoint differs from a `SavedModel` because it additionally keeps track of training-related state, such as momentum variables for a momentum-based optimizer or things like the global step. A checkpoint only stores weights so you'll need the original code to define the computation with those weights.

In [0]:
checkpoint_dir = os.path.join(MODEL_DIR, 'checkpoints')
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')

checkpoint = tf.train.Checkpoint(
  model=model, optimizer=optimizer)

# Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Train

Now that `train()` and `test()` are set up, create a model and train it for some number of epochs:

In [0]:
NUM_TRAIN_EPOCHS = 1

for i in range(NUM_TRAIN_EPOCHS):
  start = time.time()
#   with train_summary_writer.as_default():
  train(model, optimizer, train_ds)
  end = time.time()
  print('\nTrain time for epoch #{} ({} total steps): {}'.format(
      i + 1, optimizer.iterations, end - start))
#   with test_summary_writer.as_default():
#     test(model, test_ds, optimizer.iterations)
  checkpoint.save(checkpoint_prefix)

## Export a SavedModel

In [0]:
export_path = os.path.join(MODEL_DIR, 'export')

tf.saved_model.save(model, export_path)

## Restore and run the SavedModel

Restore any `SavedModel` and call it without reference to the original source code. APIs for importing and transforming `SavedModel`s exist for a variety of languages. See the [SavedModel guide](https://www.tensorflow.org/guide/saved_model) for more.

In [0]:
def import_and_eval():
  restored_model = tf.saved_model.restore(export_path)
  _, (x_test, y_test) = tf.keras.datasets.mnist.load_data()
  x_test = x_test / np.float32(255)
  y_predict = restored_model(x_test)
  accuracy = compute_accuracy(y_test, y_predict)
  print('Model accuracy: {:0.2f}%'.format(accuracy.result() * 100))

# TODO(brianklee): Activate after v2 import is implemented.
# import_and_eval()