# Extracting Intermediate Values

This notebook shows how to extract intermediate values from a model using `save_inter` and the `pull` transformation. 

This is useful for debugging and for understanding how a model works.

In [85]:
import jax
import jax.numpy as jnp
import statax as stx

In [86]:
def func(x):
  y = 4 * x
  y = stx.save_inter(y, name="y")
  # equivalently:
  stx.save_inter(y + 3, name="y")
  return jnp.sin(x) + 0.5 * y

When called without the `pull` transformation, the function behaves as if `save_inter` behaves is the identity operation.

In [87]:
func(1.0)

Array(2.841471, dtype=float32, weak_type=True)

When the function is wrapped in a `pull` transformation, it returns a tuple containing its normal output, along with a dictionary of intermediate values. The keys of the dictionary are the names of the intermediate values, appended with "_n" to distinguish duplicate names, and the values are the values of the intermediate values.

Here we save out `y` which should be `4.0`

In [88]:
stx.pull(func)(1.0)

(Array(2.841471, dtype=float32, weak_type=True),
 {'y_0': Array(4., dtype=float32, weak_type=True),
  'y_1': Array(7., dtype=float32, weak_type=True)})

This works for arbitrary pytrees. Here we save out the `y` and `x` values in a tuple.

In [89]:
def func(x):
  y = 4 * x
  stx.save_inter((y, x), name="yx")
  return jnp.sin(x) + 0.5 * y

stx.pull(func)(1.0)

(Array(2.841471, dtype=float32, weak_type=True),
 {'yx_0': (Array(4., dtype=float32, weak_type=True), 1.0)})

In [90]:
jax.jit(stx.pull(func))(1.0)

(Array(2.841471, dtype=float32, weak_type=True),
 {'yx_0': (Array(4., dtype=float32, weak_type=True),
   Array(1., dtype=float32, weak_type=True))})

We can compose with other JAX transformations, like `jit`, `grad` and `vmap`.

Note: we need to set `grad`'s `has_aux` when wrapping a `pull`ed function.

In [91]:
stx.pull(jax.jit(func))(1.0)

(Array(2.841471, dtype=float32, weak_type=True),
 {'yx_0': (Array(4., dtype=float32, weak_type=True),
   Array(1., dtype=float32, weak_type=True))})

In [92]:
jax.jit(stx.pull(func))(1.0)

(Array(2.841471, dtype=float32, weak_type=True),
 {'yx_0': (Array(4., dtype=float32, weak_type=True),
   Array(1., dtype=float32, weak_type=True))})

In [93]:
stx.pull(jax.grad(func))(1.0)

(Array(2.5403023, dtype=float32, weak_type=True),
 {'yx_0': (Array(4., dtype=float32, weak_type=True), 1.0)})

In [94]:
jax.grad(stx.pull(func), has_aux=True)(1.0)

(Array(2.5403023, dtype=float32, weak_type=True),
 {'yx_0': (Array(4., dtype=float32, weak_type=True), 1.0)})

In [95]:
stx.pull(jax.vmap(func))(jnp.ones(10))

(Array([2.841471, 2.841471, 2.841471, 2.841471, 2.841471, 2.841471,
        2.841471, 2.841471, 2.841471, 2.841471], dtype=float32),
 {'yx_0': (Array([4., 4., 4., 4., 4., 4., 4., 4., 4., 4.], dtype=float32),
   Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32))})

In [96]:
jax.vmap(stx.pull(func))(jnp.ones(10))

(Array([2.841471, 2.841471, 2.841471, 2.841471, 2.841471, 2.841471,
        2.841471, 2.841471, 2.841471, 2.841471], dtype=float32),
 {'yx_0': (Array([4., 4., 4., 4., 4., 4., 4., 4., 4., 4.], dtype=float32),
   Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32))})

# Higher order primitives

Higher order primitives which perform various operations on functions are supported, but require handlers to be defined for them. Currently only `jax.jit` is implemented, and provides an example implementation.

Missing primitives can be implemented by adding a handler of the form `Callable[[JaxprEqn],
tuple[JaxprEqn, tuple[StateMeta, ...]]]`

As an example of manually handling an unimplemented primitive (`lax.cond`):


In [97]:
def inner(a):
  stx.save_inter(a)
  return a

def outer(a, b):
  return jax.lax.cond(b, inner, inner, a)

stx.pull(outer)(jnp.zeros(10) + 3, True)

(Array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.], dtype=float32), {})

This fails as the functions passed to cond aren't parsed by the custom interpreter. We can fix this by revealing the intermediate values inside the inner function, passing them out, then rehiding them with another `save_inter` call.

In [98]:
def wrapped_inner(*args):
  result, inters = stx.pull(inner)(*args)
  return result, inters

def outer_updated(a, b):
  result, inters = jax.lax.cond(b, wrapped_inner, wrapped_inner, a)
  stx.save_inter(tuple(inters.values()))
  return result

stx.pull(outer_updated)(jnp.zeros(10) + 3, True)

(Array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.], dtype=float32),
 {'state_0': (Array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.], dtype=float32),)})