<a href="https://colab.research.google.com/github/dpiponi/colabs/blob/main/Handling_Effects_with_Jax_(Public_version).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

Here is a simple Haskell function that makes use of the list monad to implement non-determinism.

```
f x y z = do
  u <- x
  v <- (2 *) <$> y
  w <- (v +) <$> z
  return $ u * v + w
```

For those not familiar with the Haskell list monad, the code draws `u` from each element of `x` in turn, then draws `v` from each element of `y` (multiplied by 2) in turn, and then draws `w` from `v` plus each element of z in turn. At the end, all of the values of `u * v + w` are collected into a single flat list. (In fact, it's similar to a Python list comprehension, which was derived from Haskell's.)

Here is a Python function that does the same:

```
def f(x, y, z):
  u = amb(x)
  v = 2. * amb(y)
  w = v + amb(z)
  return singleton(u * v + w)
```

But importantly, it runs on the GPU and draws from `u`, `v` and `w` in parallel.

How can it possibly do that? It requires running the portion of code after the `u = ...` once for each value of `u`.
This means `amb` must be affecting the control flow of the following lines.
Haskell has monads and do-notation for this, but Python has no obvious mechanism to allow an expression to change subsequent control flow.

# Jax
Well here's the path I'm going to take:
I'm going to use the GPU numerics compiler [Jax](https://github.com/google/jax).
As described at its web site
> "At its core, JAX is an extensible system for transforming numerical functions. Here are four of primary interest: grad, jit, vmap, and pmap"

so it's perfect for this task.
I'm going to implement something close to effect handling for it  (à la [Bauer et al](https://arxiv.org/abs/1306.6316)).
(I'll sketch what "effect handling" means for us below.)
But implementing, or even just modifying compilers isn't easy.
So instead I'm going to modify an interpreter and then use a trick simpler than the [Futamura projections](http://blog.sigfpe.com/2009/05/three-projections-of-doctor-futamura.html) to turn the interpreter into a compiler.

# A sketch of how it's going to work
Don't take the following too literally. It's really just a sketch of how the code works.
In particular I'm using notational abuse to say that $f=g$ means that the Python functions $f$ and $g$ are extensionally equal in the sense that applying $f$ to $x$ and applying $g$ to $x$ give the same result even though $f$ and $g$ may be implemented differently and in fact may even run on different hardware.

Here's the idea: suppose we have

1.   a magic "decompiler" function $T$ that can turn a Python function into an abstract syntax tree (AST) representing the code
2.   an interpreter $I$ that can interpret an AST, in effect turning it back into a Python function.
3.   a compiler, $C$, that can compile an AST, except that the result runs on the GPU.


We expect properties like this:


$\begin{align}
I(T(f)) &= f \\
T(I(t)) &= t \\
C(T(f)) &= f \mbox{ except $C(T(f))$ runs faster}\\
\end{align}$

Note in particular that $C\circ T$ takes Python functions and runs them fast on the GPU. I'm guessing this is the primary feature Jax users are looking for.

Note that C(T(I(T(f)))) is a Python-to-GPU compiler.

And now comes the trick: if I' is another interpreter, for example one that supports effect handling, then C(T(I'(T(f)))) compiles `f` for the GPU with support for effect handling.

Jax is designed to work on data in the form of "tensors". So $T$ can't take arbitrary functions as arguments. It can only work on functions that are ultimately built from the kinds of function that are found in the `numpy` library.

# The Implementation of our compiler
Here's the code. We start with some imports.


In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, make_jaxpr, vmap, core
from jax.util import safe_map, safe_zip

map = safe_map
zip = safe_zip

The AST used by Jax is called a jaxpr. Jax provides objects of a special type used to trace a function. All of the standard operations used in `numpy` are replaced by Jax versions and operations like `+` and `*` are overloaded. The special objects are provided as arguments to your function so that instead of doing any numerical work, your function builds an AST that is a trace of its execution. As a normal user you never have to explicitly see the tracer objects, this functionality is all wrapped up for you.

So given a function like
```
def f(x):
  y = jnp.sum(x)  # jnp is the Jax-provided version of numpy
  return y * y + 1
```
the result of tracing looks something like this
```
{ lambda  ; a.
  let b = reduce_sum[ axes=(0,) ] a
      c = mul b b
      d = add c 1
  in (d,) }
```
It's a sort of assembly language for GPUs.

Over at github in [core.py is an interpreter](https://github.com/google/jax/blob/master/jax/core.py) for jaxprs called `eval_jaxpr_handler`. The main thing it does is loop over the lines in a jaxpr, performing a suitable evaluation, and updating a dictionary of values with the results.

I've tweaked that in two ways.


1.   I've replaced the loop with a recursion. This allows me to split the entire interpreter into a part that evaluates things to get a value and a continuation that does something with the value.
2.   If the primitive evaluated has a `handler` attribute then instead of simply applying the continuation to the value the handler gets to do whatever it likes with the continuation and value. This is the key feature of Bauer et al's effect handling. As far as we're concerned an effect is something you get by applying the continuation to its argument in a non-standard way.

I've only changed a handful of lines from the code at github.

In [None]:
def eval_jaxpr_handler(jaxpr: core.Jaxpr, consts, *args):
  env: Dict[Var, Any] = {}

  def write(v, val):
    env[v] = val

  write(core.unitvar, core.unit)
  map(write, jaxpr.constvars, consts)

  # This is the recursion that replaces the main loop in the original
  # `eval_jaxpr`.
  def eval_jaxpr_loop(eqns, env, invars, args):
    # The handler could call the continuation multiple times so we
    # we need this function to be somewhat pure. We copy `env` to
    # ensure it isn't mutated.
    env = env.copy()

    def read(v):
      if type(v) is core.Literal:
        return v.val
      else:
        return env[v]

    def write(v, val):
      env[v] = val

    map(write, invars, args)

    if eqns:
      eqn = eqns[0]
      in_vals = map(read, eqn.invars)
      in_vals = list(in_vals)
      call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
      if call_jaxpr:
        subfuns = [lu.wrap_init(partial(eval_jaxpr_handler, call_jaxpr, ()))]
      else:
        subfuns = []
      with jax.source_info_util.user_context(eqn.source_info):
        if hasattr(eqn.primitive, 'handler'):
          args = subfuns + in_vals
          # This definition "reifies" the remainder of the evaluation
          # loop so it can be explicitly passed to the handler.
          def continuation(args):
            return eval_jaxpr_loop(eqns[1:], env, eqn.outvars, [args])
          return [eqn.primitive.handler(continuation, *args)]
        else:
          ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
          if not eqn.primitive.multiple_results:
              ans = [ans]
          return eval_jaxpr_loop(eqns[1:], env, eqn.outvars, ans)
    else:
      return map(read, jaxpr.outvars)

  return eval_jaxpr_loop(jaxpr.eqns, env, jaxpr.invars, args)


Now comes the implementation of an example effect. We need the functionality of Haskell's list monad. The identity for the monad is `singleton` and the bind function is `concatMap`. Jax will provide us with the `map` bit so we just need to implement `concat` which I've called `flatten` here.

In [None]:
def flatten(xs):
  return jnp.reshape(xs, (xs.shape[0] * xs.shape[1],) + xs.shape[2:])

def singleton(x):
  return jnp.array([x])


# The implementation of `amb`
Now I define a new primitive called `amb`. Jax provides a `Primitive` class for this. Primitives need an "abstract evaluation" operation to determine the size of the result of `amb`.

The `handle_list` function is similar to Haskell's bind. But instead of composing `flatten` with `map` I compose it with `jax.vmap` which maps a function in *parallel*. So the continuation is applied not just to the argument of `amb`, but to each element of it in turn. (By the way, because the continuation is used repeatedly this is not a "tame" effect handler.)

In [None]:
amb_p = core.Primitive('amb')

def handle_list(f, x):
  return flatten(jax.vmap(f)(x)[0])

amb_p.handler = handle_list

def amb(xs):
  return amb_p.bind(xs)

def amb_abstract_eval(xs):
  return core.ShapedArray(xs.shape[1:], xs.dtype)

amb_p.def_abstract_eval(amb_abstract_eval)

<function __main__.amb_abstract_eval>

Now I define the functions provided by Jax that play the role of $T, I, I', C$ and $T$ that I described above.

One issue is that Jax needs to be able to determine types and tensor sizes statically. When we apply $T$ to a function, it needs to see the kinds of arguments you're going to use so as to make its inferences. So $T$ here is a family of functions parameterised by the arguments $f$ will be applied to. And because `jax.jit` both traces and compiles for GPU I can implement `$C\circ T$ as a single function.


In [None]:
def T(*xs):
  return lambda f: make_jaxpr(f)(*xs)

# The usual interpreter
def I(f):
  # `make_jaxpr` builds a separate "symbol table" containing the constants
  # needed by the jaxpr. This is why we also pass `f.literals` into
  # `eval_jaxpr`.
  return lambda *xs: jax.core.eval_jaxpr(f.jaxpr, f.literals, *xs)

# Our special interpreter
def I_prime(f):
  return lambda *xs: eval_jaxpr_handler(f.jaxpr, f.literals, *xs)

CT = jax.jit

# The example, implemented

In [None]:
xs = jnp.arange(1000.)
ys = jnp.arange(1000.)
zs = jnp.arange(1000.)

def f(x, y, z):
  u = amb(x)
  v = 2. * amb(y)
  w = v + amb(z)
  return singleton(u * v + w)

T_ = T(xs, ys, zs)  # Specialize `T` to the particular arguments.

First let's look at the jaxpr for the original function. We can see the new "assembly language mnemonic" for `amb`.

In [None]:
print(T_(f))

{ lambda  ; a b c.
  let d = amb a
      e = amb b
      f = mul e 2.0
      g = mul d f
      h = amb c
      i = add f h
      j = add g i
      k = broadcast_in_dim[ broadcast_dimensions=()
                            shape=(1,) ] j
  in (k,) }


# The big moment

OK, this is it. Here's where we actually run the code with the `amb` effect.

In [None]:
print(CT(I_prime(T_(f)))(xs, ys, zs))

[DeviceArray([0.000000e+00, 1.000000e+00, 2.000000e+00, ..., 1.998997e+06,
             1.998998e+06, 1.998999e+06], dtype=float32)]


Let's have a look at the jaxpr generated after the `amb` effect is handled.

In [None]:
print(T_(I_prime(T_(f))))

{ lambda  ; a b c.
  let d = reshape[ dimensions=None
                   new_sizes=(1000, 1) ] a
      e = mul b 2.0
      f = broadcast_in_dim[ broadcast_dimensions=(1,)
                            shape=(1000, 1000) ] e
      g = mul d f
      h = reshape[ dimensions=None
                   new_sizes=(1000, 1000, 1) ] g
      i = reshape[ dimensions=None
                   new_sizes=(1000, 1) ] e
      j = broadcast_in_dim[ broadcast_dimensions=(1,)
                            shape=(1000, 1000) ] c
      k = add i j
      l = broadcast_in_dim[ broadcast_dimensions=(1, 2)
                            shape=(1000, 1000, 1000) ] k
      m = add h l
      n = broadcast_in_dim[ broadcast_dimensions=(0, 1, 2)
                            shape=(1000, 1000, 1000, 1) ] m
      o = reshape[ dimensions=None
                   new_sizes=(1000, 1000, 1000) ] n
      p = reshape[ dimensions=None
                   new_sizes=(1000, 1000000) ] o
      q = reshape[ dimensions=None
                   

The result is similar to the original except a bunch of reshapes and broadcasts have been inserted.

# Discusson and caveats
This method is very flexible. You can implement a wide array of handlers including readers, writers, state, debugging tools, [sow/reap](https://www.tensorflow.org/probability/oryx/notebooks/a_tour_of_oryx) and probability. In many ways it's more flexible than Haskell monads and you have some freedom with types that Haskell doesn't give you.

On the negative side Jax needs to be able to determine all tensor sizes and types statically. So in a fragment of code like
```
a <- amb(...)
b <- amb(...a...)
```
it's perfectly fine for the argument to the second amb to depend on the value of `a`, but its shape can't depend on the value of `a`. For similar reasons you can't implement Haskell's `guard` function because that would dynamically change the size of a tensor. (But maybe you could probably build a tensor of flags and plumb that through the code in a way that a user doesn't see it.)

And another caveat is that in the interests of brevity I skipped some things. Jax can work with data more structured than pure tensors. You can work with Python objects that contain tensors, such as lists, tuples, or your own types. The extra structure is removed before the data goes to the GPU and is put back afterwards. I have ignored this.

The function `amb` is named after a related function invented by John McCarthy who invented Lisp. [Here](http://www.randomhacks.net/2005/10/11/amb-operator/)'s where I first learnt about it. The version I describe above doesn't use backtracking and runs all paths.

Lastly, whether the code above runs on a CPU or a GPU depends on how you have colab configured.

And thanks to Matt Johnson on the Jax team for explaining all of the details of jaxprs and how to use them.