diff --git a/jax/lax/lax_parallel.py b/jax/lax/lax_parallel.py index 2d5f6fe1f718..77048f2ccf3a 100644 --- a/jax/lax/lax_parallel.py +++ b/jax/lax/lax_parallel.py @@ -75,7 +75,7 @@ def psum(x, axis_name, *, axis_index_groups=None): leaves, treedef = tree_util.tree_flatten(x) leaves = [lax.convert_element_type(l, np.int32) if dtypes.dtype(l) == np.bool_ else l for l in leaves] - out_flat = psum_p.bind(*leaves, axis_name=axis_name, + out_flat = psum_p.bind(*leaves, axis_name=axis_name, keepdims=True, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) @@ -377,8 +377,14 @@ def _psum_transpose_rule(cts, _, axis_name, axis_index_groups): nonzero_in_cts = [lax.pbroadcast(ct, axis_name) for ct in nonzero_out_cts] return tree_util.tree_unflatten(treedef, nonzero_in_cts) -def _psum_abstract_eval(*avals, axis_name, axis_index_groups): - return [_remove_named_axis(axis_name, raise_to_shaped(aval)) for aval in avals] +def _psum_abstract_eval(*avals, axis_name, keepdims, axis_index_groups): + if not keepdims and axis_index_groups is not None: + raise ValueError( + f'unsupported psum with keepdims and ' + f'axis_index_groups {axis_index_groups}') + return [raise_to_shaped(aval) if keepdims + else _remove_named_axis(axis_name, raise_to_shaped(aval)) + for aval in avals] def _remove_named_axis(axis_name, aval): assert isinstance(aval, ShapedArray)