##### Copyright 2021 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# On-Device Training in TensorFlow Lite

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/lite/examples/on_device_training/overview"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

**Note:** This API is new and only available via `pip install tf-nightly`. It will be available in TensorFlow version 2.7.

To use TensorFlow Lite, a developer needs to prepare a TensorFlow model, use the converter to convert it to TensorFlow Lite model format, and run the model with the TensorFlow Lite runtime on device. This is true for inference use cases, and a similar flow can be applied to training too.

The following code illustrates the high-level flow of preparing a TensorFlow training model, converting it to TensorFlow Lite model and running in TensorFlow Lite runtime for a training use case.

The implementation is based on the [Keras classification example](https://www.tensorflow.org/tutorials/keras/classification) in the TensorFlow official guide page.

# Setup

In [None]:
!pip uninstall -y tensorflow keras
!pip install tf-nightly

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

# Classify images of clothing

This CoLab trains a neural network model to classify images of clothing, like sneakers and shirts.

Here, 60,000 images are used to train the network and 10,000 images to evaluate how accurately the network learned to classify images. You can access the Fashion MNIST directly from TensorFlow. Import and load the Fashion MNIST data directly from TensorFlow:

Loading the dataset returns four NumPy arrays:

* The `train_images` and `train_labels` arrays are the *training set*—the data the model uses to learn.
* The model is tested against the *test set*, the `test_images`, and `test_labels` arrays.

In [None]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

The images are 28x28 NumPy arrays, with pixel values ranging from 0 to 255. The *labels* are an array of integers, ranging from 0 to 9. These correspond to the *class* of clothing the image represents:

In [None]:
train_images.shape


<table>
  <tr>
    <th>Label</th>
    <th>Class</th>
  </tr>
  <tr>
    <td>0</td>
    <td>T-shirt/top</td>
  </tr>
  <tr>
    <td>1</td>
    <td>Trouser</td>
  </tr>
    <tr>
    <td>2</td>
    <td>Pullover</td>
  </tr>
    <tr>
    <td>3</td>
    <td>Dress</td>
  </tr>
    <tr>
    <td>4</td>
    <td>Coat</td>
  </tr>
    <tr>
    <td>5</td>
    <td>Sandal</td>
  </tr>
    <tr>
    <td>6</td>
    <td>Shirt</td>
  </tr>
    <tr>
    <td>7</td>
    <td>Sneaker</td>
  </tr>
    <tr>
    <td>8</td>
    <td>Bag</td>
  </tr>
    <tr>
    <td>9</td>
    <td>Ankle boot</td>
  </tr>
</table>

Each image is mapped to a single label. Since the *class names* are not included with the dataset, store them here to use later when plotting the images:

In [None]:
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

Scale these values to a range of 0 to 1 before feeding them to the neural network model. To do so, divide the values by 255. It's important that the *training set* and the *testing set* be preprocessed in the same way:

In [None]:
train_images = train_images / 255.0

test_images = test_images / 255.0

In [None]:
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()

## TensorFlow Model for Training

Instead of converting a single TensorFlow model or tf.function to a TensorFlow Lite model with a single entry point, we can convert multiple tf.function(s) into a TensorFlow Lite model. To be able to do that, we're extending the TensorFlow Lite's converter & runtime to handle multiple signatures.

Preparing a TensorFlow Model. The code constructs a tf.module with 4 tf.functions:
*   train function trains the model with training data.
*   infer function invokes the inference.
*   save function saves the trainable weights into the file system.
*   restore function loads the trainable weights from the file system.

The weights will be serialized as a TensorFlow version one checkpoint file format.

In [None]:
IMG_SIZE = 28


class Model(tf.Module):

  def __init__(self):
    self.model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(IMG_SIZE, IMG_SIZE)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    self.model.compile(
        optimizer='sgd',
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'])
    self._LOSS_FN = tf.keras.losses.CategoricalCrossentropy()
    self._OPTIM = tf.optimizers.SGD()

  @tf.function(input_signature=[
      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
      tf.TensorSpec([None, 10], tf.float32),
  ])
  def train(self, x, y):
    with tf.GradientTape() as tape:
      prediction = self.model(x)
      loss = self._LOSS_FN(prediction, y)
    gradients = tape.gradient(loss, self.model.trainable_variables)
    self._OPTIM.apply_gradients(
        zip(gradients, self.model.trainable_variables))
    return gradients

  @tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
  def predict(self, x):
    return {
        "output": self.model(x)
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def save(self, checkpoint_path):
    tensor_names = [weight.name for weight in self.model.weights]
    tensors_to_save = [weight.read_value() for weight in self.model.weights]
    tf.raw_ops.Save(
        filename=checkpoint_path, tensor_names=tensor_names,
        data=tensors_to_save, name='save')
    return {
        "checkpoint_path": checkpoint_path
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def restore(self, checkpoint_path):
    restored_tensors = {}
    for var in self.model.weights:
      restored = tf.raw_ops.Restore(
          file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype,
          name='restore')
      var.assign(restored)
      restored_tensors[var.name] = restored
    return restored_tensors

## Converting to TensorFlow Lite format.

Note the TFLiteConverter API will be extended to choose which functions / signatures to convert.

In [None]:
# Export the TensorFlow model to the saved model
SAVED_MODEL_DIR = "/content/saved_model"
m= Model()
tf.saved_model.save(
    m,
    SAVED_MODEL_DIR,
    signatures={
        'train':
            m.train.get_concrete_function(),
        'infer':
            m.predict.get_concrete_function(),
        'save':
            m.save.get_concrete_function(),
        'restore':
            m.restore.get_concrete_function(),
    })

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

## Executing in TensorFlow Lite runtime.

TensorFlow Lite's Interpreter capability will be extended to support multiple signatures too. Developers can choose to invoke restoring, training, saving and inferring signatures separately.


In [None]:
interpreter = tf.lite.Interpreter(model_content=tflite_model)

train = interpreter.get_signature_runner("train")
infer = interpreter.get_signature_runner("infer")
save = interpreter.get_signature_runner("save")
restore = interpreter.get_signature_runner("restore")

## Training with data set

Training can be done with the `train` signature method.

In [None]:
# Generate the training labels.
processed_train_labels = []
for i in range(len(train_images)):
  train_label = [0,0,0,0,0,0,0,0,0,0]
  train_label[train_labels[i]] = 1
  processed_train_labels.append(train_label)

# Run training for a few steps.
for i in range(20):
  train(
      x=tf.constant(train_images, shape=(len(train_images), IMG_SIZE, IMG_SIZE), dtype=tf.float32),
      y=tf.constant(processed_train_labels, shape=(len(train_images), 10), dtype=tf.float32))

## Exporting the trained weights to the checkpoint file

The checkpoint file can be generated through the `save` signature method.

In [None]:
# Export the traind weights to /tmp/model.ckpt
save(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))

## Restoring the trained weights from the checkpoint file

The exported checkpoint file can be restored through the `restore` signature method.

In [None]:
another_interpreter = tf.lite.Interpreter(model_content=tflite_model)

train = another_interpreter.get_signature_runner("train")
infer = another_interpreter.get_signature_runner("infer")
save = another_interpreter.get_signature_runner("save")
restore = another_interpreter.get_signature_runner("restore")

# Restore the trained weights from /tmp/model.ckpt
restore(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))

## Run the inference through the trained weights

Developers can use the trained model to run inference through the `infer` signature method.

In [None]:
# Run the inference
result = infer(
    x=tf.constant(test_images, shape=(len(test_images), IMG_SIZE, IMG_SIZE), dtype=tf.float32))

test_lables = np.argmax(result["output"], axis=1)

In [None]:
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(test_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[test_labels[i]])
plt.show()