<a href="https://colab.research.google.com/github/ghpvnist/stablehlo/blob/tutorial/composite_e2e_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# The StableHLO Composite Primitive

In [1]:
# @title Hidden
%%capture
'''
Disclaimer:

There is an official LAX API coming soon. This is a POC implementation to show
what it would look like.
'''

import jax

from jax import api_util
from jax import lax
from jax import tree_util
from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir import passmanager as pm
from jax._src.lib.mlir.dialects import hlo as stablehlo
from jax.interpreters import mlir


composite_p = jax_core.CallPrimitive('composite')
def _composite_impl(f, *args, **_):
  with jax_core.new_sublevel():
    return f.call_wrapped(*args)
composite_p.def_impl(_composite_impl)


def _composite_lax_impl(f, *args, name: str, attributes: dict = {}, **kwargs):
  fun = lu.wrap_init(f, kwargs)
  flat_args, in_tree = tree_util.tree_flatten(args)
  flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
  out_flat = composite_p.bind(
      flat_fun,
      *flat_args,
      name=name,
      attributes=attributes)
  return tree_util.tree_unflatten(out_tree(), out_flat)

# Make custom_call which calls the implementation function
# Currently this leaks a CallOp since we're using the `core_call_lowering`
# function, but this should get cleaned up by DCE easily.
def _composite_stablehlo_lowering(ctx, *args, name, call_jaxpr, attributes, **kwargs):
  impl = mlir.core_call_lowering(
      ctx,
      *args,
      name=name+".impl",
      call_jaxpr=call_jaxpr)
  # print(impl)
  # print(call_op = impl[0].owner)
  call_op = impl[0].owner
  # call_op = impl[0][0].owner
  called_fn = call_op.attributes['callee']

  composite = stablehlo.CompositeOp(
    [r.type for r in call_op.results],
    call_op.operands,
    name=ir.StringAttr.get(name),
    composite_attributes=ir.DictAttr.get(attributes),
    decomposition=called_fn,
  )
  return composite.results

## Helper function to inline module to get rid of all the extra function calls.
## Not necessary, but makes it easier to view the composite's IR.
def _inline_ir(m_original: ir.Module) -> bytes:
  with m_original.context:
    m = m_original.operation.clone()
    passes = pm.PassManager.parse("builtin.module(inline)")
    passes.run(m.operation)
    return m

mlir.register_lowering(composite_p, _composite_stablehlo_lowering)

# Experimental Composite API

In [2]:
# Official LAX API coming soon...
def composite(*args, name, decomposition, attributes = {}):
  return _composite_lax_impl(decomposition, *args, name=name, attributes=attributes)

# Hack to mimic said API
lax.composite = composite

# Example 1: Composite `my.gelu`

In [3]:
import jax.numpy as jnp
import numpy as np
from jax._src.lax.special import erf
from jax._src import dtypes


def my_gelu(x):
  def my_gelu_impl(x):
    return jnp.array(x * (erf(x / np.sqrt(2)) + 1) / 2)

  return lax.composite(x, name="my.gelu", decomposition=my_gelu_impl)


@jax.jit
def module():
  x = jnp.array(1.0, dtype=jnp.float32)
  return my_gelu(x)


# JAX eval:
print("JAX Eval:", module())


# Get StableHLO export (using internal API):
mlir_module = module.lower()._lowering.stablehlo()
print(_inline_ir(mlir_module))

JAX Eval: 0.8413447
module @jit_module attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main() -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %0 = call @my.gelu.impl(%cst) : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.composite "my.gelu" %cst {decomposition = @my.gelu.impl} : (tensor<f32>) -> tensor<f32>
    return %1 : tensor<f32>
  }
  func.func private @my.gelu.impl(%arg0: tensor<f32>) -> tensor<f32> {
    %cst = stablehlo.constant dense<2.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.41421354> : tensor<f32>
    %0 = stablehlo.divide %arg0, %cst_1 : tensor<f32>
    %1 = chlo.erf %0 : tensor<f32> -> tensor<f32>
    %2 = stablehlo.add %1, %cst_0 : tensor<f32>
    %3 = stablehlo.multiply %arg0, %2 : tensor<f32>
    %4 = stablehlo.divide %3, %cst : tensor<f32>
  

# Example 2: Composite `my.acos`

In [4]:
def my_acos(x):
  return lax.composite(x, name="my.acos", attributes={}, decomposition=jnp.acos)


@jax.jit
def module():
  x = jnp.array(1.0, dtype=jnp.float32)
  return my_acos(x)


# JAX eval:
print("JAX Eval:", module())

# Get StableHLO export (using internal API):
mlir_module = module.lower()._lowering.stablehlo()
print(_inline_ir(mlir_module))

JAX Eval: 0.0
module @jit_module attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main() -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %0 = call @my.acos.impl(%cst) : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.composite "my.acos" %cst {decomposition = @my.acos.impl} : (tensor<f32>) -> tensor<f32>
    return %1 : tensor<f32>
  }
  func.func private @my.acos.impl(%arg0: tensor<f32>) -> tensor<f32> {
    %cst = stablehlo.constant dense<3.14159274> : tensor<f32>
    %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %cst_2 = stablehlo.constant dense<-1.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  NE, %arg0, %cst_2,  FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %1 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    %2 = stablehlo.subtract %cst_1, %1 : t

# (Not recommended) Composite `my.acos` using internal APIs for `grad`/`vmap`

In [5]:
# @title Hidden
%%capture
import jax
import jax.numpy as jnp

from absl.testing import absltest
from jax._src import test_util as jtu
from jax._src.lax import lax
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo as stablehlo
from jax._src.typing import ArrayLike, Array
from jax.interpreters import mlir

'''
Disclaimer:

This should NOT be used in production. Only proceed should you need to
experiment with composites while the official API is WIP.
'''

# Step 1: Define a jax primitive.
my_acos_p = lax.standard_unop(lax._float | lax._complex, "my_acos")


# Step 2: Define a jax api.
def my_acos(x: ArrayLike) -> Array:
  return my_acos_p.bind(x)


# Step 3 (Optional): Define auto diff rule if you need gradient.
lax.ad.defjvp(
    my_acos_p,
    lambda g, x: lax.mul(g, -lax.rsqrt(lax._const(x, 1) - lax.square(x))),
)


# Step 4: Define lowering to stablehlo.composite.
def _composite_acos_lowering(
    ctx: mlir.LoweringRuleContext, arg: mlir.ir.BlockArgument
) -> mlir.ir.OpResultList:

  @jax.jit
  def my_acos_impl(x: ArrayLike) -> Array:
    return jnp.acos(x)

  # TODO: this implementation leaks a CallOp.
  lowered_fun = mlir.lower_fun(my_acos_impl, multiple_results=False)
  call_op = lowered_fun(ctx, arg)[0].owner

  composite = stablehlo.CompositeOp(
      [result.type for result in call_op.results],
      call_op.operands,
      name=ir.StringAttr.get("chlo.acos"),
      composite_attributes=ir.DictAttr.get({}),
      decomposition=call_op.attributes["callee"],
  )
  return composite.results


# Step 5: Register your custom composite lowering to stablehlo.composite.
mlir.register_lowering(my_acos_p, _composite_acos_lowering)


# Step 6: Print module
@jax.jit
def myMain():
  x = jnp.array(1.0, dtype=jnp.float32)
  return my_acos(x)

mlir_module = myMain.lower()._lowering.stablehlo()
print(_inline_ir(mlir_module))

# Print module using JAX Export API

In [6]:
# @title Hidden
%%capture

# Hack to mimic LAX API
lax.composite = composite

from jax import export as jax_export

# Official export API
x = jnp.array(1.0, dtype=jnp.float32)

@jax.jit
def myMain():
  return my_acos(x)

print("Ref impl:", jnp.acos(x))
print("My impl:", my_acos(x))

mlir_module = jax_export.export(myMain)().mlir_module()
print(mlir_module)


@jax.jit
def myMain():
  x = jnp.array(1.0, dtype=jnp.float32)
  return my_gelu(x)

print("My impl:", my_gelu(x))

mlir_module = jax_export.export(myMain)().mlir_module()
print(mlir_module)