In [2]:
from jax import core
multiply_add_p = core.Primitive("multiply_add")  # Create the primitive

# @trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
  """The JAX-traceable way to use the JAX primitive.

  Note that the traced arguments must be passed as positional arguments
  to `bind`.
  """
  return multiply_add_p.bind(x, y, z)

# @trace("square_add_prim")
def square_add_prim(a, b):
  """A square-add function implemented using the new JAX-primitive."""
  return multiply_add_prim(a, a, b)

In [4]:
square_add_prim(2., 10.)

NotImplementedError: Evaluation rule for 'multiply_add' not implemented

In [6]:
import numpy as np

In [7]:
def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.

  This function does not need to be JAX traceable.
  Args:
    x, y, z: the concrete arguments of the primitive. Will only be called with
      concrete values.
  Returns:
    the concrete result of the primitive.
  """
  # Note that we can use the original numpy, which is not JAX traceable
  return np.add(np.multiply(x, y), z)

# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)

<function __main__.multiply_add_impl(x, y, z)>

In [8]:
assert square_add_prim(2., 10.) == 14.


In [12]:
square_add_prim(2., 10.)

14.0

In [10]:
from jax._src import api


In [11]:
api.jit(square_add_prim)(2., 10.)

NotImplementedError: Abstract evaluation for 'multiply_add' not implemented

In [13]:
from jax._src import abstract_arrays
# @trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
  """Abstract evaluation of the primitive.

  This function does not need to be JAX traceable. It will be invoked with
  abstractions of the actual arguments.
  Args:
    xs, ys, zs: abstractions of the arguments.
  Result:
    a ShapedArray for the result of the primitive.
  """
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape
  return abstract_arrays.ShapedArray(xs.shape, xs.dtype)

# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)

<function __main__.multiply_add_abstract_eval(xs, ys, zs)>

In [14]:
api.jit(square_add_prim)(2., 10.)

NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu

In [15]:
from jax._src.lib import xla_client
# @trace("multiply_add_xla_translation")
def multiply_add_xla_translation(ctx, avals_in, avals_out, xc, yc, zc):
  """The compilation to XLA of the primitive.

  Given an XlaBuilder and XlaOps for each argument, return the XlaOp for the
  result of the function.

  Does not need to be a JAX-traceable function.
  """
  return [xla_client.ops.Add(xla_client.ops.Mul(xc, yc), zc)]

# Now we register the XLA compilation rule with JAX
# TODO: for GPU? and TPU?
from jax.interpreters import xla
xla.register_translation(multiply_add_p, multiply_add_xla_translation, platform='cpu')

In [16]:
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.

In [17]:
api.jvp(square_add_prim, (2., 10.), (1., 1.))

NotImplementedError: Differentiation rule for 'multiply_add' not implemented

In [18]:
from jax import lax


In [32]:
from jax.interpreters import ad


# @trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
  """Evaluates the primal output and the tangents (Jacobian-vector product).

  Given values of the arguments and perturbation of the arguments (tangents),
  compute the output of the primitive and the perturbation of the output.

  This method must be JAX-traceable. JAX may invoke it with abstract values
  for the arguments and tangents.

  Args:
    arg_values: a tuple of arguments
    arg_tangents: a tuple with the tangents of the arguments. The tuple has
      the same length as the arg_values. Some of the tangents may also be the
      special value ad.Zero to specify a zero tangent.
  Returns:
     a pair of the primal output and the tangent.
  """
  x, y, z = arg_values
  xt, yt, zt = arg_tangents
  # _trace("Primal evaluation:")
  # Now we have a JAX-traceable computation of the output.
  # Normally, we can use the ma primtive itself to compute the primal output.
  primal_out = multiply_add_prim(x, y, z)

  # _trace("Tangent evaluation:")
  # We must use a JAX-traceable way to compute the tangent. It turns out that
  # the output tangent can be computed as (xt * y + x * yt + zt),
  # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.

  # We do need to deal specially with Zero. Here we just turn it into a
  # proper tensor of 0s (of the same shape as 'x').
  # An alternative would be to check for Zero and perform algebraic
  # simplification of the output tangent computation.
  def make_zero(tan):
    return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
  print(xt, yt, zt)
  print(make_zero(xt))
  output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
  return (primal_out, output_tangent)

# Register the forward differentiation rule with JAX
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp

In [34]:
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 2.)) == (14., 5.)

1.0 1.0 2.0
1.0


AssertionError: 

In [25]:
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
api.jvp(square_add_prim, (2., 10.), (2., 1.))

2.0
2.0


(14.0, 9.0)

In [31]:
ad.Zero()


Zero(2)