##### Copyright 2024 The AI Edge Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Jax Model Conversion For LiteRT
## Overview
Note: This API is new and we recommend using via pip install tf-nightly. Also, the API is still experimental and subject to changes.

## Prerequisites
It's recommended to try this feature with the newest TensorFlow nightly pip build.

In [None]:
!pip install jax --upgrade
!pip install ai-edge-litert
!pip install orbax-export --upgrade
!pip install tf-nightly --upgrade

In [None]:
# Make sure your JAX version is at least 0.4.20 or above.
import jax
jax.__version__

In [None]:
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
from orbax.export import constants

import tensorflow as tf
from PIL import Image

import time
import functools
import itertools

import numpy as np
import numpy.random as npr

import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import optimizers
from jax.example_libraries import stax

## Data Preparation
Download the MNIST data with Keras dataset and run that data through a pre-processing step. This dataset consists of multiple images that are 28x28 pixels and grayscaled (only having one color channel from black to white) representing hand drawn digits from 0 to 9.

During the pre-processing step, the images will be normalized so that their gray color channel will change from 0->255 to 0.0->1.0. This decreases training time.

The model will also use One Hot Encoding. This filters predictions to the most likely prediction.

In [None]:
# Create a one-hot encoding of x of size k.
def _one_hot(x, k, dtype=np.float32):
  return np.array(x[:, None] == np.arange(k), dtype)

# JAX doesn't have its own data loader, so you can use Keras here.
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# Normalize the image pixels to a range of 0.0 to 1.0
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)

train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)

The following code block is a simple utility to display a set of the MNIST dataset images.

In [None]:
# Draws out some of the data in the training dataset.
import matplotlib.pyplot as plt

rows = 3
cols = 7

for i in range(rows):
  for j in range(cols):
    index = i * cols + j
    if index < len(train_images):
      plt.subplot(rows, cols, index + 1)
      plt.imshow(train_images[index], cmap='gray')
      plt.title(f"Label: {np.argmax(train_labels[index])}")
      plt.axis('off')

plt.tight_layout()
plt.show()

## Build the MNIST model with Jax

This block outlines the loss and accuracy functions for training a new classification model, as well as defines the shape of the model layers.

In [None]:
# Loss function: Measures how well the model's predictions match expected outputs.
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))

# Accuracy function: Average number of times the predictec class matches the true class
def accuracy(params, batch):
  inputs, targets = batch
  # Finds the highest value in the output targets, which is the true value
  target_class = jnp.argmax(targets, axis=1)
  # Gets the primary predicted value from classification
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(
    stax.Flatten, # turns input data into a vector (1D array)
    stax.Dense(1024), stax.Relu, # Create two dense layers with ReLU activation
    stax.Dense(1024), stax.Relu,
    stax.Dense(10), stax.LogSoftmax) # Final layer condenses predictions into one of ten potential output classifications (0->9)

# Pseudo random number generator used for initializing values
rng = random.PRNGKey(0)

## Train & Evaluate the model

In [None]:
step_size = 0.001 # Learning rate - smaller means slower but more stable learning
num_epochs = 10
batch_size = 128
momentum_mass = 0.9 # Momentum optimization algorithm - helps converge faster

# Data setup
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()

# Optimizer setup
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

# Performs a single training step. Gets the current parameters, calculates
# gradient of the loss function, then updates the optimizer state and model parameters
@jit
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

# Run the training loop!
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    opt_state = update(next(itercount), opt_state, next(batches))
  epoch_time = time.time() - start_time

  params = get_params(opt_state)
  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

## Convert to a tflite model.

Using the `orbax` library, you can export the newly trained `JAX` model to a TensorFlow `SavedModel` file. Once you have a `SavedModel`, you can convert it to a `.tflite` file that can work with the LiteRT interpreter.






In [None]:
# This line bridges JAX to TensorFlow
# Key point: `params` is everything that was learned during training. This is the
# core part of what you just accomplished.
# `predict` is the JAX function that does inference.
jax_module = JaxModule(params, predict, input_polymorphic_shape='b, ...')

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[constants.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(1, 28, 28), dtype=tf.float32, name="input")
        )
    ],
    trackable_obj=tf.function() # Added empty trackable_obj argument
)

tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
  f.write(tflite_model)

## Check the Converted TFLite Model
Next you can compare the converted model's results with the Jax model. This first block defines a utility to perform the prediction inference.

In [None]:
def predict_image_class(image_path, model_path):

  try:
    # Load the TFLite model and allocate tensors.
    interpreter = Interpreter(model_path=model_path)
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Load the test image.
    img = Image.open(image_path).convert('L').resize((28, 28))
    img_array = np.array(img)
    img_array = img_array / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    img_array = img_array.astype(np.float32)

    # Set the tensor to the input tensor and run inference.
    interpreter.set_tensor(input_details[0]['index'], img_array)
    interpreter.invoke()

    # Get the output tensor.
    output_data = interpreter.get_tensor(output_details[0]['index'])

    # Get the predicted class
    predicted_class = np.argmax(output_data)
    print("Predicted class:", predicted_class)

  except Exception as e:
    print(f"An error occurred: {e}")

You can download a pre-drawn image for testing that Google has provided, or load your own hand drawn monochronmatic image into the `/content/` directory.

In [None]:
!wget https://storage.googleapis.com/ai-edge/models-samples/jax_converter/jax_to_litert_conversion_test/7.png -O /content/7.png

In [None]:
from ai_edge_litert.interpreter import Interpreter

# Example usage
# Replace with your image and model paths
image_path = "/content/7.png"
model_path = "/content/jax_mnist.tflite"

predict_image_class(image_path, model_path)

## Optimize the Model
We will provide a `representative_dataset` to do post-training quantiztion to optimize the model. This will reduce the model size to roughly a quarter.




In [None]:
def representative_dataset():
  for i in range(1000):
    x = train_images[i:i+1]
    yield [x]


# Create a orbax.export.JaxModule that wraps the given JAX function and params into a TF.Module
jax_module = JaxModule(params, predict)

# Instanciate tf.lite.TFLiteConverter object from the default_signature in the above module
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[constants.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(1, 28, 28), dtype=tf.float32, name="input")
        )
    ],
    trackable_obj=tf.function() # Added empty trackable_obj argument
)

# Apply optimization settings and convert the model
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()

# Save the serialized model contents to a .tflite flatbuffer
with open('jax_mnist_quant.tflite', 'wb') as f:
  f.write(tflite_quant_model)

## Evaluate the Optimized Model

In [None]:
image_path = "/content/7.png"
model_path = "/content/jax_mnist_quant.tflite"

predict_image_class(image_path, model_path)

## Compare the Quantized Model size
We should be able to see the quantized model is four times smaller than the original model.

In [None]:
!du -h jax_mnist.tflite
!du -h jax_mnist_quant.tflite

In [None]:
from google.colab import files

files.download('jax_mnist.tflite')
files.download('jax_mnist_quant.tflite')