In [2]:
import jax
import jax.numpy as jnp
import jax.tree_util as tu

In [28]:
#@title Helper functions (execute this cell)
import functools
import traceback

_indentation = 0
def _trace(msg=None):
		"""Print a message at current indentation."""
		if msg is not None:
				print("  " * _indentation + msg)

def _trace_indent(msg=None):
		"""Print a message and then indent the rest."""
		global _indentation
		_trace(msg)
		_indentation = 1 + _indentation

def _trace_unindent(msg=None):
		"""Unindent then print a message."""
		global _indentation
		_indentation = _indentation - 1
		_trace(msg)

def trace(name):
	"""A decorator for functions to trace arguments and results."""

	def trace_func(func):  # pylint: disable=missing-docstring
		def pp(v):
				"""Print certain values more succinctly"""
				vtype = str(type(v))
				if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
						return "<JaxComputationBuilder>"
				elif "jaxlib.xla_extension.XlaOp" in vtype:
						return "<XlaOp at 0x{:x}>".format(id(v))
				elif ("partial_eval.JaxprTracer" in vtype or
							"batching.BatchTracer" in vtype or
							"ad.JVPTracer" in vtype):
						return "Traced<{}>".format(v.aval)
				elif isinstance(v, tuple):
						return "({})".format(pp_values(v))
				else:
						return str(v)
		def pp_values(args):
				return ", ".join([pp(arg) for arg in args])
		
		@functools.wraps(func)
		def func_wrapper(*args):
			_trace_indent("call {}({})".format(name, pp_values(args)))
			res = func(*args)
			_trace_unindent("|<- {} = {}".format(name, pp(res)))
			return res

		return func_wrapper

	return trace_func

class expectNotImplementedError(object):
	"""Context manager to check for NotImplementedError."""
	def __enter__(self): pass
	def __exit__(self, type, value, tb):
		global _indentation
		_indentation = 0
		if type is NotImplementedError:
			print("\nFound expected exception:")
			traceback.print_exc(limit=3)
			return True
		elif type is None:  # No exception
			assert False, "Expected NotImplementedError"
		else:
			return False

In [87]:
a = jnp.linspace(0, 10, 2)

#@trace("f")
def f(x): 
    return dict(a=[x * a, 2 * x * a], b=3 * x * a)

g = [lambda x: x * a, lambda x: 2 * x * a]

t2 = dict(a=[1, 2, 3], b=dict(c=3, d=0))

In [88]:
f(2)

{'a': [Array([ 0., 20.], dtype=float32), Array([ 0., 40.], dtype=float32)],
 'b': Array([ 0., 60.], dtype=float32)}

In [89]:
jax.make_jaxpr(f)(5)

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[2][39m; b[35m:i32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] b
    d[35m:f32[2][39m = mul c a
    e[35m:i32[][39m = mul 2 b
    f[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] e
    g[35m:f32[2][39m = mul f a
    h[35m:i32[][39m = mul 3 b
    i[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] h
    j[35m:f32[2][39m = mul i a
  [34m[22m[1min [39m[22m[22m(d, g, j) }

In [90]:
g

[<function __main__.<lambda>(x)>, <function __main__.<lambda>(x)>]

In [91]:
def examine_jaxpr(closed_jaxpr):
    jaxpr = closed_jaxpr.jaxpr
    print("invars:", jaxpr.invars)
    print("outvars:", jaxpr.outvars)
    print("constvars:", jaxpr.constvars)
    for eqn in jaxpr.eqns:
        print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
    print()
    print("jaxpr:", jaxpr)

In [92]:
examine_jaxpr(jax.make_jaxpr(f)(5))

invars: [b]
outvars: [d, g, j]
constvars: [a]
equation: [b] convert_element_type [c] {'new_dtype': dtype('float32'), 'weak_type': False}
equation: [c, a] mul [d] {}
equation: [2, b] mul [e] {}
equation: [e] convert_element_type [f] {'new_dtype': dtype('float32'), 'weak_type': False}
equation: [f, a] mul [g] {}
equation: [3, b] mul [h] {}
equation: [h] convert_element_type [i] {'new_dtype': dtype('float32'), 'weak_type': False}
equation: [i, a] mul [j] {}

jaxpr: { [34m[22m[1mlambda [39m[22m[22ma[35m:f32[2][39m; b[35m:i32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] b
    d[35m:f32[2][39m = mul c a
    e[35m:i32[][39m = mul 2 b
    f[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] e
    g[35m:f32[2][39m = mul f a
    h[35m:i32[][39m = mul 3 b
    i[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] h
    j[35m:f32[2][39m = mul i a
  [34m[22m[

In [93]:
jax.make_jaxpr(f)(5)

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[2][39m; b[35m:i32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] b
    d[35m:f32[2][39m = mul c a
    e[35m:i32[][39m = mul 2 b
    f[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] e
    g[35m:f32[2][39m = mul f a
    h[35m:i32[][39m = mul 3 b
    i[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] h
    j[35m:f32[2][39m = mul i a
  [34m[22m[1min [39m[22m[22m(d, g, j) }

In [81]:
jax.make_jaxpr(f)(5).__class__

jax._src.core.ClosedJaxpr

In [78]:
jax.core.eval_jaxpr(jax.make_jaxpr(f)(5), [], 1)

AttributeError: 'ClosedJaxpr' object has no attribute 'constvars'

In [77]:
jax.core.eval_jaxpr??

[0;31mSignature:[0m
[0mjax[0m[0;34m.[0m[0mcore[0m[0;34m.[0m[0meval_jaxpr[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mjaxpr[0m[0;34m:[0m [0;34m'Jaxpr'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mconsts[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m*[0m[0margs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpropagate_source_info[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
[0;32mdef[0m [0meval_jaxpr[0m[0;34m([0m[0mjaxpr[0m[0;34m:[0m [0mJaxpr[0m[0;34m,[0m [0mconsts[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0mpropagate_source_info[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;32mdef[0m [0mread[0m[0;34m([0m[0mv[0m[0;34m:[0m [0mAtom[0m[0;34m)[0m [0;34m->[0m [0mAny[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mreturn[0m [0mv[0m[0;34m.[0m[0mval[0m [0;32mif[0m [0mi

In [94]:
def examine_jaxpr(closed_jaxpr):
    jaxpr = closed_jaxpr.jaxpr
    print("invars:", jaxpr.invars)
    print("outvars:", jaxpr.outvars)
    print("constvars:", jaxpr.constvars)
    for eqn in jaxpr.eqns:
        print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
    print()
    print("jaxpr:", jaxpr)

In [95]:
examine_jaxpr(jax.make_jaxpr(f)(5))

invars: [b]
outvars: [d, g, j]
constvars: [a]
equation: [b] convert_element_type [c] {'new_dtype': dtype('float32'), 'weak_type': False}
equation: [c, a] mul [d] {}
equation: [2, b] mul [e] {}
equation: [e] convert_element_type [f] {'new_dtype': dtype('float32'), 'weak_type': False}
equation: [f, a] mul [g] {}
equation: [3, b] mul [h] {}
equation: [h] convert_element_type [i] {'new_dtype': dtype('float32'), 'weak_type': False}
equation: [i, a] mul [j] {}

jaxpr: { [34m[22m[1mlambda [39m[22m[22ma[35m:f32[2][39m; b[35m:i32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] b
    d[35m:f32[2][39m = mul c a
    e[35m:i32[][39m = mul 2 b
    f[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] e
    g[35m:f32[2][39m = mul f a
    h[35m:i32[][39m = mul 3 b
    i[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] h
    j[35m:f32[2][39m = mul i a
  [34m[22m[

In [101]:
jax.make_jaxpr(f)(5)

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[2][39m; b[35m:i32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] b
    d[35m:f32[2][39m = mul c a
    e[35m:i32[][39m = mul 2 b
    f[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] e
    g[35m:f32[2][39m = mul f a
    h[35m:i32[][39m = mul 3 b
    i[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] h
    j[35m:f32[2][39m = mul i a
  [34m[22m[1min [39m[22m[22m(d, g, j) }

In [146]:
import jax
import jax.tree_util as jtu

def transform_jaxpr(f):
    """
        Expects 'f' to be a function with a single float argument
    """
    new_jaxpr = []
    closed_jaxpr = jax.make_jaxpr(f)(0.0)
    structure = jtu.tree_structure(jax.eval_shape(f, 0.0))
    for i in range(len(closed_jaxpr.jaxpr.outvars)):
        c_jaxpr = closed_jaxpr.jaxpr.replace(outvars=[closed_jaxpr.jaxpr.outvars[i]])
        c_jaxpr = jax.core.ClosedJaxpr(c_jaxpr, closed_jaxpr.consts)
        c_fun = jax.core.jaxpr_as_fun(c_jaxpr)
        new_jaxpr.append(c_fun)
    tree = jtu.tree_unflatten(structure, new_jaxpr)
    return tree


a = jnp.linspace(0, 10, 2)
def H(x): 
    return dict(a=[x * a, 2 * x * a], b=3 * x * a)
transform_jaxpr(H)

{'a': [functools.partial(<function jaxpr_as_fun at 0x1073c90d0>, { [34m[22m[1mlambda [39m[22m[22ma[35m:f32[2][39m; b[35m:f32[][39m. [34m[22m[1mlet
      [39m[22m[22mc[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] b
      d[35m:f32[2][39m = mul c a
      e[35m:f32[][39m = mul 2.0 b
      f[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] e
      g[35m:f32[2][39m = mul f a
      h[35m:f32[][39m = mul 3.0 b
      i[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] h
      j[35m:f32[2][39m = mul i a
    [34m[22m[1min [39m[22m[22m(d,) }),
  functools.partial(<function jaxpr_as_fun at 0x1073c90d0>, { [34m[22m[1mlambda [39m[22m[22ma[35m:f32[2][39m; b[35m:f32[][39m. [34m[22m[1mlet
      [39m[22m[22mc[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] b
      d[35m:f32[2][39m = mul c a
      e[35m:f32[][39m = mul 2.0 b
      f[35m:f32[][39m = 

In [144]:
transform_jaxpr(f)["a"][0](3.0)

PyTreeDef({'a': [*, *], 'b': *})


[Array([ 0., 30.], dtype=float32)]

In [44]:
import numpy as np
from functools import wraps

from jax import core
from jax import lax
from jax._src.util import safe_map

In [46]:
def f(x):
    return jnp.exp(jnp.tanh(x))

closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[5][39m. [34m[22m[1mlet[39m[22m[22m b[35m:f32[5][39m = tanh a; c[35m:f32[5][39m = exp b [34m[22m[1min [39m[22m[22m(c,) }
[]


In [50]:
def eval_jaxpr(jaxpr, consts, *args):
    # Mapping from variable -> value
    env = {}

    def read(var):
        # Literals are values baked into the Jaxpr
        if type(var) is core.Literal:
            return var.val
        return env[var]

    def write(var, val):
        env[var] = val

    # Bind args and consts to environment
    safe_map(write, jaxpr.invars, args)
    safe_map(write, jaxpr.constvars, consts)

    # Loop through equations and evaluate primitives using `bind`
    for eqn in jaxpr.eqns:
        # Read inputs to equation from environment
        invals = safe_map(read, eqn.invars)  
        # `bind` is how a primitive is called
        outvals = eqn.primitive.bind(*invals, **eqn.params)
        # Primitives may return multiple outputs or not
        if not eqn.primitive.multiple_results: 
            outvals = [outvals]
        # Write the results of the primitive into the environment
        safe_map(write, eqn.outvars, outvals) 
    # Read the final result of the Jaxpr from the environment
    return safe_map(read, jaxpr.outvars) 

In [51]:
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))

[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]

In [199]:
fs = []
for i in np.arange(10):
    def f(t, i):
        return i * t
    fs.append(jtu.Partial(f, i=i))
    

@jax.jit
def g(f):
    return 2 * f(3.0)

g(arrays[2]) # OK 
#jax.tree_map(g, dict(a=fs[0], b=fs[1])) # explose

Array(12., dtype=float32, weak_type=True)

In [219]:
idx = jnp.arange(len(fs))
vmapped = jax.vmap(lambda i, t: jax.lax.switch(i, fs, t), in_axes=(0, None))
vmapped(idx, 1.0)

Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32, weak_type=True)

In [229]:
a = jnp.linspace(0, 3, 1000)
b = jnp.linspace(-1, 1, 1000)

@jax.jit
def f(x):
    return jnp.exp(jnp.tanh(b * x))

In [230]:
%timeit f(a).block_until_ready()

3.81 µs ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [231]:
from jaxtyping import ArrayLike, PyTree

def callable_to_pytree(f) -> PyTree:
    """Turns the function `f` from a callable that retuns a tree
    into a tree of callables.

    Example:
        ```
            def f(t):
                return dict(a=t, b=2*t)

            f(1.0) # dict(a=1.0, b=2.0)

            tree = callable_to_pytree(f)
            tree["a"](1.0) # 1.0
            tree["b"](1.0) # 2.0
        ```
    """
    new_jaxpr = []
    closed_jaxpr = jax.make_jaxpr(f)(0.0)
    structure = jtu.tree_structure(jax.eval_shape(f, 0.0))
    for i in range(len(closed_jaxpr.jaxpr.outvars)):
        c_jaxpr = closed_jaxpr.jaxpr.replace(outvars=[closed_jaxpr.jaxpr.outvars[i]])
        c_jaxpr = jax.core.ClosedJaxpr(c_jaxpr, closed_jaxpr.consts)
        c_fun = jax.core.jaxpr_as_fun(c_jaxpr)
        new_jaxpr.append(c_fun)
    return jtu.tree_unflatten(structure, new_jaxpr)


f2 = callable_to_pytree(f)

In [232]:
f2

functools.partial(<function jaxpr_as_fun at 0x1073c90d0>, { [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[1000][39m = pjit[
      name=f
      jaxpr={ [34m[22m[1mlambda [39m[22m[22mc[35m:f32[1000][39m; d[35m:f32[][39m. [34m[22m[1mlet
          [39m[22m[22me[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] d
          f[35m:f32[1000][39m = mul c e
          g[35m:f32[1000][39m = tanh f
          h[35m:f32[1000][39m = exp g
        [34m[22m[1min [39m[22m[22m(h,) }
    ] a
  [34m[22m[1min [39m[22m[22m(b,) })