# Defines a simple TF module, saves it and loads it in IREE.

## Start kernel:
*   [Install a TensorFlow2 nightly pip](https://www.tensorflow.org/install) (or bring your own)
*   Enable IREE/TF integration by adding to your user.bazelrc: `build --define=iree_tensorflow=true`
*   *Optional:* Prime the build: `bazel build bindings/python/pyiree`
*   Start colab by running `python build_tools/scripts/start_colab_kernel.py` (see that file for initial setup instructions)

## TODO:

* This is just using low-level binding classes. Change to high level API.
* Plumg through ability to run TF compiler lowering passes and import directly into IREE


In [0]:
import os
import tensorflow as tf
import pyiree
from pyiree import binding

SAVE_PATH = os.path.join(os.environ["HOME"], "saved_models")
os.makedirs(SAVE_PATH, exist_ok=True)

In [5]:
class MyModule(tf.Module):
  def __init__(self):
    self.v = tf.Variable([4], dtype=tf.float32)
  
  @tf.function(
      input_signature=[tf.TensorSpec([4], tf.float32), tf.TensorSpec([4], tf.float32)]
  )
  def add(self, a, b):
    return tf.tanh(self.v * a + b)

my_mod = MyModule()

options = tf.saved_model.SaveOptions(save_debug_info=True)
tf.saved_model.save(my_mod, os.path.join(SAVE_PATH, "simple.sm"), options=options)

ctx = binding.compiler.CompilerContext()
input_module = binding.tf_interop.load_saved_model(ctx, os.path.join(SAVE_PATH, "simple.sm"))
print('LOADED ASM:', input_module.to_asm())

# Canonicalize the TF import.
input_module.run_pass_pipeline([
  "tf-executor-graph-pruning",
  "tf-standard-pipeline",
  "canonicalize",
])
print("LOWERED TF ASM:", input_module.to_asm())

# Legalize to XLA (high-level).
input_module.run_pass_pipeline([
  "xla-legalize-tf",
])
print("XLA ASM:", input_module.to_asm())

INFO:tensorflow:Assets written to: C:\Users\laurenzo\saved_models\simple.sm\assets
LOADED ASM: 

module attributes {tf_saved_model.semantics} {
  "tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node1__v", tf_saved_model.exported_names = ["v"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()
  func @__inference_add_2620(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})
  attributes  {tf._input_shapes = ["tfshape$dim { size: 4 }", "tfshape$dim { size: 4 }", "tfshape$unknown_rank: true"], tf.signature.is_stateful, tf_saved_model.exported_names = ["add"]} {
    %0 = tf_executor.graph {
      %1:2 = tf_executor.island wraps "tf.ReadVariableOp"(%arg2) {_output_shapes = ["tfshape$dim { size: 1 }"], device = "", dtype = "tfdtype$DT_FLOAT", name = "ReadVariableOp"} : (tensor<*x!tf.res