Skip to content

Commit

Permalink
small autodidax tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Aug 5, 2021
1 parent c75f773 commit 24de3e9
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 109 deletions.
133 changes: 96 additions & 37 deletions docs/autodidax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,7 @@
},
"outputs": [],
"source": [
"from typing import DefaultDict\n",
"from collections import defaultdict\n",
"import string\n",
"\n",
Expand Down Expand Up @@ -1800,7 +1801,7 @@
"def vcat(ps: List[PPrint]) -> PPrint:\n",
" return sum(ps, pp(''))\n",
"\n",
"def pp_jaxpr(jaxpr: Jaxpr):\n",
"def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:\n",
" namegen = (''.join(s) for r in it.count(1)\n",
" for s in it.permutations(string.ascii_lowercase, r))\n",
" names = defaultdict(lambda: next(namegen))\n",
Expand All @@ -1811,15 +1812,19 @@
" return (pp(f'{{ lambda {in_binders} .') +\n",
" ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))\n",
"\n",
"def var_str(names: Dict[Var, str], v: Var) -> str:\n",
"def var_str(names: DefaultDict[Var, str], v: Var) -> str:\n",
" return f'{names[v]}:{v.aval.str_short()}'\n",
"\n",
"def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
" rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n",
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
" for x in eqn.inputs)))\n",
" return lhs >> pp(' = ') >> rhs\n",
"def pp_eqn(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
" rule = pp_rules.get(eqn.primitive)\n",
" if rule:\n",
" return rule(names, eqn)\n",
" else:\n",
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
" rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n",
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
" for x in eqn.inputs)))\n",
" return lhs >> pp(' = ') >> rhs\n",
"\n",
"def pp_params(params: Dict[str, Any]) -> PPrint:\n",
" items = sorted(params.items())\n",
Expand All @@ -1828,7 +1833,8 @@
" else:\n",
" return pp(' ')\n",
"\n",
"Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))"
"Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))\n",
"pp_rules: Dict[Primitive, Callable[..., PPrint]] = {}"
]
},
{
Expand Down Expand Up @@ -2167,7 +2173,7 @@
" [bool, int, float, np.ndarray, np.float64, np.float32]}\n",
"\n",
"def handle_result(aval: ShapedArray, buf):\n",
" del aval # Unused for now.\n",
" del aval # Unused for now\n",
" return buf.to_py()\n",
"\n",
"xla_translations = {}"
Expand Down Expand Up @@ -2332,7 +2338,7 @@
"outputs": [],
"source": [
"def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):\n",
" del num_consts # Unused.\n",
" del num_consts # Unused\n",
" new_jaxpr, new_consts = jvp_jaxpr(jaxpr)\n",
" outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,\n",
" num_consts=len(new_consts))\n",
Expand Down Expand Up @@ -2362,7 +2368,7 @@
"outputs": [],
"source": [
"def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):\n",
" del num_consts # Unused.\n",
" del num_consts # Unused\n",
" new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))\n",
" outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,\n",
" num_consts=len(new_consts))\n",
Expand Down Expand Up @@ -2397,7 +2403,7 @@
"outputs": [],
"source": [
"def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):\n",
" del num_consts # Unused.\n",
" del num_consts # Unused\n",
" jaxpr_type = typecheck_jaxpr(jaxpr)\n",
" if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):\n",
" raise TypeError\n",
Expand Down Expand Up @@ -2532,6 +2538,29 @@
"print(ydot)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_end_of_cell_marker": 0,
"lines_to_next_cell": 1,
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"def pprint_xla_call(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
" params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}\n",
" rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>\n",
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
" for x in eqn.inputs)))\n",
" return vcat([lhs >> pp(' = ') >> rhs,\n",
" pp_jaxpr(eqn.params['jaxpr']).indent(2)])\n",
"pp_rules[xla_call_p] = pprint_xla_call"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -2729,11 +2758,11 @@
"operations out of Python first before sorting out what can be evaluated now\n",
"and what must be delayed, we want only to form a jaxpr for those operations\n",
"that _must_ be delayed due to a dependence on unknown inputs. In the context\n",
"of automatic differentiation, this is the feature that ultimately enables us to\n",
"handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control\n",
"flow works because partial evaluation keeps the primal computation in Python.\n",
"As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out\n",
"what can be evaluated and what must be staged out into a jaxpr.\n",
"of automatic differentiation, this is the feature that ultimately enables us\n",
"to handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python\n",
"control flow works because partial evaluation keeps the primal computation in\n",
"Python. As a consequence, our `Trace` and `Tracer` subclasses must on the fly\n",
"sort out what can be evaluated and what must be staged out into a jaxpr.\n",
"\n",
"First, we start with a `PartialVal` class, which represents a value that can\n",
"be either known or unknown:"
Expand Down Expand Up @@ -2803,8 +2832,9 @@
"do so, it builds a bipartite directed acyclic graph (DAG) between\n",
"`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`\n",
"nodes, representing formulas for how to compute some values from others. One\n",
"kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive\n",
"application, but we also have recipe types for constants and lambda binders:"
"kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s\n",
"primitive application, but we also have recipe types for constants and lambda\n",
"binders:"
]
},
{
Expand Down Expand Up @@ -2945,11 +2975,12 @@
"source": [
"def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],\n",
" tracers_out: List[PartialEvalTracer]):\n",
" tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}\n",
" constvar_to_val = {}\n",
" constid_to_var = {}\n",
" processed_eqns = set()\n",
" eqns = []\n",
" tracer_to_var: Dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))\n",
" for t in tracers_in}\n",
" constvar_to_val: Dict[int, Any] = {}\n",
" constid_to_var: Dict[int, Var] = {}\n",
" processed_eqns: Set[int] = set()\n",
" eqns: List[JaxprEqn] = []\n",
" for t in toposort(tracers_out, tracer_parents):\n",
" if isinstance(t.recipe, LambdaBindingRecipe):\n",
" assert id(t) in set(map(id, tracers_in))\n",
Expand Down Expand Up @@ -3083,7 +3114,7 @@
"outputs": [],
"source": [
"def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):\n",
" del num_consts # Unused.\n",
" del num_consts # Unused\n",
" in_unknowns = [not t.pval.is_known for t in tracers]\n",
" jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)\n",
" known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)\n",
Expand All @@ -3106,8 +3137,8 @@
" env: Dict[Var, bool] = {}\n",
" residuals: Set[Var] = set()\n",
"\n",
" def read(v: Atom) -> bool:\n",
" return type(v) is Var and env[v]\n",
" def read(x: Atom) -> bool:\n",
" return type(x) is Var and env[x]\n",
"\n",
" def write(unk: bool, v: Var) -> None:\n",
" env[v] = unk\n",
Expand Down Expand Up @@ -3139,6 +3170,7 @@
" out_unknowns = map(op.or_, out_unknowns, instantiate)\n",
"\n",
" residuals, num_res = list(residuals), len(residuals)\n",
" assert all(type(v) is Var for v in residuals), residuals\n",
"\n",
" ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)\n",
" outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)\n",
Expand Down Expand Up @@ -3170,16 +3202,16 @@
"partial_eval_jaxpr_rules = {}\n",
"\n",
"def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,\n",
" ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:\n",
" ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Var]]:\n",
" jaxpr = eqn.params['jaxpr']\n",
" jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)\n",
" ins1, ins2 = partition_list(unks_in, eqn.inputs)\n",
" outs1, outs2 = partition_list(unks_out, eqn.out_binders)\n",
" residuals, _ = split_list(jaxpr2.in_binders, num_res)\n",
" out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)\n",
" residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]\n",
" eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),\n",
" outs1 + residuals)\n",
" out_binders1 + residuals)\n",
" eqn2 = JaxprEqn(xla_call_p, residuals + ins2,\n",
" dict(jaxpr=jaxpr2, num_consts=0), outs2)\n",
" dict(jaxpr=jaxpr2, num_consts=0), out_binders2)\n",
" return eqn1, eqn2, unks_out, residuals\n",
"partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn"
]
Expand Down Expand Up @@ -3395,7 +3427,7 @@
"transpose_rules[add_p] = add_transpose_rule\n",
"\n",
"def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):\n",
" del num_consts # Unused.\n",
" del num_consts # Unused\n",
" undef_primals = [type(x) is UndefPrimal for x in invals]\n",
" transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))\n",
" residuals, _ = partition_list(undef_primals, invals)\n",
Expand Down Expand Up @@ -3804,7 +3836,7 @@
"abstract_eval_rules[cond_p] = cond_abstract_eval\n",
"\n",
"def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):\n",
" del in_avals # Unused.\n",
" del in_avals # Unused\n",
" pred, *in_vals = in_vals\n",
" flat_vals, in_tree = tree_flatten(in_vals)\n",
" operand = xops.Tuple(c, flat_vals)\n",
Expand Down Expand Up @@ -3857,6 +3889,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_end_of_cell_marker": 0,
"lines_to_next_cell": 1
},
"outputs": [],
Expand Down Expand Up @@ -3954,7 +3987,8 @@
" eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],\n",
" dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),\n",
" outs2)\n",
" return eqn1, eqn2, unks_out, [eqn.inputs[0], *residuals]\n",
" res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals\n",
" return eqn1, eqn2, unks_out, res\n",
"partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn"
]
},
Expand Down Expand Up @@ -4002,12 +4036,37 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)\n",
"print(out)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"def pprint_cond(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
" true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']\n",
" new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}\n",
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
" rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>\n",
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
" for x in eqn.inputs)))\n",
" return vcat([lhs >> pp(' = ') >> rhs,\n",
" pp_jaxpr(true_jaxpr).indent(2),\n",
" pp_jaxpr(false_jaxpr).indent(2)])\n",
"pp_rules[cond_p] = pprint_cond"
]
}
],
"metadata": {
Expand Down

0 comments on commit 24de3e9

Please sign in to comment.