# Image edge detection module

## High level overview:

1.  Define a `tf.Module` containing a `@tf.function` that performs edge detection
2.  Save the `tf.Module` as a `SavedModel`
3.  Use IREE's python bindings to load the `SavedModel` into MLIR in the `xla_hlo` dialect
4.  Save the MLIR to a file (can stop here to use it from another application)
5.  Compile the `xla_hlo` MLIR into a VM module for IREE to execute
6.  Run the VM module through IREE's runtime to test the edge detection function

In [0]:
#@title Imports and common setup

import os
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
import pyiree
from pyiree import binding

SAVE_PATH = os.path.join(os.environ["HOME"], "saved_models")
os.makedirs(SAVE_PATH, exist_ok=True)

In [0]:
#@title Construct a module containing the edge detection function

class EdgeDetectionModule(tf.Module):
  @tf.function(input_signature=[tf.TensorSpec([1, 128, 128, 1], tf.float32)])
  def edge_detect_sobel_operator(self, image):
    # https://en.wikipedia.org/wiki/Sobel_operator
    sobel_x = tf.constant([[-1.0, 0.0, 1.0],
                           [-2.0, 0.0, 2.0],
                           [-1.0, 0.0, 1.0]],
                          dtype=tf.float32, shape=[3, 3, 1, 1])    
    sobel_y = tf.constant([[ 1.0,  2.0,  1.0],
                           [ 0.0,  0.0,  0.0],
                           [-1.0, -2.0, -1.0]],
                          dtype=tf.float32, shape=[3, 3, 1, 1])
    gx = tf.nn.conv2d(image, sobel_x, 1, "SAME")
    gy = tf.nn.conv2d(image, sobel_y, 1, "SAME")
    return tf.math.sqrt(gx * gx + gy * gy)

tf_module = EdgeDetectionModule()
saved_model_path = os.path.join(SAVE_PATH, "edge_detection.sm")
save_options = tf.saved_model.SaveOptions(save_debug_info=True)
tf.saved_model.save(tf_module, saved_model_path, options=save_options)

# Compile from SavedModel to MLIR xla_hlo, then save to a file.
# 
# Do *not* further compile to a bytecode module for a particular backend.
# 
# By stopping at xla_hlo in text format, we can more easily take advantage of
# future compiler improvements within IREE and can use iree_bytecode_module to
# compile and bundle the module into a sample application. For a production
# application, we would probably want to freeze the version of IREE used and
# compile as completely as possible ahead of time, then use some other scheme
# to load the module into the application at runtime.
mlir_module = pyiree.compiler.tf_load_saved_model(saved_model_path)
print("Edge Detection MLIR:", mlir_module.to_asm())

edge_detection_mlir_path = os.path.join(SAVE_PATH, "edge_detection.mlir")
with open(edge_detection_mlir_path, "wt") as output_file:
  output_file.write(mlir_module.to_asm())
print("Wrote MLIR to path '%s'" % edge_detection_mlir_path)

In [0]:
#@title Prepare to test the edge detection module

TARGET_BACKENDS = ["vulkan-spirv"]
DRIVER_NAME = "vulkan"

# Compile the MLIR module into a VM module for execution
sequencer_blob = mlir_module.compile_to_sequencer_blob(
    target_backends=TARGET_BACKENDS, print_mlir=True)
vm_module = binding.vm.create_module_from_blob(sequencer_blob)

# Create runtime context and register the VM module
policy = binding.rt.Policy()
instance = binding.rt.Instance(driver_name=DRIVER_NAME)
rt_ctx = binding.rt.Context(instance=instance, policy=policy)
rt_ctx.register_module(vm_module)

In [0]:
#@title Load a test image of a [labrador](https://commons.wikimedia.org/wiki/File:YellowLabradorLooking_new.jpg)

def load_image(path_to_image):
  image = tf.io.read_file(path_to_image)
  image = tf.image.decode_image(image, channels=1)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize(image, (128, 128))
  image = image[tf.newaxis, :]
  return image

content_path = tf.keras.utils.get_file(
    'YellowLabradorLooking_new.jpg',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg')
content_image = load_image(content_path)

In [0]:
#@title Test the "edge_detect_image" function

edge_detect_image_f = rt_ctx.resolve_function(
    "module.edge_detect_sobel_operator")

# Prepare inputs for the function
print("Invoke function: '%s'" % edge_detect_image_f.name)
# arg0_numpy = content_image.numpy().reshape(1, 128, 128, 1).astype(np.float32)
arg0_numpy = content_image.numpy()
arg0 = rt_ctx.wrap_for_input(arg0_numpy)

# Invoke the function and wait for completion
invocation = rt_ctx.invoke(edge_detect_image_f, policy, [arg0])
print("Invocation status:", invocation.query_status())
invocation.await_ready()

# Get the result as a numpy array and plot the input and output images
results = invocation.results
result = results[0].map()
result_array = np.array(result, copy=False, dtype=np.float32)

print("Input:")
plt.imshow(arg0_numpy.reshape(128, 128), cmap="gray")
plt.show()
print("Output:")
plt.imshow(result_array.reshape(128, 128), cmap="gray")
plt.show()