# Journal 2022-09-02
[Flax Getting Started Tutorial](https://flax.readthedocs.io/en/latest/getting_started.html) trying a new workflow where the code is edited in a `journal_20220902.py` file in PyCharm to make use of code completion and linting.  Also AWS CodeWhisperer which interestingly filled in a lot of the example model from the tutorial be itself.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from flax import serialization
import os
import jax
import jax.numpy as jnp
import tensorflow as tf
import functools

In [3]:
from journal_20220902 import get_datasets, CNN, run_training, CnnParams

In [4]:
ds_train, ds_test = get_datasets()

In [5]:
train = False
if train:
    print('Training model and saving trained parameters to file')
    state = run_training((ds_train, ds_test))
    cnn_params = CnnParams(state.params)
    cnn_params.to_file('mnist_flax_20220902.bin')
else:
    print('Loading model parameters from file')
    cnn_params = CnnParams.from_file('mnist_flax_20220902.bin')



Loading model parameters from file


# Predictions from Trained Model
Use the trained model to predict image labels 

In [6]:
cnn = CNN()

In [7]:
def predict(p, x):
    return jnp.argmax(cnn.apply({'params': p}, x), axis=1)

In [8]:
predict(cnn_params.params, ds_test['image'][:10])

DeviceArray([2, 0, 4, 8, 7, 6, 0, 6, 3, 1], dtype=int32)

In [9]:
ds_test['label'][:10, ...]

array([2, 0, 4, 8, 7, 6, 0, 6, 3, 1])

# TensorFlow Lite
Try exporting the trained model to TFLite following [JAX Model Conversion for TFLite](https://www.tensorflow.org/lite/examples/jax_conversion/overview) 

In [11]:
serving_fn = functools.partial(predict, cnn_params.params)
x_input = jnp.zeros((1, 28, 28, 1))  # BHWC
converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_fn], [[('input1', x_input)]]
)
tflite_model = converter.convert()

In [12]:
serving_fn(ds_test['image'][:10])

DeviceArray([2, 0, 4, 8, 7, 6, 0, 6, 3, 1], dtype=int32)

In [13]:
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [15]:
interpreter.set_tensor(input_details[0]["index"], ds_test['image'][jnp.newaxis,0, ...].astype(jnp.float32))
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])
result

array([2], dtype=int32)