<a href="https://colab.research.google.com/github/ghpvnist/stablehlo/blob/tutorial/composite_e2e_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# The StableHLO Composite Primitive

In [None]:
import jax
from jax import api_util
from jax import tree_util
from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir import ir
from jax._src.lib.mlir import passmanager as pm
from jax.interpreters import mlir


composite_p = jax_core.CallPrimitive('composite')
def _composite_impl(f, *args, **_):
  with jax_core.new_sublevel():
    return f.call_wrapped(*args)
composite_p.def_impl(_composite_impl)


def call_composite(f, *args, name: str, attributes: dict = {}, **kwargs):
  fun = lu.wrap_init(f, kwargs)
  flat_args, in_tree = tree_util.tree_flatten(args)
  flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
  out_flat = composite_p.bind(
      flat_fun,
      *flat_args,
      name=name,
      attributes=attributes)
  return tree_util.tree_unflatten(out_tree(), out_flat)

# Make custom_call which calls the implementation function
# Currently this leaks a CallOp since we're using the `core_call_lowering`
# function, but this should get cleaned up by DCE easily.
def _composite_stablehlo_lowering(ctx, *args, name, call_jaxpr, attributes, **kwargs):
  impl = mlir.core_call_lowering(
      ctx,
      *args,
      name=name+".impl",
      call_jaxpr=call_jaxpr)
  print(impl)
  call_op = impl[0][0].owner
  called_fn = call_op.attributes['callee']

  composite = hlo.CompositeOp(
    [r.type for r in call_op.results],
    call_op.operands,
    name=ir.StringAttr.get(name),
    composite_attributes=ir.DictAttr.get(attributes),
    decomposition=called_fn,
  )
  return composite.results

## Helper function to inline module to get rid of all the extra function calls.
## Not necessary, but makes it easier to view the composite's IR.
def _inline_ir(m_original: ir.Module) -> bytes:
  with m_original.context:
    m = m_original.operation.clone()
    passes = pm.PassManager.parse("builtin.module(inline)")
    passes.run(m.operation)
    return m

mlir.register_lowering(composite_p, _composite_stablehlo_lowering)

# Create composite calls

In [None]:
import numpy as np

def mySquared(x):
  return x * x

def myMain():
  return call_composite(mySquared, 2, name="my.squared")

print("Value from jaxpr eval:", myMain())

# DO NOT ATTEMPT TO EVALUATE THIS JITTED FUNCTION
# Since the backend doesn't handle these custom calls yet
# this will cause a crash during eval on "unknown custom call target 'stablehlo.composite'"
# Adding a pass in StableHLO which converts this custom call to CallOp would resolve.
# mlir_module = jax.jit(myMain).lower()._lowering.stablehlo()
# print("MLIR Module:")
# print(_inline_ir(mlir_module))

Value from jaxpr eval: 4


# Syntactic Sugar

In [None]:
import functools

def composite_decl(name):
  def wrapper(f):
    return functools.partial(call_composite, f, name=name)
  return wrapper

# Example 1: GELU Composite

In [None]:
import jax.numpy as jnp
from jax._src.lax import lax
from jax._src import dtypes

# GELU from JAX BERT implementation
@composite_decl("my.gelu")
def gelu(x):
  sqrt_2 = np.sqrt(2)
  return jnp.array(x * (erf(x / sqrt_2) + 1) / 2)

def myMain():
  return gelu(np.ones([4], dtype=np.float32))

# JAX eval:
print("JAX Eval:", myMain())

# Get StableHLO export:
# mlir_module = jax.jit(myMain).lower()._lowering.stablehlo()
# print("MLIR Module:")
# print(_inline_ir(mlir_module))

JAX Eval: [0.8413447 0.8413447 0.8413447 0.8413447]


In [None]:
from functools import partial

# from jax._src.lax import _nary_lower_hlo, _float, _complex, standard_unop
from jax._src.typing import Array, ArrayLike

my_acos_p = lax.standard_unop(lax._float | lax._complex, 'my_acos')

def my_acos(x: ArrayLike) -> Array:
  return my_acos_p.bind(x)

mlir.register_lowering(my_acos_p, partial(lax._nary_lower_hlo, hlo.composite))

functools.partial(<function _nary_lower_hlo at 0x7ceef4364670>, <function composite at 0x7cef05840940>)

# Composite Acos

In [None]:
from absl.testing import absltest
import jax
from jax import export as jax_export
from jax._src import test_util as jtu
from jax._src.lax import lax
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo as stablehlo
from jax.interpreters import mlir
import jax.numpy as jnp
from jax._src.typing import ArrayLike, Array

# Step 1: Define a jax primitive.
my_acos_p = lax.standard_unop(lax._float | lax._complex, "my_acos")


# Step 2: Define a jax api.
def my_acos(x: ArrayLike) -> Array:
  return my_acos_p.bind(x)


# Step 3: Define auto diff rule.
lax.ad.defjvp(
    my_acos_p,
    lambda g, x: lax.mul(g, -lax.rsqrt(lax._const(x, 1) - lax.square(x))),
)


# Step 4: Define lowering to stablehlo.composite.
def _composite_acos_lowering(
    ctx: mlir.LoweringRuleContext, arg: mlir.ir.BlockArgument
) -> mlir.ir.OpResultList:

  @jax.jit
  def my_acos_impl(x: ArrayLike) -> Array:
    return jnp.acos(x)

  # TODO(gunhyun): this implementation leaks a CallOp.
  lowered_fun = mlir.lower_fun(my_acos_impl, multiple_results=False)
  call_op = lowered_fun(ctx, arg)[0].owner

  composite = stablehlo.CompositeOp(
      [result.type for result in call_op.results],
      call_op.operands,
      name=ir.StringAttr.get("chlo.acos"),
      composite_attributes=ir.DictAttr.get({}),
      decomposition=call_op.attributes["callee"],
  )
  return composite.results


# Step 5: Register your custom composite lowering to stablehlo.composite.
mlir.register_lowering(my_acos_p, _composite_acos_lowering)

In [None]:
@jax.jit
def f(x: ArrayLike) -> Array:
  return my_acos(x)

x = jnp.array(1.0, dtype=jnp.float32)

print(jnp.acos(x), f(x))

mlir_module = jax_export.export(f)(x).mlir_module()
print(mlir_module)
# self.assertIn('stablehlo.composite "chlo.acos"', mlir_module)

0.0 0.0
#loc1 = loc("x")
#loc2 = loc("<ipython-input-35-7c9fd849e790>":7:0)
#loc3 = loc("<ipython-input-57-b5d2814a3f38>":3:0)
#loc4 = loc("<ipython-input-57-b5d2814a3f38>":7:0)
#loc5 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3553:0)
#loc6 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3473:0)
#loc7 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3257:0)
#loc8 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py":78:0)
#loc9 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3030:0)
#loc10 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":2975:0)
#loc11 = loc("/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py":539:0)
#loc12 = loc("my_acos"(#loc2))
#loc13 = loc("f"(#loc3))
#loc14 = loc("<cell line: 7>"(#loc4))
#loc15 = loc("run_code"(#loc5))
#loc16 = loc("run_ast_nodes"(#loc6))
#loc17 = 

# StableHLO Dump

```
module @jit_f attributes {
  func.func public @main(%arg0: tensor<f32>) -> (tensor<f32>) {
    %0 = stablehlo.composite "chlo.acos" %arg0 {decomposition = @my_acos_impl} : (tensor<f32>) -> tensor<f32>
    return %0 : tensor<f32>
  }
  func.func private @my_acos_impl(%arg0: tensor<f32>) -> (tensor<f32>) {
    %cst = stablehlo.constant dense<-1.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  NE, %arg0, %cst,  FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %1 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.subtract %cst_0, %1 : tensor<f32>
    %3 = stablehlo.sqrt %2 : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %4 = stablehlo.add %cst_1, %arg0 : tensor<f32>
    %5 = stablehlo.atan2 %3, %4 : tensor<f32>
    %cst_2 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
    %6 = stablehlo.multiply %cst_2, %5 : tensor<f32>
    %cst_3 = stablehlo.constant dense<3.14159274> : tensor<f32>
    %7 = stablehlo.select %0, %6, %cst_3 : tensor<i1>, tensor<f32>
    return %7 : tensor<f32>
  }
}
```

# HLO Dump

```
HloModule jit_f, entry_computation_layout={(f32[])->f32[]}

%my_acos_impl.2 (Arg_0.3: f32[]) -> f32[] {
  %Arg_0.3 = f32[] parameter(0)
  %constant.4 = f32[] constant(-1)
  %compare.5 = pred[] compare(f32[] %Arg_0.3, f32[] %constant.4), direction=NE
  %constant.13 = f32[] constant(2)
  %constant.7 = f32[] constant(1)
  %multiply.6 = f32[] multiply(f32[] %Arg_0.3, f32[] %Arg_0.3)
  %subtract.8 = f32[] subtract(f32[] %constant.7, f32[] %multiply.6)
  %sqrt.9 = f32[] sqrt(f32[] %subtract.8)
  %constant.10 = f32[] constant(1)
  %add.11 = f32[] add(f32[] %constant.10, f32[] %Arg_0.3)
  %atan2.12 = f32[] atan2(f32[] %sqrt.9, f32[] %add.11)
  %multiply.14 = f32[] multiply(f32[] %constant.13, f32[] %atan2.12)
  %constant.15 = f32[] constant(3.14159274)
  ROOT %select.16 = f32[] select(pred[] %compare.5, f32[] %multiply.14, f32[] %constant.15)
}

%my_acos_impl.18 (Arg_0.19: f32[]) -> f32[] {
  %Arg_0.19 = f32[] parameter(0)
  %constant.20 = f32[] constant(-1)
  %compare.21 = pred[] compare(f32[] %Arg_0.19, f32[] %constant.20), direction=NE
  %constant.29 = f32[] constant(2)
  %constant.23 = f32[] constant(1)
  %multiply.22 = f32[] multiply(f32[] %Arg_0.19, f32[] %Arg_0.19)
  %subtract.24 = f32[] subtract(f32[] %constant.23, f32[] %multiply.22)
  %sqrt.25 = f32[] sqrt(f32[] %subtract.24)
  %constant.26 = f32[] constant(1)
  %add.27 = f32[] add(f32[] %constant.26, f32[] %Arg_0.19)
  %atan2.28 = f32[] atan2(f32[] %sqrt.25, f32[] %add.27)
  %multiply.30 = f32[] multiply(f32[] %constant.29, f32[] %atan2.28)
  %constant.31 = f32[] constant(3.14159274)
  ROOT %select.32 = f32[] select(pred[] %compare.21, f32[] %multiply.30, f32[] %constant.31)
}

ENTRY %main.34 (Arg_0.1: f32[]) -> f32[] {
  %Arg_0.1 = f32[] parameter(0)
  %call.17 = f32[] call(f32[] %Arg_0.1), to_apply=%my_acos_impl.2
  ROOT %call.33 = f32[] call(f32[] %Arg_0.1), to_apply=%my_acos_impl.18, is_composite=true, frontend_attributes={composite.attributes={},composite.name="chlo.acos",composite.version="0"}
}
```