Skip to content

Commit

Permalink
Cleanup untile lowering to remove platform dependence.
Browse files Browse the repository at this point in the history
The workaround for cpu and gpu for booleans is not necessary anymore.
  • Loading branch information
gnecula committed Oct 19, 2023
1 parent a40e7ed commit c36c428
Showing 1 changed file with 2 additions and 14 deletions.
16 changes: 2 additions & 14 deletions jax/_src/maps.py
Expand Up @@ -1365,8 +1365,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,

outs = [
mlir.lower_fun(
partial(_untile, out_axes=ans_out_axes, axis_sizes=local_mesh_shape,
platform=ctx.module_context.platform),
partial(_untile, out_axes=ans_out_axes, axis_sizes=local_mesh_shape),
multiple_results=False)(
ctx.replace(primitive=None,
avals_in=[vectorized_outvar.aval],
Expand Down Expand Up @@ -1520,13 +1519,7 @@ def _tile(x, in_axes, axis_sizes):


# TODO(b/110096942): more efficient gather
def _untile(x, out_axes, axis_sizes, platform):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
convert_bool = (np.issubdtype(x.dtype, np.bool_)
and platform in ('cpu', 'gpu'))
if convert_bool:
x = lax.convert_element_type(x, np.dtype(np.float32))

def _untile(x, out_axes, axis_sizes):
tile_shape = list(x.shape)
shape = list(tile_shape)
for name, axis in out_axes.items():
Expand All @@ -1536,11 +1529,6 @@ def _untile(x, out_axes, axis_sizes, platform):
padded = lax.broadcast(np.array(0, x.dtype), shape)
padded = lax.dynamic_update_slice(padded, x, base_idxs)
out = lax.psum(padded, tuple(out_axes.keys()))

# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
nonzero = lax.ne(out, np.array(0, dtype=np.float32))
out = lax.convert_element_type(nonzero, np.dtype(np.bool_))
return out


Expand Down

0 comments on commit c36c428

Please sign in to comment.