Skip to content

Commit

Permalink
Update custom interpreter tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed May 11, 2022
1 parent 882a2d5 commit aca9dc6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 28 deletions.
25 changes: 11 additions & 14 deletions docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@
"source": [
"To get a first look at Jaxprs, consider the `make_jaxpr` transformation. `make_jaxpr` is essentially a \"pretty-printing\" transformation:\n",
"it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation.\n",
"Although we can't generally use the Jaxprs that it returns, it is useful for debugging and introspection.\n",
"Let's use it to look at how some example Jaxprs\n",
"are structured."
"`make_jaxpr` is useful for debugging and introspection.\n",
"Let's use it to look at how some example Jaxprs are structured."
]
},
{
Expand Down Expand Up @@ -201,7 +200,7 @@
"\n",
"### 1. Tracing a function\n",
"\n",
"We can't use `make_jaxpr` for this, because we need to pull out constants created during the trace to pass into the Jaxpr. However, we can write a function that does something very similar to `make_jaxpr`."
"Let's use `make_jaxpr` to trace a function into a Jaxpr."
]
},
{
Expand All @@ -227,8 +226,8 @@
"id": "CpTml2PTrzZ4"
},
"source": [
"This function first flattens its arguments into a list, which are the abstracted and wrapped as partial values. The `jax.make_jaxpr` function is used to then trace a function into a Jaxpr\n",
"from a list of partial value inputs."
"`jax.make_jaxpr` returns a *closed* Jaxpr, which is a Jaxpr that has been bundled with\n",
"the constants (`literals`) from the trace."
]
},
{
Expand All @@ -243,7 +242,7 @@
" return jnp.exp(jnp.tanh(x))\n",
"\n",
"closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))\n",
"print(closed_jaxpr)\n",
"print(closed_jaxpr.jaxpr)\n",
"print(closed_jaxpr.literals)"
]
},
Expand Down Expand Up @@ -321,7 +320,7 @@
"source": [
"Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n",
"\n",
"Furthermore, this interpreter does not handle `subjaxprs`, which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
"Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
]
},
{
Expand Down Expand Up @@ -389,9 +388,8 @@
"def inverse(fun):\n",
" @wraps(fun)\n",
" def wrapped(*args, **kwargs):\n",
" # Since we assume unary functions, we won't\n",
" # worry about flattening and\n",
" # unflattening arguments\n",
" # Since we assume unary functions, we won't worry about flattening and\n",
" # unflattening arguments.\n",
" closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)\n",
" out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)\n",
" return out[0]\n",
Expand Down Expand Up @@ -434,9 +432,8 @@
" # outvars are now invars \n",
" invals = safe_map(read, eqn.outvars)\n",
" if eqn.primitive not in inverse_registry:\n",
" raise NotImplementedError(\"{} does not have registered inverse.\".format(\n",
" eqn.primitive\n",
" ))\n",
" raise NotImplementedError(\n",
" f\"{eqn.primitive} does not have registered inverse.\")\n",
" # Assuming a unary function \n",
" outval = inverse_registry[eqn.primitive](*invals)\n",
" safe_map(write, eqn.invars, [outval])\n",
Expand Down
25 changes: 11 additions & 14 deletions docs/notebooks/Writing_custom_interpreters_in_Jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ for function transformation.

To get a first look at Jaxprs, consider the `make_jaxpr` transformation. `make_jaxpr` is essentially a "pretty-printing" transformation:
it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation.
Although we can't generally use the Jaxprs that it returns, it is useful for debugging and introspection.
Let's use it to look at how some example Jaxprs
are structured.
`make_jaxpr` is useful for debugging and introspection.
Let's use it to look at how some example Jaxprs are structured.

```{code-cell} ipython3
:id: RSxEiWi-EeYW
Expand Down Expand Up @@ -139,7 +138,7 @@ The way we'll implement this is by (1) tracing `f` into a Jaxpr, then (2) interp

### 1. Tracing a function

We can't use `make_jaxpr` for this, because we need to pull out constants created during the trace to pass into the Jaxpr. However, we can write a function that does something very similar to `make_jaxpr`.
Let's use `make_jaxpr` to trace a function into a Jaxpr.

```{code-cell} ipython3
:id: BHkg_3P1pXJj
Expand All @@ -155,8 +154,8 @@ from jax._src.util import safe_map

+++ {"id": "CpTml2PTrzZ4"}

This function first flattens its arguments into a list, which are the abstracted and wrapped as partial values. The `jax.make_jaxpr` function is used to then trace a function into a Jaxpr
from a list of partial value inputs.
`jax.make_jaxpr` returns a *closed* Jaxpr, which is a Jaxpr that has been bundled with
the constants (`literals`) from the trace.

```{code-cell} ipython3
:id: Tc1REN5aq_fH
Expand All @@ -165,7 +164,7 @@ def f(x):
return jnp.exp(jnp.tanh(x))
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr)
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)
```

Expand Down Expand Up @@ -224,7 +223,7 @@ eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))

Notice that `eval_jaxpr` will always return a flat list even if the original function does not.

Furthermore, this interpreter does not handle `subjaxprs`, which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.

+++ {"id": "0vb2ZoGrCMM4"}

Expand Down Expand Up @@ -261,9 +260,8 @@ inverse_registry[lax.tanh_p] = jnp.arctanh
def inverse(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
# Since we assume unary functions, we won't
# worry about flattening and
# unflattening arguments
# Since we assume unary functions, we won't worry about flattening and
# unflattening arguments.
closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out[0]
Expand Down Expand Up @@ -296,9 +294,8 @@ def inverse_jaxpr(jaxpr, consts, *args):
# outvars are now invars
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError("{} does not have registered inverse.".format(
eqn.primitive
))
raise NotImplementedError(
f"{eqn.primitive} does not have registered inverse.")
# Assuming a unary function
outval = inverse_registry[eqn.primitive](*invals)
safe_map(write, eqn.invars, [outval])
Expand Down

0 comments on commit aca9dc6

Please sign in to comment.