Skip to content

Commit

Permalink
Switch the order of sharding and memory kind custom call
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634505383
  • Loading branch information
yashk2810 authored and jax authors committed May 16, 2024
1 parent 01194bd commit 74ffed9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,10 @@ def _tpu_gpu_device_put_lowering(ctx, x, *, device, src):
device.memory_kind is not None):
aval, = ctx.avals_in
out_aval, = ctx.avals_out
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
if isinstance(device, XLACompatibleSharding):
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
return [x]
return [x]
mlir.register_lowering(
Expand Down

0 comments on commit 74ffed9

Please sign in to comment.