From c36c4287210b9829bae58abf0ce3ccb8db887065 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 19 Oct 2023 00:37:49 -0700 Subject: [PATCH] Cleanup untile lowering to remove platform dependence. The workaround for cpu and gpu for booleans is not necessary anymore. --- jax/_src/maps.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/jax/_src/maps.py b/jax/_src/maps.py index b8a75bc7a512..ca7ff271b9e0 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -1365,8 +1365,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes, outs = [ mlir.lower_fun( - partial(_untile, out_axes=ans_out_axes, axis_sizes=local_mesh_shape, - platform=ctx.module_context.platform), + partial(_untile, out_axes=ans_out_axes, axis_sizes=local_mesh_shape), multiple_results=False)( ctx.replace(primitive=None, avals_in=[vectorized_outvar.aval], @@ -1520,13 +1519,7 @@ def _tile(x, in_axes, axis_sizes): # TODO(b/110096942): more efficient gather -def _untile(x, out_axes, axis_sizes, platform): - # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU - convert_bool = (np.issubdtype(x.dtype, np.bool_) - and platform in ('cpu', 'gpu')) - if convert_bool: - x = lax.convert_element_type(x, np.dtype(np.float32)) - +def _untile(x, out_axes, axis_sizes): tile_shape = list(x.shape) shape = list(tile_shape) for name, axis in out_axes.items(): @@ -1536,11 +1529,6 @@ def _untile(x, out_axes, axis_sizes, platform): padded = lax.broadcast(np.array(0, x.dtype), shape) padded = lax.dynamic_update_slice(padded, x, base_idxs) out = lax.psum(padded, tuple(out_axes.keys())) - - # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU - if convert_bool: - nonzero = lax.ne(out, np.array(0, dtype=np.float32)) - out = lax.convert_element_type(nonzero, np.dtype(np.bool_)) return out