In [None]:
import tensorflow as tf
import numpy as np
from google3.pyglib import gfile
from colabtools import adhoc_import
from jax.experimental import jax2tf

with adhoc_import.Google3SubmittedChangelist():
  from google3.third_party.tensorflow.lite.experimental.mlir.testing.jax.pax import convert_utils


In [None]:
def convert_tflite_model(model):
  """Convert the save TF model to tflite model, then save it as .tflite flatbuffer format

  Args:
      model (tf.keras.Model): the trained hello_world Model

  Returns:
      The converted model in serialized format.
  """
  converter = tf.lite.TFLiteConverter.from_keras_model(model)
  tflite_model = converter.convert()
  return tflite_model

def write_tfl_model(model, model_path):
  with gfile.Open(model_path, 'wb') as f:
    f.write(model)
  print(f'Wrote to {model_path}')

In [None]:
# Convert and save the model to .tflite
base_model_path = '/cns/dy-d/home/rewu/quantization_tool/test_models/'
gfile.MakeDirs(base_model_path)

# Single OP Models


In [None]:
# @title Fully Connected
def create_single_fc_model(hidden_dim=4):
  model = tf.keras.Sequential()
  model.add(tf.keras.Input(shape=(2,8),batch_size=1))
  model.add(
      tf.keras.layers.Dense(
          hidden_dim,
          activation=None,
          use_bias=True,
          bias_initializer="glorot_uniform",
      )
  )
  return model

In [None]:
fc_model = create_single_fc_model()
tflite_model_float = convert_tflite_model(fc_model)
save_path = f'{base_model_path}/single_fc_bias.tflite'
write_tfl_model(tflite_model_float, save_path)

In [None]:
# @title BMM

B, X, Y, Z = 2, 4, 16, 8
BMM_WEIGHT1 = np.random.normal(size=(B, Y, Z)).astype(np.float32)
BMM_WEIGHT2 = np.random.normal(size=(B, Y, Z)).astype(np.float32)


def bmm_test_model(input_tensor):
  @tf.function
  def tf_bmm(input_tensor):
    bmm1 = tf.raw_ops.BatchMatMulV3(
        x=input_tensor,
        y=BMM_WEIGHT1,  # will recover after conversion
        Tout=tf.float32,
        adj_y=False,
    )
    bmm2 = tf.raw_ops.BatchMatMulV3(
        x=bmm1,
        y=BMM_WEIGHT2,  # will recover after conversion
        Tout=tf.float32,
        adj_y=True,
    )
    return bmm2

  output = jax2tf.call_tf(tf_bmm)(input_tensor)
  return output


def convert_bmm(save_path):
  input_tf_signature = [
      tf.TensorSpec(
          (B, X, Y),
          tf.float32,
          'inputs',
      )
  ]

  def export_func(input_tensor):
    return bmm_test_model(input_tensor)

  tf_fxn = convert_utils.create_jax2tf_fxn(
      export_func, input_tf_signature, use_stablehlo=False
  )
  tfl_model = convert_utils.convert2tfl(tf_fxn)
  write_tfl_model(tfl_model, save_path)


save_path = f'{base_model_path}/bmm.tflite'
convert_bmm(save_path)

In [None]:
# @title embedding_lookup
import jax
import jax.numpy as jnp

NUM_SELECTIONS = 1
NUM_CLASS = 16
EMB_VAR = np.random.normal(size=(NUM_CLASS, 8)).astype(np.float32)


def embedding_lookup(ids):
  one_hot_ids = jax.nn.one_hot(ids, NUM_CLASS)
  embs = jnp.einsum('...y,yz->...z', one_hot_ids, EMB_VAR)
  return embs

def convert_embedding_lookup(save_path):
  input_tf_signature = [
      tf.TensorSpec(
          (NUM_SELECTIONS,),
          tf.int32,
          'inputs',
      )
  ]

  def export_func(input_tensor):
    return embedding_lookup(input_tensor)

  tf_fxn = convert_utils.create_jax2tf_fxn(
      export_func, input_tf_signature, use_stablehlo=False
  )
  tfl_model = convert_utils.convert2tfl(tf_fxn)
  write_tfl_model(tfl_model, save_path)

save_path = f'{base_model_path}/embedding_lookup.tflite'
convert_embedding_lookup(save_path)

# MNIST

In [None]:
# @title test model definition
def create_model():
  """Simple model with conv and fc layers."""
  num_classes = 10
  hidden_dim = 32
  model = tf.keras.Sequential()

  model.add(
      tf.keras.layers.Conv2D(
          hidden_dim//4,
          3,
          activation="relu",
          padding="same",
          input_shape=(28, 28, 1),
          use_bias=True,
      )
  )
  model.add(tf.keras.layers.AveragePooling2D())
  model.add(tf.keras.layers.Flatten())
  model.add(tf.keras.layers.Dense(hidden_dim, activation="relu", use_bias=True))
  model.add(
      tf.keras.layers.Dense(num_classes, use_bias=False, activation="softmax")
  )

  model.compile(
      optimizer="adam",
      loss="sparse_categorical_crossentropy",
      metrics=["accuracy"],
  )

  return model


In [None]:
# @title train utils
def get_data():
  """Get MNIST train and test data

  Returns:
      tuple: (data, label) pairs for train and test
  """
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  x_train = x_train / 255.0  # normalize pixel values to 0-1
  x_train = x_train.astype(np.float32)
  x_train = x_train.reshape([-1, 28, 28, 1])
  return (x_train, y_train)

def train_model(epochs, x_values, y_values):
  """Train keras hello_world model

  Args: epochs (int) : number of epochs to train the model
      x_train (numpy.array): list of the training data
      y_train (numpy.array): list of the corresponding array
  Returns:
      tf.keras.Model: A trained keras hello_world model
  """
  model = create_model()
  model.fit(
      x_values,
      y_values,
      epochs=epochs,
      validation_split=0.2,
      batch_size=256,
      verbose=1,
  )
  return model

In [None]:
epochs = 5
x_values, y_values = get_data()
trained_model = train_model(epochs, x_values, y_values)

In [None]:
# Convert and save the model to .tflite
tflite_model_float = convert_tflite_model(trained_model)
save_path = f'{base_model_path}/conv_fc_mnist.tflite'
write_tfl_model(tflite_model_float, save_path)

# Branching model

In [None]:
from tensorflow.keras import layers
from tensorflow.keras.models import Model


def create_branching_model():
  """Create a simple model with branching."""
  input_layer = layers.Input(
      shape=(
          4,
          4,
          1,
      ),
      batch_size=3,
  )
  x1, x2, x3 = tf.split(input_layer, 3, axis=0)
  # First branch
  x1 = tf.reshape(x1, (1, 16))
  x1 = tf.keras.layers.Dense(
      16, activation=None, use_bias=True, bias_initializer="glorot_uniform"
  )(x1)
  # Second branch
  x2 = layers.Conv2D(
      filters=1, kernel_size=(3, 3), padding="same", activation="relu"
  )(x2)
  x2 = tf.reshape(x2, (1, 16))
  # Third branch
  x3 = tf.reshape(x3, (1, 16))
  # Merge second with third
  x2 = tf.concat([x2, x3], axis=0)
  # Merge first with the rest
  y = tf.concat([x1, x2], axis=0)
  y = tf.reshape(y, (48,))
  model = Model(inputs=input_layer, outputs=y)
  return model

In [None]:
branching_model = create_branching_model()
print(branching_model.summary())

In [None]:
tflite_model_float = convert_tflite_model(branching_model)
save_path = f'{base_model_path}/branching_conv_fc.tflite'
write_tfl_model(tflite_model_float, save_path)