In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
import numpy as np
import jax.numpy as jnp
from jax import partial, jit, grad, make_jaxpr, random, vmap
import jax



# Jax Reading group session 3: the Jaxpr


Session summary:
1. Define and explore the Jaxpr
2. Write an interpreter: evaluate a function
3. Write an "inverse" interpreter: a higher-order function that returns the inverse of a function

# Explore the Jaxpr


**What is the Jaxpr?**

Jaxpr is a simple statically-typed functional language that is used to represent Python functions. A Jaxpr instance represents a function with one or more typed parameters and one or more typed results.

Jax includes an interpreter than translates Python functions into this intermediate representation. Jax can then use this representation to evaluate it, do automatic differentiation, or compile it into XLA.


**Example:** consider the following python function


In [2]:

def foo(x):
    return x+1


This is represented by the following Jaxpr. Note that it's essentially a computational graph (as used in automatic differentiation)

In [3]:
closed_jaxpr = make_jaxpr(foo)(2.)
print(closed_jaxpr)

{ lambda  ; a.
  let b = add a 1.0
  in (b,) }


This Jaxpr has a bunch of info attached to it. This means you can interact with it easily. Later we'll evaluate a function outselves by starting from some concrete value and looping through the equations to evaluate the final result.

In [4]:
print(f"Input type: {closed_jaxpr.in_avals}")
print(f"Output type: {closed_jaxpr.out_avals}")
print(f"Equations in the jaxpr: {closed_jaxpr.jaxpr.eqns}")

Input type: [ShapedArray(float32[], weak_type=True)]
Output type: [ShapedArray(float32[])]
Equations in the jaxpr: [b = add a 1.0]


In [5]:
closed_jaxpr.jaxpr

{ lambda  ; a.
  let b = add a 1.0
  in (b,) }

From the [docs](https://jax.readthedocs.io/en/latest/jaxpr.html):

- `ClosedJaxpr` is a "partially applied `jax.core.Jaxpr`"
- `Jaxpr`: a field of ClosedJaxpr. the actual execution content

## Some properties of the Jaxpr


### 1. Only objects that have data dependence appear in the trace

In [6]:
def foo2(x):
    y = x + 1.
    a1 = 10.
    a2 = 100.
    a3 = a1 + a2
    a4 = a3/a1
    return y + a4

# all the intermediate stuff has been evaluated and is not fixed
make_jaxpr(foo2)(1.)

{ lambda  ; a.
  let b = add a 1.0
      c = add b 11.0
  in (c,) }

### 2. The traced function must be pure

In [7]:
global_list = []

def impure_function(x):
    global_list.append(x)
    return x+1


# call it: here it's a standard Python function
impure_function(1)
print(f"global_list: {global_list}")

global_list: [1]


In [8]:
# call it with an int32
make_jaxpr(impure_function)(1)

_ = [print(e) for e in global_list]

1
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [9]:
# call it with a float32
make_jaxpr(impure_function)(1.)

_ = [print(e) for e in global_list]

1
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [10]:
# call it with an array of float32s
make_jaxpr(impure_function)(jnp.ones(4))

_ = [print(e) for e in global_list]



1
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>


**Impure behaviour with `jit`:**

- The function is traced the first time you call it. So a `ShapedArray` is append to global_list
- Every other time nothing is appended.

In [11]:
jit_im_list = jit(impure_function)

global_list = []

jit_im_list(1)

print(global_list)

[Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>]


In [12]:
# after it's compiled nothing is appended to `global_list`
jit_im_list(1)

print(global_list)

[Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>]


### 3. Shapes and types are checked during tracing

In [13]:
def foo3(x):
    if jnp.shape(x) == ():
        return 2*x**3
    else:
        return 3*x

In [14]:
print(f"Jaxpr for integer input: \n{make_jaxpr(foo3)(1)}")


print(f"\nJaxpr for array input: \n{make_jaxpr(foo3)(jnp.ones(10))}")

Jaxpr for integer input: 
{ lambda  ; a.
  let b = integer_pow[ y=3 ] a
      c = mul b 2
  in (c,) }

Jaxpr for array input: 
{ lambda  ; a.
  let b = mul a 3.0
  in (b,) }


But if we branch on concrete values it doesn't work:

In [15]:
def foo4(x):
    if x < 0:
        return 2*x
    else: 
        return x
    
print(f"This works, as it's pure Python. {foo4(2)}")

This works, as it's pure Python. 2


In [16]:
# this doesn't work, as when the function is traced it replaces inputs with abstract values!

make_jaxpr(foo4)(2)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The problem arose with the `bool` function. 

While tracing the function foo4 at <ipython-input-15-66d6bf88dce5>:1, this concrete value was not available in Python because it depends on the value of the arguments to foo4 at <ipython-input-15-66d6bf88dce5>:1 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>

However conditioning on values works for autodiff:


In [17]:
# this works!
grad(foo4)(4.)

array(1., dtype=float32)

**Remark 1: Tracing with `grad`:** taking the grad of `foo4` still works. This is odd as the function is also traced!

This is because Jax considers different levels of abstraction (from [this talk](https://slideslive.com/38923687/jax-accelerated-machinelearning-research-via-composable-function-transformations-in-python) at 32:29):
 
 
- `Unshaped(f32)`: no control flow allowed. Ex: `z = cos(x+y)`
- `Shaped(f32, (2,2))`: can branch on shape. Ex: `if x.shape[0]>2:..` or `for subarray in array: ..`
    - `jit` and `vmap` use this
- `EpsilonBall(f32, [[1.,2.], [3.,4.]])...`: can branch on value if x.val != 0. Ex: `if x>0:...`
    - `grad` uses this
- `Concrete(f32)`: can always branch on value
    - `eval` uses this


**Remark 2:** There's no `EpsilonBall` class/variable in the source; they might be doing something else now.. That talk was from December 2019

### 4. Higher-order primitives


#### Conditionals: To condition on a concrete value you can use `jnp.where` or `lax.cond`

In [18]:
from jax import lax

# python version
def python_cond(x):
    if x >0:
        return 2*x
    else:
        return 0
    
@jit
def jax_cond1(x):
    return lax.cond(x>0., 
                    lambda y: 2.*y, 
                    lambda _: 0., 
                    x)

@jit
def jax_cond2(x):
    return jnp.where(x>0, 2*x, 0.)


In [19]:
python_cond(4.)

8.0

In [20]:
# this works
make_jaxpr(jax_cond1)(4.)

{ lambda  ; a.
  let b = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b = gt a 0.0
                                     c = convert_element_type[ new_dtype=int32
                                                               old_dtype=bool ] b
                                     d = cond[ branches=( { lambda  ; a.
                                                            let 
                                                            in (0.0,) }
                                                          { lambda  ; a.
                                                            let b = mul a 2.0
                                                            in (b,) } )
                                               linear=(False,) ] c a
                                 in (d,) }
                    device=None
                    donated_invars=(False,)
                    name=jax_cond1 ] a
  in (b,) }

In [21]:
make_jaxpr(jax_cond2)(4.)

{ lambda  ; a.
  let b = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b = gt a 0.0
                                     c = mul a 2.0
                                     d = xla_call[ backend=None
                                                   call_jaxpr={ lambda  ; a b c.
                                                                let d = select a b c
                                                                in (d,) }
                                                   device=None
                                                   donated_invars=(False, False, False)
                                                   name=_where ] b c 0.0
                                 in (d,) }
                    device=None
                    donated_invars=(False,)
                    name=jax_cond2 ] a
  in (b,) }

In [22]:
# jitting also works
print(jit(jax_cond1)(4.))
print(jit(jax_cond2)(4.))
print(jit(jax_cond1)(-4.))
print(jit(jax_cond2)(-4.))

8.0
8.0
0.0
0.0


#### Scan: this is the Jax way to loop over an array

In [23]:
def python_scan(x):
    new_list = []
    carry = 0
    
    for elem in np.arange(10):
        carry += elem*x
        new_list.append(carry)
    return carry, new_list


def jax_scan(x):
    def body(carry, elem):
        carry += elem*x
        return (carry, (carry, 120))
    
    carry, (new_list, alt_list) = lax.scan(body, (0), jnp.arange(10))
    return carry, new_list, alt_list

print(jax_scan(5))

(DeviceArray(225, dtype=int32), DeviceArray([  0,   5,  15,  30,  50,  75, 105, 140, 180, 225], dtype=int32), DeviceArray([120, 120, 120, 120, 120, 120, 120, 120, 120, 120], dtype=int32))


In [24]:
python_scan(5)

(225, [0, 5, 15, 30, 50, 75, 105, 140, 180, 225])

In [25]:
make_jaxpr(jax_scan)(10)

{ lambda  ; a.
  let b = iota[ dimension=0
                dtype=int32
                shape=(10,) ] 
      c d e = scan[ jaxpr={ lambda  ; a b c.
                            let d = mul c a
                                e = add b d
                            in (e, e, 120) }
                    length=10
                    linear=(False, False, False)
                    num_carry=1
                    num_consts=1
                    reverse=False
                    unroll=1 ] a 0 b
  in (c, d, e) }

# Writing  a Jaxpr interpreter

## Part 1: rewrite the Jax interpreter 

Writing a basic interpreter is in some cases surprisingly simple: see Peter Norvig's blog post about writing a [Lisp interpreter in Python](https://norvig.com/lispy.html). This involves doing the following 2 things:
1. Parse the program to represent it in a nice way. In Peter Noriv's post it's an abstract syntax tree. For Jax this is the Jaxpr!
2. Evaluate this intermediate representation. This is what we're doing here as we already have the Jaxpr



**Plan:** We will do the following:

1. Look at function [`eval_jaxpr`](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html#2.-Evaluating-a-Jaxpr) to see the big picture
2. Go through each line in the notebook to see what's going on in detail
3. Look at the function again and test it!


In [26]:
import numpy as np
from functools import wraps

from jax import core, lax
from jax.util import safe_map

In [27]:
# Here is the function we'll be playing with
def f(x):
    return jnp.exp(jnp.tanh(x))

closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))

print(f"ClosedJaxpr:\n{closed_jaxpr}")
print(f"\nLiterals: {closed_jaxpr.literals}")

ClosedJaxpr:
{ lambda  ; a.
  let b = tanh a
      c = exp b
  in (c,) }

Literals: []


In [28]:
print(type(closed_jaxpr))
print(type(closed_jaxpr.jaxpr))

<class 'jax.core.ClosedJaxpr'>
<class 'jax.core.Jaxpr'>


In [29]:
# see what fields `closed_jaxpr` has:
closed_jaxpr.consts

[]

### Note about ClosedJaxpr vs Jaxpr

- printing a `ClosedJaxpr` simply prints `self.jaxpr`
- printing `self.jaxpr` prints a nice looking version of all the stuff inside (the inargs, equations etc..)


#### Define `read` and `write` functions

In [30]:
env = {}

def read(var):
    if type(var) is core.Literal:
        return var.val
    return env[var]

def write(var, val):
    env[var] = val
    
print(env)

{}


#### Bind args and consts to environment

`Unit` and `UnitVar` are in [core.py](https://github.com/google/jax/blob/master/jax/core.py#L859)

In [31]:
# for some reason we must add this to the env dictionary (both have `__repr__(self): return "*"`)
write(core.unitvar, core.unit)

print(core.unitvar)
print(core.unit)

*
*


In [32]:
# the env dictionary is now updated
env

{*: *}

In [33]:
# Here are the variables (ie: arguments) of our function (without the concrete value)
closed_jaxpr.jaxpr.invars

[a]

In [34]:
# We assign `jnp.ones(5)` to this variable (ie: imagine that we've passed it to the function)
write(closed_jaxpr.jaxpr.invars[0], jnp.ones(5))

In [35]:
env

{*: *, a: DeviceArray([1., 1., 1., 1., 1.], dtype=float32)}

In [36]:
# Our function might have many arguments, so we should do this for all of them
# We need to call list with the map function as `map` is lazy
list(map(write, closed_jaxpr.jaxpr.invars, [jnp.ones(5)]))

[None]

Jax uses [`safe_map`](https://github.com/google/jax/blob/master/jax/util.py#L30) to write all the input arguments to the environment

In [37]:
# Note that in Jax they define `safe_map` which first checks 
# there are the same number of `invars` as there are args:
def safe_map(f, *args):
    args = list(map(list, args))
    n = len(args[0])
    for arg in args[1:]:
        assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
    return list(map(f, *args))

In [38]:
args = (closed_jaxpr.jaxpr.invars, [jnp.ones(5)])

args = list(map(list, args))

print(args)

# notice that all the lengths agree!
# This means that we've passed in the correct number of arguments
print([len(i) for i in args])

[[a], [DeviceArray([1., 1., 1., 1., 1.], dtype=float32)]]
[1, 1]


In [39]:
# We now deal with the `constvars`. There are none for our function
closed_jaxpr.literals

[]

In [40]:
# If there were, we would now add these to the env dictionary
safe_map(write, closed_jaxpr.jaxpr.constvars, [])

print(env)

{*: *, a: DeviceArray([1., 1., 1., 1., 1.], dtype=float32)}


#### Loop through equations

In [41]:
# We now loop through each equation, evaluate them, and add them to the env
print(closed_jaxpr.jaxpr.eqns)

# Let's go through the first equation
my_eqn = closed_jaxpr.jaxpr.eqns[0]

[b = tanh a, c = exp b]


In [42]:
# see what fields my_eqn has
my_eqn.invars

[a]

In [43]:
print(env) # recall the environment

# this just get's the value from `a`
invals = safe_map(read, my_eqn.invars) 
print(f"invals: {invals}")

{*: *, a: DeviceArray([1., 1., 1., 1., 1.], dtype=float32)}
invals: [DeviceArray([1., 1., 1., 1., 1.], dtype=float32)]


See the [Primitive](https://github.com/google/jax/blob/master/jax/core.py#L252) class in `core.py`

In [44]:
my_eqn.params

{}

In [45]:
# the `bind` method is how a primitive is called
outvals = my_eqn.primitive.bind(*invals, **my_eqn.params)
print(outvals)

[0.7615942 0.7615942 0.7615942 0.7615942 0.7615942]


In [46]:
# Let's check that this is correct
# The equation is tanh
print(my_eqn)
print("Answer: ", np.tanh(np.ones(5)))

b = tanh a
Answer:  [0.76159416 0.76159416 0.76159416 0.76159416 0.76159416]


In [47]:
# if the function returns a single output, place outvals in a list:
print("Multiple results? ",my_eqn.primitive.multiple_results)

if not my_eqn.primitive.multiple_results:
    outvals = [outvals]
    
print("outvals", outvals)

Multiple results?  False
outvals [DeviceArray([0.7615942, 0.7615942, 0.7615942, 0.7615942, 0.7615942], dtype=float32)]


In [48]:
print("name of output of our equation:", my_eqn.outvars)
print("value of output:", outvals)

name of output of our equation: [b]
value of output: [DeviceArray([0.7615942, 0.7615942, 0.7615942, 0.7615942, 0.7615942], dtype=float32)]


In [49]:
# now write the result (ie: `outvals`) to env
safe_map(write, my_eqn.outvars, outvals)

[None]

In [50]:
env

{*: *,
 a: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
 b: DeviceArray([0.7615942, 0.7615942, 0.7615942, 0.7615942, 0.7615942], dtype=float32)}

**Now do the same for the other equation**

In [51]:
my_eqn = closed_jaxpr.jaxpr.eqns[1]
print(my_eqn)

c = exp b


In [52]:
invals = safe_map(read, my_eqn.invars)
outvals = my_eqn.primitive.bind(*invals, **my_eqn.params)
if not my_eqn.primitive.multiple_results:
      outvals = [outvals]
safe_map(write, my_eqn.outvars, outvals)

print("env:\n")
for k,v in env.items():
    print(f"key: {k}, val: {v}")

env:

key: *, val: *
key: a, val: [1. 1. 1. 1. 1.]
key: b, val: [0.7615942 0.7615942 0.7615942 0.7615942 0.7615942]
key: c, val: [2.1416876 2.1416876 2.1416876 2.1416876 2.1416876]


In [53]:
# read out final output (ie: the last variable)
safe_map(read, closed_jaxpr.jaxpr.outvars)

[DeviceArray([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]

In [54]:
# Check:
f(jnp.ones(5))

DeviceArray([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)

## Interpreter: put it all in a function

In [55]:
def eval_jaxpr(jaxpr, consts, *args):
    # Mapping from variable -> value
    env = {}

    def read(var):
    # Literals are values baked into the Jaxpr
        if type(var) is core.Literal:
            return var.val
        return env[var]

    def write(var, val):
        env[var] = val

    # Bind args and consts to environment
    write(core.unitvar, core.unit)
    safe_map(write, jaxpr.invars, args)
    safe_map(write, jaxpr.constvars, consts)

    # Loop through equations and evaluate primitives using `bind`
    for eqn in jaxpr.eqns:
        # Read inputs to equation from environment
        invals = safe_map(read, eqn.invars)  
        # `bind` is how a primitive is called
        outvals = eqn.primitive.bind(*invals, **eqn.params)
        # Primitives may return multiple outputs or not
        if not eqn.primitive.multiple_results: 
            outvals = [outvals]
        # Write the results of the primitive into the environment
        safe_map(write, eqn.outvars, outvals) 
        # Read the final result of the Jaxpr from the environment
    return safe_map(read, jaxpr.outvars) 

In [56]:
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))

print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)

{ lambda  ; a.
  let b = tanh a
      c = exp b
  in (c,) }
[]


In [57]:
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))

[DeviceArray([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]

In [58]:
# check that it's the same as simply calling the function:
f(jnp.ones(5))

DeviceArray([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)

## inverse function

In [59]:
inverse_registry = {}
inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh

In [60]:
inverse_registry

{exp: <function jax._src.numpy.lax_numpy._one_to_one_unop.<locals>.<lambda>(x)>,
 tanh: <function jax._src.numpy.lax_numpy._one_to_one_unop.<locals>.<lambda>(x)>}

In [61]:
def inverse(fun):
    @wraps(fun)
    def wrapped(*args, **kwargs):
        # Since we assume unary functions, we won't
        # worry about flattening and
        # unflattening arguments
        closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
        out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
        return out[0]
    return wrapped

In [62]:
def inverse_jaxpr(jaxpr, consts, *args):
    env = {}

    def read(var):
        if type(var) is core.Literal:
            return var.val
        return env[var]

    def write(var, val):
        env[var] = val
    # Args now correspond to Jaxpr outvars
    write(core.unitvar, core.unit)
    safe_map(write, jaxpr.outvars, args)
    safe_map(write, jaxpr.constvars, consts)

    # Looping backward
    for eqn in jaxpr.eqns[::-1]:
        #  outvars are now invars 
        invals = safe_map(read, eqn.outvars)
        if eqn.primitive not in inverse_registry:
            raise NotImplementedError("{} does not have registered inverse.".format(
            eqn.primitive
            ))
        # Assuming a unary function 
        outval = inverse_registry[eqn.primitive](*invals)
        safe_map(write, eqn.invars, [outval])
    return safe_map(read, jaxpr.invars)

In [63]:
def f(x):
    return jnp.exp(jnp.tanh(x))

f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)

In [64]:
jax.make_jaxpr(inverse(f))(f(1.))

{ lambda  ; a.
  let b = log a
      c = atanh b
  in (c,) }

In [65]:
jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)

DeviceArray([-3.1440797, 15.584931 ,  2.2551253,  1.3155028,  1.       ],            dtype=float32)

# Use the invert function to sample from 1D distributions

TODO: invert functions of two arguments (with one fixed), such as `add`

In [66]:
inverse_registry = {}
inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh

inverse_registry

{exp: <function jax._src.numpy.lax_numpy._one_to_one_unop.<locals>.<lambda>(x)>,
 tanh: <function jax._src.numpy.lax_numpy._one_to_one_unop.<locals>.<lambda>(x)>}

In [67]:
def cdf_exp(x):
    beta = 0.5
    return 1 - jnp.exp(-beta*x)


inv_exp_pdf = inverse(cdf_exp)