# Using Custom Ops with TF2ONNX

The custom ops framework lets you define new ONNX operators in Python or C++ and load them into ORT.  This makes it possible to convert and run TF models with ops that have no current ONNX equivalent.  The framework also serves as a place for sharing custom op definitions.

There are 3 main ways to use this framework:
- Case 1: Converting a TF model using an existing custom op
  - Best option if op is already implemented
- Case 2: Defining new custom ops in Python to use in conversion
  - Easier than C++ but perf might be poor
- Case 3: Defining new custom ops in C++
  - Likely better perf than Python but requires building the customops repo from source

For cases 1 and 2, you can use the off-the-shelf pip package `onnxruntime_extensions`.  For case 3, you will need to clone and build the customops repo.  Follow the instructions [here](https://github.com/microsoft/ort-customops#getting-started).

You will also need to install the onnxruntime, tensorflow, and tf2onnx packages.  **NOTE: tf2onnx version (FIXME) is required for this tutorial.**

## Case 1: Converting a TF model using an existing custom op

First let's create a model that requires a custom op that is already defined in the custom ops framework

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
class Model1(tf.keras.Model):

    def __init__(self, name='model1', **kwargs):
        super(Model1, self).__init__(name=name, **kwargs)

    def call(self, inputs):
        return tf.strings.regex_replace(inputs, " ", "_", replace_global=True)

model1 = Model1()

In [3]:
model1(tf.constant(["Hello world!"]))

<tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Hello_world!'], dtype=object)>

In [4]:
model1.save("saved_model1")

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: saved_model1\assets


### Identifying unsupported ops

If a model has unsupported ops, tf2onnx will still convert it, but the unsupported ops will be left in the graph unchanged. An error message will list the unsupported ops.

In [5]:
!python -m tf2onnx.convert --saved-model "saved_model1" --output "model1.onnx"

Loading a model with unsupported ops into ORT raises an error.

In [6]:
import onnxruntime as ort

try:
    sess = ort.InferenceSession("model1.onnx")
except Exception as e:
    print(e)

[ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from model1.onnx failed:This is an invalid model. Error in Node:PartitionedCall/autoencoder/StaticRegexReplace : No Op registered for StaticRegexReplace with domain_version of 8


### Enabling custom ops in the converter

Fortunately, in this case there is already a custom op implementing the functionality we need: StringRegexReplace.  The converter has a rule to replace TF's StaticRegexReplace op with the StringRegexReplace custom op.  To enable conversions that use custom ops, add the `--extra_opset ai.onnx.contrib:1` flag.

In [7]:
!python -m tf2onnx.convert --saved-model "saved_model1" --output "model1.onnx" --extra_opset ai.onnx.contrib:1

### Loading custom ops into ORT

Pass the location of the custom ops library into the ORT session options to use the op.

In [8]:
import onnxruntime as ort
from onnxruntime_extensions import get_library_path

so = ort.SessionOptions()
so.register_custom_ops_library(get_library_path())

sess = ort.InferenceSession("model1.onnx", so)
print("Inputs:", [inp.name for inp in sess.get_inputs()])
print("Outputs:", [out.name for out in sess.get_outputs()])

Inputs: ['input_1:0']
Outputs: ['Identity:0']


In [9]:
sess.run(["Identity:0"], {"input_1:0": ["Hello World!"]})

[array(['Hello_World!'], dtype=object)]

## Case 2: Defining new custom ops with Python

If there is no existing custom op implementation, you will need to define the op yourself and add a conversion rule for it.

In [10]:
import tensorflow as tf
import numpy as np

In [11]:
class Model2(tf.keras.Model):

    def __init__(self, name='model2', **kwargs):
        super(Model2, self).__init__(name=name, **kwargs)

    def call(self, inputs):
        x, segment_ids = inputs
        num_segs = tf.reduce_max(segment_ids) + 1
        return tf.strings.unsorted_segment_join(x, segment_ids, num_segs, separator='-')

model2 = Model2()

In [12]:
model2([tf.constant(["car", "java", "pet", "script"]), tf.constant([1, 0, 1, 0])])

<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'java-script', b'car-pet'], dtype=object)>

In [13]:
model2.save("saved_model2", save_format="tf")

INFO:tensorflow:Assets written to: saved_model2\assets


### Adding a custom op conversion rule using the command line

We need to tell the converter how to convert the TF DecodeGif op. Even if our custom op will have the same name as the TF op, the node must be tagged with the custom ops domain `ai.onnx.contrib`.

Pass `--extra_opset ai.onnx.contrib:1` and `--custom-ops DecodeGif:ai.onnx.contrib` flags to the converter.

In [14]:
!python -m tf2onnx.convert --saved-model "saved_model2" --output "model2a.onnx" --extra_opset ai.onnx.contrib:1 --custom-ops UnsortedSegmentJoin:ai.onnx.contrib

### Adding a custom op conversion rule using python

For more complicated conversions, the rule can be defined using python.  See the [tf2onnx repo](https://github.com/onnx/tensorflow-onnx/tree/master/tf2onnx/onnx_opset) for more conversion rule examples.

In [15]:
import numpy as np
from tf2onnx import utils, constants
from tf2onnx.handler import tf_op

# Registers a conversion rule for UnsortedSegmentJoin op
# Rule will only be run if ai.onnx.contrib domain is included via --extra_opset flag
@tf_op("UnsortedSegmentJoin", domain=constants.CONTRIB_OPS_DOMAIN)
class ConvertUnsortedSegmentJoinOp:
    @classmethod
    def version_1(cls, ctx, node, **kwargs):
        node.type = "MyCustomStringSegmentJoin"
        # Don't forget to set the domain!
        node.domain = constants.CONTRIB_OPS_DOMAIN
        # Ops defined using the custom ops framework only get access to inputs, not attributes
        separator = node.get_attr_str("separator") if "separator" in node.attr else ''
        for a in list(node.attr.keys()):
            del node.attr[a]
        # Add the separator as an additional string input
        separator_const = ctx.make_const(utils.make_name('separator_const'), np.array([separator], dtype=np.object))
        ctx.replace_inputs(node, node.input + [separator_const.output[0]])

Next, call the converter using the [tf2onnx Python API](https://github.com/onnx/tensorflow-onnx#python-api-reference). All rules decorated with `@tf_op` will be used.

In [16]:
concrete_fn2 = tf.function(model2.call).get_concrete_function([tf.TensorSpec([None], tf.string), tf.TensorSpec([None], tf.int32)])
input_names = [inp.name for inp in concrete_fn2.inputs]
output_names = [out.name for out in concrete_fn2.outputs]
print("Inputs:", input_names)
print("Outputs:", output_names)

Inputs: ['inputs:0', 'inputs_1:0']
Outputs: ['Identity:0']


In [17]:
from tf2onnx import tf_loader
from tf2onnx.tfonnx import process_tf_graph
from tf2onnx.optimizer import optimize_graph

graph_def = tf_loader.from_function(concrete_fn2, input_names=input_names, output_names=output_names)
extra_opset = [utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1)]
with tf.Graph().as_default() as tf_graph:
    tf.import_graph_def(graph_def, name='')
with tf_loader.tf_session(graph=tf_graph):
    g = process_tf_graph(tf_graph, input_names=input_names, output_names=output_names, extra_opset=extra_opset)
onnx_graph = optimize_graph(g)
model_proto = onnx_graph.make_model("converted")
utils.save_protobuf("model2b.onnx", model_proto)
print("Conversion complete!")

Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
Conversion complete!


### Implementing the op in python

Add a function with the `@onnx_op` decorator to register a custom op before creating the ORT InferenceSession.  The inputs will be passed in as numpy arrays, and a numpy array of the declared type should be returned.  

**NOTE:** ORT only will allow an op to be registered once, so you must restart the Jupyter kernel each time you change the implementation below.

In [18]:
import numpy as np
from onnxruntime_extensions import onnx_op, PyCustomOpDef

@onnx_op(op_type="UnsortedSegmentJoin",
         inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_int32, PyCustomOpDef.dt_int32],
         outputs=[PyCustomOpDef.dt_string])
def unsorted_segment_join(x, segment_ids, num_segments):
    # The custom op implementation.
    result = np.full([num_segments], '', dtype=np.object)
    for s, seg_id in zip(x, segment_ids):
        result[seg_id] += s
    return result

@onnx_op(op_type="MyCustomStringSegmentJoin",
         inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_int32, PyCustomOpDef.dt_int32, PyCustomOpDef.dt_string],
         outputs=[PyCustomOpDef.dt_string])
def string_segment_join(x, segment_ids, num_segments, separator):
    result = [[] for i in range(num_segments)]
    separator = separator[0]
    for s, seg_id in zip(x, segment_ids):
        result[seg_id].append(s)
    result_joined = [separator.join(l) for l in result]
    return np.array(result_joined, dtype=np.object)

In [19]:
import onnxruntime as ort
from onnxruntime_extensions import get_library_path

so = ort.SessionOptions()
so.register_custom_ops_library(get_library_path())

sess = ort.InferenceSession("model2a.onnx", so)
# Use the input names from the saved_model_cli
print(sess.run(["Identity:0"], {"input_1:0": ["car", "java", "pet", "script"], "input_2:0": [1, 0, 1, 0]}))

sess = ort.InferenceSession("model2b.onnx", so)
# Use the input names from the concrete function
print(sess.run(["Identity:0"], {input_names[0]: ["car", "java", "pet", "script"], input_names[1]: [1, 0, 1, 0]}))

[array(['javascript', 'carpet'], dtype=object)]
[array(['java-script', 'car-pet'], dtype=object)]


## Case 3: Implementing custom ops in C++

Add a conversion rule for your custom op using the instructions in the previous section.  It can be useful to prototype the op in python before developing a C++ version.  Follow the [C++ Custom Ops Tutorial](https://github.com/microsoft/ort-customops/blob/main/tutorials/cpp_custom_ops_tutorial.md) to create a C++ version of the op.