Skip to content

Commit

Permalink
Add live-analysis memory optimization to more jaxpr interpreters.
Browse files Browse the repository at this point in the history
Follow-up on 8a85e76

PiperOrigin-RevId: 540857501
  • Loading branch information
LenaMartens authored and jax authors committed Jun 16, 2023
1 parent 9fdaf5a commit fbf8823
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 9 deletions.
11 changes: 2 additions & 9 deletions jax/_src/checkify.py
Expand Up @@ -401,11 +401,7 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
err_vals, in_args = split_list(args, [err_tree.num_leaves])
error = jtu.tree_unflatten(err_tree, err_vals)

last_used = {v: None for v in jaxpr.outvars if not isinstance(v, core.Literal)}
for eqn in jaxpr.eqns[::-1]:
for v in eqn.invars:
if not isinstance(v, core.Literal) and v not in last_used:
last_used[v] = eqn
last_used = core.last_used(jaxpr)

def read_env(var: core.Atom):
if isinstance(var, core.Literal):
Expand All @@ -432,10 +428,7 @@ def write_env(var: core.Var, val: Any):
map(write_env, eqn.outvars, outvals)
else:
write_env(eqn.outvars[0], outvals)
for v in set(v for v in eqn.invars if not isinstance(v, core.Literal)):
if last_used[v] is eqn:
# Delete ref to variable when it is no longer needed by next equations.
del env[v]
core.clean_up_dead_vars(eqn, env, last_used)

return error, map(read_env, jaxpr.outvars)

Expand Down
22 changes: 22 additions & 0 deletions jax/_src/core.py
Expand Up @@ -439,6 +439,7 @@ def write(v: Var, val: Any) -> None:
env: Dict[Var, Any] = {}
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
lu = last_used(jaxpr)
for eqn in jaxpr.eqns:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
Expand All @@ -449,6 +450,7 @@ def write(v: Var, val: Any) -> None:
map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
clean_up_dead_vars(eqn, env, lu)
return map(read, jaxpr.outvars)


Expand Down Expand Up @@ -3148,3 +3150,23 @@ def pp_effect(effect: Effect, context: JaxprPpContext) -> pp.Doc:
if hasattr(effect, "_pretty_print"):
return effect._pretty_print(context)
return pp.text(str(effect))

# ------------------- Jaxpr util -------------------

def last_used(jaxpr: Jaxpr) -> Dict[Var, Optional[JaxprEqn]]:
"""Returns a mapping from every var in jaxpr to what equation uses it last."""
last_used: Dict[Var, Optional[JaxprEqn]] = {
v: None for v in jaxpr.outvars if not isinstance(v, Literal)}
for eqn in reversed(jaxpr.eqns):
for v in eqn.invars:
if not isinstance(v, Literal) and v not in last_used:
last_used[v] = eqn
return last_used

def clean_up_dead_vars(eqn: JaxprEqn, env: Dict[Var, Any],
last_used: Dict[Var, Optional[JaxprEqn]]):
"""Remove all eqn.invars from env if eqn is the last time they were used."""
for v in set(v for v in eqn.invars if not isinstance(v, Literal)):
if last_used[v] is eqn:
# Delete ref to variable when it is no longer needed by next equations.
del env[v]
2 changes: 2 additions & 0 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -1106,6 +1106,7 @@ def write(v: core.Var, node: Sequence[ir.Value]):
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
last_used = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
in_nodes = map(read, eqn.invars)
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
Expand Down Expand Up @@ -1168,6 +1169,7 @@ def write(v: core.Var, node: Sequence[ir.Value]):
ans, "lowering function returned a bad output", eqn)
assert len(ans) == len(eqn.outvars), (ans, eqn)
map(write, eqn.outvars, out_nodes)
core.clean_up_dead_vars(eqn, env, last_used)
return map(read, jaxpr.outvars), tokens

def _ir_consts(consts):
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/interpreters/partial_eval.py
Expand Up @@ -2533,12 +2533,14 @@ def write(v, val) -> None:

map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
last_used = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars]
out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars]
rule = padding_rules[eqn.primitive]
outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params)
map(write, eqn.outvars, outs)
core.clean_up_dead_vars(eqn, env, last_used)
return map(read, jaxpr.outvars)

def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue:
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/shard_map.py
Expand Up @@ -480,6 +480,7 @@ def write(v: core.Var, val: Set[AxisName]) -> None:

map(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars))
map(write, jaxpr.invars, in_rep)
last_used = core.last_used(jaxpr)
for e in jaxpr.eqns:
rule = _rep_rules.get(e.primitive, partial(_rep_rule, e.primitive))
out_rep = rule(mesh, *map(read, e.invars), **e.params)
Expand All @@ -488,6 +489,7 @@ def write(v: core.Var, val: Set[AxisName]) -> None:
map(write, e.outvars, out_rep)
else:
write(e.outvars[0], out_rep)
core.clean_up_dead_vars(e, env, last_used)
return map(read, jaxpr.outvars)

def _valid_repeats(mesh: Mesh, rep: Set[AxisName], dst: AxisNames) -> bool:
Expand Down

0 comments on commit fbf8823

Please sign in to comment.