# Extracting Intermediate Values

This notebook shows how to thread state using the Statax `set_state` and `get_state` primitives along with the `stateful` transformation. 

In [48]:
import jax
import jax.numpy as jnp
import statax as stx

`set_state` acts as a no-op and behaves essentially the same as `save_inter`, but requires a name is given for the state.

`get_state` also requires a name in order to access the state, but also requires an initialiser. This is a callable which takes no arguments and returns the initial state.

In [49]:
def f(x):
  y = stx.get_state(name="y", init=lambda: 4 * jnp.ones(()))
  z = x * y
  z = stx.set_state(z, name="z")
  return z

f(2.0)

Array(8., dtype=float32)

This works as expected, with the initialiser providing the value for `y` and `set_state` having no impact.

When instead we use the `stateful` transformation, the initialiser is used again but the state is threaded out of the function, in the form of a dictionary.

In [50]:
f_state = stx.stateful(f)
f_state(1.0)

(Array(4., dtype=float32), {'z': Array(4., dtype=float32)})

Note, that we only have "z" in the output state dictionary and not "y". This is an efficiency measure - as we never set "y" we know the output will always be the same as the input/ initialiser, so we can avoid threading it out of the function.

To force output of all states, we can set `output_unchanged` in the stateful transformation. This adds to the output state dictionary the values of all states, even if they are unchanged.

In [51]:
f_state = stx.stateful(f, output_unchanged=True)
f_state(2.0)

(Array(8., dtype=float32),
 {'z': Array(8., dtype=float32), 'y': Array(4., dtype=float32)})

We can provide input values for the state by passing in a (potentially partially filled) dictionary of state values.

In [52]:
f_state(2.0, state={"y": 3.0})

(Array(6., dtype=float32), {'z': Array(6., dtype=float32), 'y': 3.0})

States can be any valid PyTree.

In [53]:
def g(x):
  y = stx.get_state(name="y", init=lambda: (4 * jnp.ones(()), {"g": 1.0}))
  z = x * y[0]
  z = stx.set_state((z, y), name="zy")
  return z

stx.stateful(g)(1.0)

((Array(4., dtype=float32), (Array(4., dtype=float32), {'g': 1.0})),
 {'zy': (Array(4., dtype=float32), (Array(4., dtype=float32), {'g': 1.0}))})

In [54]:
stx.stateful(g)(1.0, state={"y": (2.0, {"g": 0.0})})

((Array(2., dtype=float32), (2.0, {'g': 0.0})),
 {'zy': (Array(2., dtype=float32), (2.0, {'g': 0.0}))})

In [55]:
jax.jit(stx.stateful(f))(1.0, state={"y": 3.0})

(Array(3., dtype=float32), {'z': Array(3., dtype=float32)})

We can compose with other JAX transformations, like `jit`, `grad` and `vmap`.

Note: we need to set `grad`'s `has_aux` when wrapping a `pull`ed function.

In [56]:
stx.stateful(jax.jit(f))(1.0, state={"y": 3.0})

(Array(3., dtype=float32), {'z': Array(3., dtype=float32)})

In [57]:
stx.stateful(jax.grad(f))(1.0, state={"y": 3.0})

(Array(3., dtype=float32, weak_type=True), {'z': Array(3., dtype=float32)})

In [58]:
jax.grad(stx.stateful(f), has_aux=True)(1.0, state={"y": 3.0})

(Array(3., dtype=float32, weak_type=True), {'z': Array(3., dtype=float32)})

In [59]:
stx.stateful(jax.vmap(f))(jnp.ones(10), state={"y": 3.0})

(Array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.], dtype=float32),
 {'z': Array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.], dtype=float32)})

In [60]:
jax.vmap(stx.stateful(f))(jnp.ones(10), state={"y": jnp.zeros(10) + 3})

(Array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.], dtype=float32),
 {'z': Array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.], dtype=float32)})

# Higher order primitives

Higher order primitives which perform various operations on functions are supported, but require handlers to be defined for them. Currently only `jax.jit` is implemented, and provides an example implementation.

Missing primitives can be implemented by adding a handler of the form `Callable[[JaxprEqn],
tuple[JaxprEqn, tuple[StateMeta, ...]]]`

As an example of manually handling an unimplemented primitive (`lax.cond`):


In [61]:
def inner(a):
  b = stx.get_state(name="b", init=lambda: 1.0)
  a = a + b
  stx.set_state(a, name="a")
  return a

def outer(a, b):
  return jax.lax.cond(b, inner, inner, a)

try:
  stx.stateful(outer)(jnp.zeros(10) + 3, True, state={"b": 2.0})
except stx.StateError as e:
  print(e)

stx.stateful(outer)(jnp.zeros(10) + 3, True)

Unknown state(s) "('b',)" provided in input state


(Array([4., 4., 4., 4., 4., 4., 4., 4., 4., 4.], dtype=float32), {})

This fails in two stages. First, the functions passed to `cond` aren't parsed by the custom interpreter. This means the interpreter has no knowledge of the "b" state, so when we pass the dictionary there is an unknown state error. When we don't pass the state, we see the function runs, but uses the initial values and doesn't return output state. 

We can fix this by manually passing state through `cond`:

In [62]:
def outer(a, b):

  b_s = stx.get_state(name="b", init=lambda: 1.0)

  def inner(a):
    a = a + b_s
    return a

  result = jax.lax.cond(b, inner, inner, a)
  stx.set_state(result, name="a")

  return result

stx.stateful(outer)(jnp.zeros(10) + 3, True, state={"b": 2.0})

(Array([5., 5., 5., 5., 5., 5., 5., 5., 5., 5.], dtype=float32),
 {'a': Array([5., 5., 5., 5., 5., 5., 5., 5., 5., 5.], dtype=float32)})

Now, we see that the state is threaded correctly.