Skip to content

Commit

Permalink
factor out subs_list and subs_list2
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Oct 19, 2023
1 parent 0944010 commit 1ce8313
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 35 deletions.
15 changes: 4 additions & 11 deletions jax/_src/interpreters/partial_eval.py
Expand Up @@ -48,7 +48,7 @@
KeyPath, generate_key_paths, keystr)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache)
as_hashable_function, weakref_lru_cache, subs_list)


map, unsafe_map = safe_map, map
Expand Down Expand Up @@ -272,7 +272,7 @@ def process_call(self, primitive, f, tracers, params):
*in_consts, **const_params)
fwds, out_knowns, out_type, jaxpr, env = aux()
# Split apart known outputs from the original call and non-fwded residuals.
out_consts, non_fwd_res_ = split_list(out, [sum(out_knowns)])
out_consts, non_fwd_res = split_list(out, [sum(out_knowns)])

# Form the complete list of residuals by forwarding some inputs.
if config.dynamic_shapes.value:
Expand All @@ -288,10 +288,7 @@ def process_call(self, primitive, f, tracers, params):
in_consts_full[d1.val] = d2
else:
in_consts_full = in_consts
non_fwd_res = iter(non_fwd_res_)
res = [next(non_fwd_res) if i is None else in_consts_full[i] for i in fwds]
sentinel = object()
assert next(non_fwd_res, sentinel) is sentinel
res = subs_list(fwds, in_consts_full, non_fwd_res)

# Create the input tracers for the staged-out (unknown-value) call.
res_tracers = map(self.instantiate_const, map(self.new_const, res))
Expand Down Expand Up @@ -1436,11 +1433,7 @@ def closed_call_partial_eval_custom_rule(
[newvar(res_aval(params_known, v))
for v in jaxpr_staged.in_avals[:num_res]], [num_res_val])
res_val_binders = [v for v, f in zip(res_val_binders, out_fwd) if f is None]
res_val_binders_ = iter(res_val_binders)
res_val_vars = [out_binders_known[f] if f is not None
else next(res_val_binders_) for f in out_fwd]
sentinel = object()
assert next(res_val_binders_, sentinel) is sentinel
res_val_vars = subs_list(out_fwd, out_binders_known, res_val_binders)
eqn_known = new_jaxpr_eqn([*ins_known, *res_ref_binders],
[*out_binders_known, *res_val_binders],
eqn.primitive, params_known, jaxpr_known.effects,
Expand Down
14 changes: 4 additions & 10 deletions jax/_src/pjit.py
Expand Up @@ -70,7 +70,7 @@
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, weakref_lru_cache,
merge_lists, flatten, unflatten)
merge_lists, flatten, unflatten, subs_list2)

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand Down Expand Up @@ -1571,18 +1571,12 @@ def keep_where(l, should_keep):
# Bind known things to pjit_p.
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
all_known_outs = pjit_p.bind(*known_inputs, **known_params)

all_known_outs_ = iter(all_known_outs)
all_known_outs = [known_inputs[f1] if f1 is not None else
all_known_outs[f2] if f2 is not None else
next(all_known_outs_) for f1, f2 in zip(in_fwd, out_fwd)]
sentinel = object()
assert next(all_known_outs_, sentinel) is sentinel
del all_known_outs_, known_inputs
all_known_outs = subs_list2(in_fwd, out_fwd, known_inputs, all_known_outs,
all_known_outs)

known_out_vals, residual_vals = \
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
residual_tracers = [trace.new_instantiated_const(residual) for residual in residual_vals]
residual_tracers = map(trace.new_instantiated_const, residual_vals)

# The convention of partial_eval_jaxpr_nounits is to place residual binders
# at the front of the jaxpr produced, so we move them to the back since both
Expand Down
21 changes: 21 additions & 0 deletions jax/_src/util.py
Expand Up @@ -143,6 +143,27 @@ def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> list[T]
assert next(i0, sentinel) is next(i1, sentinel) is sentinel
return out

def subs_list(
subs: Sequence[Optional[int]], src: Sequence[T], base: Sequence[T],
) -> list[T]:
base_ = iter(base)
out = [src[i] if i is not None else next(base_) for i in subs]
sentinel = object()
assert next(base_, sentinel) is sentinel
return out

def subs_list2(
subs1: Sequence[Optional[int]], subs2: Sequence[Optional[int]],
src1: Sequence[T], src2: Sequence[T], base: Sequence[T],
) -> list[T]:
assert len(subs1) == len(subs2)
base_ = iter(base)
out = [src1[f1] if f1 is not None else src2[f2] if f2 is not None else
next(base_) for f1, f2, in zip(subs1, subs2)]
sentinel = object()
assert next(base_, sentinel) is sentinel
return out

def split_dict(dct, names):
dct = dict(dct)
lst = [dct.pop(name) for name in names]
Expand Down
19 changes: 5 additions & 14 deletions jax/experimental/shard_map.py
Expand Up @@ -49,7 +49,7 @@
windowed_reductions, fft, linalg, control_flow)
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
as_hashable_function, memoize, partition_list,
merge_lists, split_list)
merge_lists, split_list, subs_list2)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand Down Expand Up @@ -1284,11 +1284,7 @@ def known_out_names():
assert not jaxpr.constvars
unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk())
known_out_names_ = known_out_names()
non_fwd_res_ = iter(non_fwd_res)
res = [in_consts[f1] if f1 is not None else out_consts[f2] if f2 is not None
else next(non_fwd_res_) for f1, f2 in zip(in_fwd, out_fwd)]
sentinel = object()
assert next(non_fwd_res_, sentinel) is sentinel
res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res)
res_names = [known_in_names[f1] if f1 is not None else
known_out_names_[f2] if f2 is not None else
{0: (*mesh.axis_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
Expand All @@ -1315,8 +1311,8 @@ def _shard_map_partial_eval_post_process(
unk_tracers = [t for t in tracers if not t.is_known()]
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
# TODO(mattjj): output forwarding optimization
which = [not v.aval.shape for v in jaxpr.constvars]
res = [jax.lax.broadcast(x, (1,)) if not v.aval.shape else x
which = [not getattr(v.aval, 'shape', True) for v in jaxpr.constvars]
res = [jax.lax.broadcast(x, (1,)) if not getattr(v.aval, 'shape', True) else x
for x, v in zip(res, jaxpr.constvars)]
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)

Expand Down Expand Up @@ -1466,12 +1462,7 @@ def _partial_eval_jaxpr_custom_rule(
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects,
eqn.source_info)
residuals_ = iter(residuals)
full_res = [ins_known[f1] if f1 is not None else
out_binders_known[f2] if f2 is not None else
next(residuals_) for f1, f2 in zip(in_fwd, out_fwd)]
sentinel = object()
assert next(residuals_, sentinel) is sentinel
full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals)
eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged,
eqn.primitive, params_staged,
jaxpr_staged.effects, eqn.source_info)
Expand Down

0 comments on commit 1ce8313

Please sign in to comment.