# Meeting 2022-02-21

This notebook describes how to extend JAX to include new primitives. Material was heavily based from this [source](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) in the JAX Docs.

In [1]:
import functools
import traceback
import contextlib

import numpy as np
import jax.numpy as jnp

from jax import jit, grad, jvp, vjp, vmap, pmap, make_jaxpr, lax
from jax.core import Primitive, ShapedArray
from jax._src.lib import xla_client
from jax.interpreters import xla, ad

In [2]:
INDENT = 0
SUPRESS = False

def _trace(msg=None):
    """Prints a message at current indentation."""
    if msg is not None and not SUPRESS:
        print("|   " * INDENT + msg)


def _trace_indent(msg=None):
    """Print a message and increase indentation."""
    global INDENT
    _trace(msg)
    INDENT = INDENT + 1


def _trace_unindent(msg=None):
    """Decrease indentation and print a message."""
    global INDENT
    INDENT = INDENT - 1
    _trace(msg)


def log_calls(name):
    """Decorator that shows when functions are invoked/returned."""
    def trace_func(fn):
        def pp(v):
            """Pretty-print a value."""
            vtype = str(type(v))
            if "jax._src.lib.xla_bridge._JaxComputationBuilder" in vtype:
                return "<JaxComputationBuilder>"
            if "jaxlib.xla_extension.XlaOp" in vtype:
                return f"<XlaOp at 0x{id(v):x}>"
            if ("partial_eval.JaxprTracer" in vtype or
                "batching.BatchTracer" in vtype or
                "ad.JVPTracer" in vtype):
                return f"Tracer<{v.aval}>"
            if isinstance(v, tuple):
                return f"({pp_vals(v)})"

            return str(v)

        def pp_vals(args):
            return ", ".join(pp(arg) for arg in args)

        @functools.wraps(fn)
        def fn_wrapper(*args):
            """Wrapper of fn that shows the calls as they happen."""
            _trace_indent(f"| CALL {name}({pp_vals(args)})")
            res = fn(*args)
            _trace_unindent(f"| RET  {name} = {pp(res)}")
            return res

        return fn_wrapper

    return trace_func


class ExpectNotImplemented:
    def __enter__(self):
        pass

    def __exit__(self, typ, value, tb):
        global INDENT
        INDENT = 0
        if typ is NotImplementedError:
            print(f"\nFound expected exception:")
            traceback.print_exc(limit=3)
            return True
        elif typ is None:
            assert False, "Expected NotImplementedError"
        else:
            return False

@contextlib.contextmanager
def SuppressCallLog():
    global SUPRESS
    SUPRESS = True
    try:
        yield None
    finally:
        SUPRESS = False

# PART 0: Last Week

<br>

<center>
    <img src="images/flows.png" width="80%" />
</center>

<hr />

## PART 1: JAX Primitives and JIT

### Preparing the Ground

Let us define a function that will be reimplemented as a primitive.

In [3]:
@log_calls("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
    return jnp.add(jnp.multiply(x, y), z)

@log_calls("square_add_numpy")
def square_add_numpy(a, b):
    return multiply_add_numpy(a, a, b)

**NORMAL EVALUATION**

- Nested calls to `square_add_numpy` and `multiply_add_numpy`.
- Floating point parameters.
- Returns `DeviceArray` with expected value.

In [4]:
square_add_numpy(2., 10.)

| CALL square_add_numpy(2.0, 10.0)
|   | CALL multiply_add_numpy(2.0, 2.0, 10.0)
|   | RET  multiply_add_numpy = 14.0
| RET  square_add_numpy = 14.0


DeviceArray(14., dtype=float32, weak_type=True)

**GRADIENT EVALUATION**


- Nested calls again.
- Parameters are now `ConcreteArray`s. 🤔
- Returns `DeviceArray` with the gradient evaluated at `(2., 10.)`.

In [5]:
grad(square_add_numpy)(2., 10.)

| CALL square_add_numpy(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   | CALL multiply_add_numpy(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   | RET  multiply_add_numpy = Tracer<ConcreteArray(14.0, dtype=float32, weak_type=True)>
| RET  square_add_numpy = Tracer<ConcreteArray(14.0, dtype=float32, weak_type=True)>


DeviceArray(4., dtype=float32, weak_type=True)

### Creating a new Primitive

- Simply instantiate `Primitive` with a name.
- The method `Primitive.bind` "simply" wraps the arguments in `Tracer`s.

In [6]:
multiply_add_p = Primitive("multiply_add")

@log_calls("multiply_add_prim")
def multiply_add_prim(x, y, z):
    """The JAX-traceable way to use the JAX primitive"""
    return multiply_add_p.bind(x, y, z)

@log_calls("square_add_prim")
def square_add_prim(a, b):
    """The JAX-traceable way to use the JAX primitive"""
    return multiply_add_p.bind(a, a, b)

We created a new primitive but did not tell JAX how it should be evaluated (i.e. which operations it should compute).

In [7]:
with ExpectNotImplemented():
    square_add_prim(2., 10.)

| CALL square_add_prim(2.0, 10.0)

Found expected exception:


Traceback (most recent call last):
  File "/tmp/ipykernel_13341/3035546393.py", line 2, in <module>
    square_add_prim(2., 10.)
  File "/tmp/ipykernel_13341/785201755.py", line 50, in fn_wrapper
    res = fn(*args)
  File "/tmp/ipykernel_13341/3651717411.py", line 11, in square_add_prim
    return multiply_add_p.bind(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented


Let us define an **concrete evaluation** implementation for our primitive.

In [8]:
@log_calls("multiply_add_impl")
def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.

  This function does NOT need to be JAX traceable.
  """
  # Note that we can use the original numpy, which is not JAX traceable
  return np.add(np.multiply(x, y), z)


# Register the implementation of `multiply_add_p` primitive.
multiply_add_p.def_impl(multiply_add_impl)

<function __main__.multiply_add_impl(x, y, z)>

Now we can call it.

> **NOTE**: We are not doing JIT yet, this invocation is simply interpreting.

In [9]:
assert square_add_prim(2., 10.) == 14.

| CALL square_add_prim(2.0, 10.0)
|   | CALL multiply_add_impl(2.0, 2.0, 10.0)
|   | RET  multiply_add_impl = 14.0
| RET  square_add_prim = 14.0


### JIT

In [10]:
with ExpectNotImplemented():
    jit(square_add_prim)(2., 10.)

| CALL square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)

Found expected exception:


Traceback (most recent call last):
  File "/tmp/ipykernel_13341/1315100091.py", line 2, in <module>
    jit(square_add_prim)(2., 10.)
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/api.py", line 430, in cache_miss
    out_flat = xla.xla_call(
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented


Defining the **abstract evaluation** rule.

In [11]:
@log_calls("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
    # Make sure paramters are compatible
    assert xs.shape == ys.shape
    assert xs.shape == zs.shape
    # Inform that the output has the same shape as the inputs
    return ShapedArray(xs.shape, xs.dtype)


# Register the abstract implementation of `multiply_add_p` primitive.
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)

<function __main__.multiply_add_abstract_eval(xs, ys, zs)>

In [12]:
with ExpectNotImplemented():
    jit(square_add_prim)(2., 10.)

| CALL square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
|   | CALL multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|   | RET  multiply_add_abstract_eval = ShapedArray(float32[])
| RET  square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>

Found expected exception:


Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 191, in _run_module_as_main
    msg = "%s: %s" % (sys.executable, exc)
  File "/usr/lib/python3.10/runpy.py", line 75, in _run_code
    fname = mod_spec.origin
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/ipykernel_launcher.py", line 12, in <module>
    if sys.path[0] == '':
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform gpu

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_13341/1315100091.py", line 2, in <module>
    jit(square_add_prim)(2., 10.)
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 1

In [13]:
@log_calls("multiply_add_xla_translation")
def multiply_add_xla_translation(ctx, avals_in, avals_out, xc, yc, zc):
    return [
        xla_client.ops.Add(
            xla_client.ops.Mul(xc, yc),
            zc)
    ]

# Associate the new translation rule with our primitive
xla.register_translation(multiply_add_p, multiply_add_xla_translation, platform='gpu')

In [14]:
assert jit(square_add_prim)(2., 10.) == 14.

| CALL square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
|   | CALL multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|   | RET  multiply_add_abstract_eval = ShapedArray(float32[])
| RET  square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
| CALL multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
| RET  multiply_add_abstract_eval = ShapedArray(float32[])
| CALL multiply_add_xla_translation(TranslationContext(builder=<jaxlib.xla_extension.XlaBuilder object at 0x7f1cbc27e5b0>, platform='gpu', axis_env=AxisEnv(nreps=1, names=(), sizes=()), name_stack=''), [ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), Shaped

In [15]:
assert jit(lambda x, y: square_add_prim(x, y), static_argnums=1)(2., 10.) == 14.

| CALL square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, 10.0)
|   | CALL multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|   | RET  multiply_add_abstract_eval = ShapedArray(float32[])
| RET  square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
| CALL multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
| RET  multiply_add_abstract_eval = ShapedArray(float32[])
| CALL multiply_add_xla_translation(TranslationContext(builder=<jaxlib.xla_extension.XlaBuilder object at 0x7f1c54a83c30>, platform='gpu', axis_env=AxisEnv(nreps=1, names=(), sizes=()), name_stack=''), [ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], [ShapedArray(float32[])], <XlaOp at 0x7f1

JAXPR/IR of the function that users regular `jax.numpy` (i.e. without our primitive).

In [16]:
with SuppressCallLog():
    print(make_jaxpr(square_add_numpy)(2., 10.))
    print("=" * 80)
    print(jit(square_add_numpy).lower(2., 10.).compiler_ir('mhlo'))

{ lambda ; a:f32[] b:f32[]. let c:f32[] = mul a a; d:f32[] = add c b in (d,) }
module @jit_square_add_numpy.11 {
  func public @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
    %0 = mhlo.multiply %arg0, %arg0 : tensor<f32>
    %1 = mhlo.add %0, %arg1 : tensor<f32>
    return %1 : tensor<f32>
  }
}



JAXPR/IR of the function that invokes our newly created primitive.

In [17]:
with SuppressCallLog():
    print(make_jaxpr(square_add_prim)(2., 10.))
    print("=" * 80)
    print(jit(square_add_prim).lower(2., 10.).compiler_ir('mhlo'))

{ lambda ; a:f32[] b:f32[]. let c:f32[] = multiply_add a a b in (c,) }
module @jit_square_add_prim.12 {
  func public @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
    %0 = call @multiply_add(%arg0, %arg0, %arg1) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
    return %0 : tensor<f32>
  }
  func private @multiply_add(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
    %0 = call @xla_fallback_multiply_add(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
    return %0 : tensor<f32>
  }
  func private @xla_fallback_multiply_add(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
    %0 = mhlo.constant dense<false> : tensor<i1>
    %1 = mhlo.multiply %arg0, %arg1 : tensor<f32>
    %2 = mhlo.add %1, %arg2 : tensor<f32>
    return %2 : tensor<f32>
  }
}



### RECAP:

| Step | Description | API |
|:----:|:------------|:----|
| 1 | Create new primitive object | `Primitive(NAME)` |
| 2 | Define concrete evaluation rule | `PRIMITIVE.def_impl(IMPL_FN)` |
| 3 | Define abstract evaluation rule | `PRIMITIVE.def_abstract_eval(ABS_EVAL_FN)` |
| 4 | Define translation rule | `xla.register_translation(PRIMITIVE, TRANSLATION_FN)` |

### JAX Core

<center>
    <img src="images/jax_core.png" alt="" width="100%"/>
</center>

<hr />

## PART 2: JAX Primitives and Autodiff

In [18]:
with ExpectNotImplemented():
    jvp(square_add_prim, (2., 10.), (1., 1.))

| CALL square_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(10.0, dtype=float32, weak_type=True)>)

Found expected exception:


Traceback (most recent call last):
  File "/tmp/ipykernel_13341/1685878697.py", line 2, in <module>
    jvp(square_add_prim, (2., 10.), (1., 1.))
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/api.py", line 2280, in jvp
    return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/api.py", line 2309, in _jvp
    out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Differentiation rule for 'multiply_add' not implemented


We have this function:

$$ f(x, y, z) = x \cdot y + z $$

And we are computing its JVP as:

$$ g(x, y, z) = x_t \cdot y + (x \cdot y_t + z_t) $$

What is happening here?

In [19]:
@log_calls("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
    """Evaluates the primal output and the tangents (Jacobian-vector product).

    Given values of the arguments and perturbation of the arguments (tangents), 
    compute the output of the primitive and the perturbation of the output.

    This method must be JAX-traceable. JAX may invoke it with abstract values 
    for the arguments and tangents.
    """
    x,  y,  z  = arg_values
    xt, yt, zt = arg_tangents
    
    _trace(">>> Primal evaluation:")
    primal_out = multiply_add_prim(x, y, z)
    
    def make_zero(tan):
        return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
    
    _trace(">>> Tangent evaluation:")
    output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
    
    return primal_out, output_tangent

# Register JVP rule for out `multiply_add_p` primitive
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp

In [20]:
assert jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)

| CALL square_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
|   | CALL multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
|   |   >>> Primal evaluation:
|   |   | CALL multiply_add_prim(2.0, 2.0, 10.0)
|   |   |   | CALL multiply_add_impl(2.0, 2.0, 10.0)
|   |   |   | RET  multiply_add_impl = 14.0
|   |   | RET  multiply_add_prim = 14.0
|   |   >>> Tangent evaluation:
|   |   | CALL multiply_add_prim(2.0, 1.0, 1.0)
|   |   |   | CALL multiply_add_impl(2.0, 1.0, 1.0)
|   |   |   | RET  multiply_add_impl = 3.0
|   |   | RET  multiply_add_prim = 3.0
|   |   | CALL multiply_add_prim(1.0, 2.0, 3.0)
|   |   |   | CALL multiply_add_impl(1.0, 2.0, 3.0)
|   |   |   | RET  multiply_add_impl = 5.0
|   |   | RET  multiply_add_prim = 5.0
|   | RET  multiply_add_value_and_jvp = (14.0, 5.0)
| RET  square_add_prim = Tracer<ConcreteArray(14.0, dtype=float32)>


In [21]:
with ExpectNotImplemented():
  grad(square_add_prim)(2., 10.)

| CALL square_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   | CALL multiply_add_value_and_jvp((Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0), (Tracer<ShapedArray(float32[], weak_type=True)>, Tracer<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
|   |   >>> Primal evaluation:
|   |   | CALL multiply_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   |   |   | CALL multiply_add_impl(2.0, 2.0, 10.0)
|   |   |   | RET  multiply_add_impl = 14.0
|   |   | RET  multiply_add_prim = 14.0
|   |   >>> Tangent evaluation:
|   |   | CALL multiply_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ShapedArray(float32[], weak_type=True)>, 0.0)
|   |   |   | CALL multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), Shape

Traceback (most recent call last):
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/interpreters/ad.py", line 258, in get_primitive_transpose
    return primitive_transposes[p]
KeyError: multiply_add

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 191, in _run_module_as_main
    msg = "%s: %s" % (sys.executable, exc)
  File "/usr/lib/python3.10/runpy.py", line 75, in _run_code
    fname = mod_spec.origin
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/ipykernel_launcher.py", line 12, in <module>
    if sys.path[0] == '':
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The ab

In [22]:
@log_calls("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
    """Evaluates the transpose of a linear primitive.

    This method is only used when computing the backward gradient following 
    value_and_jvp, and is only needed for primitives that are used in the JVP 
    calculation for some other primitive. We need transposition for multiply_add_prim, 
    because we have used multiply_add_prim in the computation of the output_tangent in 
    multiply_add_value_and_jvp.

    In our case, multiply_add is not a linear primitive. However, it is used linearly 
    w.r.t. tangents in multiply_add_value_and_jvp:
       output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
  
    Always one of the first two multiplicative arguments is a constant.
    """
    if not ad.is_undefined_primal(x):
        # This use of multiply_add is with a constant "x"
        assert ad.is_undefined_primal(y)
        ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
        return None, ct_y, ct
    else:
        # This use of multiply_add is with a constant "y"
        assert ad.is_undefined_primal(x)
        ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
        return ct_x, None, ct
        
# Register transpose rule for `multiply_add_p` primitive.
ad.primitive_transposes[multiply_add_p] = multiply_add_transpose

In [23]:
assert grad(square_add_prim)(2., 10.) == 4.

| CALL square_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   | CALL multiply_add_value_and_jvp((Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0), (Tracer<ShapedArray(float32[], weak_type=True)>, Tracer<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
|   |   >>> Primal evaluation:
|   |   | CALL multiply_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   |   |   | CALL multiply_add_impl(2.0, 2.0, 10.0)
|   |   |   | RET  multiply_add_impl = 14.0
|   |   | RET  multiply_add_prim = 14.0
|   |   >>> Tangent evaluation:
|   |   | CALL multiply_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ShapedArray(float32[], weak_type=True)>, 0.0)
|   |   |   | CALL multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), Shape

### RECAP

Consider a composition of functions $f(g(h(\dots(z(x))))$. We normally evaluate such a function from inside-out, first calculate $z(x)$ then use the result as input to the next innermost function until we get to $f$. Similarly for evaluating functions, we decide to compute the derivative from inside-out or outside-in. These are known as forward mode autodiff and backward mode autodiff, respectively.

<br>

<center>
    <img src="images/fwd_bwd_ad.png", alt="autodiff" width="30%" />
</center>

- Is this correct or is it inverted?

<hr />

## PART 3: JAX Primitives and Batching

In [24]:
# TODO

### RECAP

<center>
    <img src="images/batching.png" alt="batching" width="50%" />
</center>

<hr />

## PART 4: Summary

1. Create a primitive:

```python
my_primitive_p = Primitive("my_primitive")
```

2. Define a concrete evaluation rule for interpretation:

```python
my_primitive_p.def_impl(my_primitive_impl)
```

3. Define abstract evaluation rule for JIT compilation:

```python
my_primitive_p.def_abstract_eval(my_primitive_abs_eval)
```

4. Define XLA translation rule for JIT compilation:

```python
jax.interpreters.xla.register_translation(
    my_primitive_p,
    my_primitive_xla_compile_gpu,
    platform='gpu')
```

5. Define Jacobian-Vector Product (JVP) rule for forward AD:

```python
jax.interpreters.ad.primitive_jvps[my_primitive_p] = my_primitive_value_and_jvp
```

6. Define transpose rule for backward AD:

```python
jax.interpreters.ad.primitive_transposes[my_primitive_p] = my_primitive_transpose
```

7. Define batching rule:

```python
jax.interpreters.batching.primitive_batchers[my_primitive_p] = my_primitive_batch
```