Skip to content

Commit

Permalink
sketch out psum changes under new pmap semantics
Browse files Browse the repository at this point in the history
Co-authored-by: James Bradbury <jekbradbury@google.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
3 people committed Aug 11, 2020
1 parent 161e81c commit 76bf9a8
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions jax/lax/lax_parallel.py
Expand Up @@ -75,7 +75,7 @@ def psum(x, axis_name, *, axis_index_groups=None):
leaves, treedef = tree_util.tree_flatten(x)
leaves = [lax.convert_element_type(l, np.int32)
if dtypes.dtype(l) == np.bool_ else l for l in leaves]
out_flat = psum_p.bind(*leaves, axis_name=axis_name,
out_flat = psum_p.bind(*leaves, axis_name=axis_name, keepdims=True,
axis_index_groups=axis_index_groups)
return tree_util.tree_unflatten(treedef, out_flat)

Expand Down Expand Up @@ -377,8 +377,14 @@ def _psum_transpose_rule(cts, _, axis_name, axis_index_groups):
nonzero_in_cts = [lax.pbroadcast(ct, axis_name) for ct in nonzero_out_cts]
return tree_util.tree_unflatten(treedef, nonzero_in_cts)

def _psum_abstract_eval(*avals, axis_name, axis_index_groups):
return [_remove_named_axis(axis_name, raise_to_shaped(aval)) for aval in avals]
def _psum_abstract_eval(*avals, axis_name, keepdims, axis_index_groups):
if not keepdims and axis_index_groups is not None:
raise ValueError(
f'unsupported psum with keepdims and '
f'axis_index_groups {axis_index_groups}')
return [raise_to_shaped(aval) if keepdims
else _remove_named_axis(axis_name, raise_to_shaped(aval))
for aval in avals]

def _remove_named_axis(axis_name, aval):
assert isinstance(aval, ShapedArray)
Expand Down

0 comments on commit 76bf9a8

Please sign in to comment.