##### Copyright 2020 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
#@title License header
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Defines a simple TF module, saves it and loads it in IREE.

## Start kernel:
*   [Install a TensorFlow2 nightly pip](https://www.tensorflow.org/install) (or bring your own)
*   Enable IREE/TF integration by adding to your user.bazelrc: `build --define=iree_tensorflow=true`
*   *Optional:* Prime the build: `bazel build bindings/python/pyiree`
*   Start colab by running `python colab/start_colab_kernel.py` (see that file for initial setup instructions)

In [1]:
import tensorflow as tf
from pyiree.tf import compiler as ireec

In [2]:
class MyModule(tf.Module):

  def __init__(self):
    self.v = tf.Variable([4], dtype=tf.float32)

  @tf.function(input_signature=[
      tf.TensorSpec([4], tf.float32),
      tf.TensorSpec([4], tf.float32)
  ])
  def add(self, a, b):
    return tf.tanh(self.v * a + b)

In [10]:
#@title Compile to MLIR (mhlo).
compiler_module = ireec.tf_module_to_compiler_module(MyModule(),
                                                     pass_pipeline=())
print('LOADED ASM:', compiler_module.to_asm())

LOADED ASM: 

module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 504 : i32}, tf_saved_model.semantics} {
  "tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node1__v", tf_saved_model.exported_names = ["v"], type = tensor<1xf32>, value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()
  func @__inference_add_160(%arg0: tensor<4xf32> {tf._user_specified_name = "a", tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf._user_specified_name = "b", tf_saved_model.index_path = [1]}, %arg2: tensor<!tf.resource<tensor<1xf32>>> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = [#tf.shape<4>, #tf.shape<4>, #tf.shape<*>], tf.signature.is_stateful, tf_saved_model.exported_names = ["add"]} {
    %0 = "tf.Cast"(%arg2) {Truncate = false} : (tensor<!tf.resource<tensor<1xf32>>>) -> tensor<*x!tf.resource>
    %1 = tf_executor.graph {
      %outputs, %control = tf_execu

In [12]:
#@title Canonicalize the TF import.
compiler_module.run_pass_pipeline([
  "tf-executor-graph-pruning",
  "tf-standard-pipeline",
  "canonicalize",
])
print("LOWERED TF ASM:", compiler_module.to_asm())

LOWERED TF ASM: 

module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 504 : i32}, tf_saved_model.semantics} {
  "tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node1__v", tf_saved_model.exported_names = ["v"], type = tensor<1xf32>, value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()
  func @__inference_add_160(%arg0: tensor<4xf32> {tf._user_specified_name = "a", tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf._user_specified_name = "b", tf_saved_model.index_path = [1]}, %arg2: tensor<!tf.resource<tensor<1xf32>>> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = [#tf.shape<4>, #tf.shape<4>, #tf.shape<*>], tf.signature.is_stateful, tf_saved_model.exported_names = ["add"]} {
    %0 = "tf.ReadVariableOp"(%arg2) : (tensor<!tf.resource<tensor<1xf32>>>) -> tensor<1xf32>
    %1 = "tf.Mul"(%0, %arg0) {device = ""} : (tensor<1xf32>, tensor<4xf32>) ->

In [13]:
#@title Legalize to XLA (high-level).
compiler_module.run_pass_pipeline([
  "xla-legalize-tf{allow-partial-conversion=true}",
])
print("XLA ASM:", compiler_module.to_asm())

XLA ASM: 

module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 504 : i32}, tf_saved_model.semantics} {
  "tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node1__v", tf_saved_model.exported_names = ["v"], type = tensor<1xf32>, value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()
  func @__inference_add_160(%arg0: tensor<4xf32> {tf._user_specified_name = "a", tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf._user_specified_name = "b", tf_saved_model.index_path = [1]}, %arg2: tensor<!tf.resource<tensor<1xf32>>> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = [#tf.shape<4>, #tf.shape<4>, #tf.shape<*>], tf.signature.is_stateful, tf_saved_model.exported_names = ["add"]} {
    %0 = "tf.ReadVariableOp"(%arg2) : (tensor<!tf.resource<tensor<1xf32>>>) -> tensor<1xf32>
    %1 = shape.shape_of %0 : tensor<1xf32> -> tensor<?xindex>
    %2 = shape.shape_of