From 76bf9a8d1b13c81a1d8beb84c79697eedba7ed0c Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 11 Aug 2020 16:56:52 -0700 Subject: [PATCH] sketch out psum changes under new pmap semantics Co-authored-by: James Bradbury Co-authored-by: Matthew Johnson --- jax/lax/lax_parallel.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/jax/lax/lax_parallel.py b/jax/lax/lax_parallel.py index 2d5f6fe1f718..77048f2ccf3a 100644 --- a/jax/lax/lax_parallel.py +++ b/jax/lax/lax_parallel.py @@ -75,7 +75,7 @@ def psum(x, axis_name, *, axis_index_groups=None): leaves, treedef = tree_util.tree_flatten(x) leaves = [lax.convert_element_type(l, np.int32) if dtypes.dtype(l) == np.bool_ else l for l in leaves] - out_flat = psum_p.bind(*leaves, axis_name=axis_name, + out_flat = psum_p.bind(*leaves, axis_name=axis_name, keepdims=True, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) @@ -377,8 +377,14 @@ def _psum_transpose_rule(cts, _, axis_name, axis_index_groups): nonzero_in_cts = [lax.pbroadcast(ct, axis_name) for ct in nonzero_out_cts] return tree_util.tree_unflatten(treedef, nonzero_in_cts) -def _psum_abstract_eval(*avals, axis_name, axis_index_groups): - return [_remove_named_axis(axis_name, raise_to_shaped(aval)) for aval in avals] +def _psum_abstract_eval(*avals, axis_name, keepdims, axis_index_groups): + if not keepdims and axis_index_groups is not None: + raise ValueError( + f'unsupported psum with keepdims and ' + f'axis_index_groups {axis_index_groups}') + return [raise_to_shaped(aval) if keepdims + else _remove_named_axis(axis_name, raise_to_shaped(aval)) + for aval in avals] def _remove_named_axis(axis_name, aval): assert isinstance(aval, ShapedArray)