Skip to content

Commit

Permalink
Add a jaxpr interpreter for propagating memory kinds to output. It on…
Browse files Browse the repository at this point in the history
…ly triggers if we detect multiple memory kinds in the jaxpr.

This hopefully should go away when XLA implements it's own memory space propagation pass or JAX adds memory_kind to the type system of jaxpr i.e. on avals.

It's required to treat the following code blocks (1) and (2) as equivalent when lowering to stablehlo. In general shardings should also be treated the same way but we'll cross that bridge later.

1. `jit(f, out_shardings=s_host)`

2. ```
   @jax.jit
   def f(x):
     return jax.device_put(x, s_host)
   ```

PiperOrigin-RevId: 632621025
  • Loading branch information
yashk2810 authored and jax authors committed May 10, 2024
1 parent 27c932a commit a4693db
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 9 deletions.
6 changes: 6 additions & 0 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,9 @@ def _common_device_put_lowering(ctx, x, *, device, src):
f" platforms {ctx.module_context.platforms}")
return [x]
mlir.register_lowering(device_put_p, _common_device_put_lowering)

def _propagate_mem_kind_dp(xm, device=None, src=None):
if isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)):
return device.memory_kind
return None
pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp
36 changes: 28 additions & 8 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,7 @@ def lower_jaxpr_to_module(
num_partitions: int = 1,
all_default_mem_kind: bool = True,
input_output_aliases: None | tuple[int | None, ...] = None,
propagated_out_mem_kinds: tuple[None | str, ...] | None = None,
lowering_parameters: LoweringParameters,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
Expand Down Expand Up @@ -988,7 +989,8 @@ def lower_jaxpr_to_module(
arg_memory_kinds=arg_memory_kinds,
result_memory_kinds=result_memory_kinds,
arg_layouts=in_layouts,
result_layouts=out_layouts)
result_layouts=out_layouts,
propagated_out_mem_kinds=propagated_out_mem_kinds)

try:
if not ctx.module.operation.verify():
Expand Down Expand Up @@ -1137,6 +1139,7 @@ def lower_jaxpr_to_fun(
result_memory_kinds: Sequence[str | None] | None = None,
arg_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
result_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
propagated_out_mem_kinds: tuple[None | str, ...] | None = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
Expand Down Expand Up @@ -1200,6 +1203,8 @@ def aval_to_types(aval):
input_types = [*dim_var_types, *token_types, *input_types]
output_avals = [core.abstract_token] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
output_types = [*output_token_types, *token_types, *output_types]
if propagated_out_mem_kinds is None:
propagated_out_mem_kinds = (None,) * len(output_avals)

if input_output_aliases is not None:
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
Expand Down Expand Up @@ -1267,16 +1272,30 @@ def aval_to_types(aval):

ir_result_shardings = None
if result_shardings is not None:
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
ir_result_shardings = util.flatten(
[[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(out_avals, result_shardings, output_types)])
del out_avals
for a, s, types in zip(output_avals, result_shardings, output_types)])

ir_result_memory_kinds = None
custom_call_ir_result_memory_kinds = None
if result_memory_kinds is not None:
ir_result_memory_kinds = util.flatten(
[[mk] * len(types) for mk, types in zip(result_memory_kinds, output_types)])
res, custom_call_res = [], []
for pom, mk, types in zip(propagated_out_mem_kinds, result_memory_kinds,
output_types):
if pom is not None and mk is None:
res.append([pom] * len(types)) # type: ignore
else:
if pom is not None and mk is not None and pom != mk:
raise AssertionError(
f"propagated out memory kind ({pom}) does not match the memory"
f" kind specified in out_shardings of jit ({mk})")
res.append([mk] * len(types)) # type: ignore
# To add the custom call on the output to signal a transfer, only do it
# if memory kind comes from out_shardings on `jit` and result_memory_kinds
# comes from out_shardings on `jit`.
custom_call_res.append([mk] * len(types))
ir_result_memory_kinds = util.flatten(res)
custom_call_ir_result_memory_kinds = util.flatten(custom_call_res)

ir_result_layouts = None
if result_layouts is not None:
Expand Down Expand Up @@ -1462,10 +1481,11 @@ def aval_to_types(aval):

# Insert a custom call if output is on host because XLA needs that to do the
# transfer.
if ir_result_memory_kinds is not None:
if custom_call_ir_result_memory_kinds is not None and name == "main":
flat_outputs = [
o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]
for o, mk, o_aval in zip(
flat_outputs, custom_call_ir_result_memory_kinds, output_avals)]

if ir_result_shardings is not None and name == "main":
flat_outputs = [
Expand Down
44 changes: 43 additions & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1882,6 +1882,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
in_layouts, out_layouts, num_devices, device_assignment,
donated_invars, name_stack, all_default_mem_kind,
inout_aliases: None | tuple[None | int, ...],
propagated_out_mem_kinds: tuple[None | str, ...],
lowering_parameters: mlir.LoweringParameters):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings._gspmd_shardings
Expand Down Expand Up @@ -1959,6 +1960,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
num_partitions=num_partitions,
all_default_mem_kind=all_default_mem_kind,
input_output_aliases=inout_aliases,
propagated_out_mem_kinds=propagated_out_mem_kinds,
lowering_parameters=lowering_parameters)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
unordered_effects = list(
Expand Down Expand Up @@ -1996,6 +1998,39 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
return False
return True

memory_kind_propagate_rule = {} # type: ignore

@weakref_lru_cache
def get_out_memory_kinds_via_propagation(closed_jaxpr: core.ClosedJaxpr
) -> tuple[None | str]:
env = {} # type: ignore
jaxpr = closed_jaxpr.jaxpr

def read(var):
if type(var) is core.Literal:
return None
return env[var]

def write(var, val):
env[var] = val

def _default_rule(prim, num_outvars, *_, **__):
return [None] * num_outvars if prim.multiple_results else None

safe_map(write, jaxpr.invars, [None] * len(jaxpr.invars))
safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars))

for eqn in jaxpr.eqns:
in_mem_kinds = safe_map(read, eqn.invars)
rule = memory_kind_propagate_rule.get(
eqn.primitive, partial(_default_rule, eqn.primitive, len(eqn.outvars)))
out_mem_kinds = rule(*in_mem_kinds, **eqn.params)
if not eqn.primitive.multiple_results:
out_mem_kinds = [out_mem_kinds]
safe_map(write, eqn.outvars, out_mem_kinds)
return tuple(safe_map(read, jaxpr.outvars))


MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]


Expand Down Expand Up @@ -2116,6 +2151,13 @@ def lower_sharding_computation(
da_object,
it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding])) # type: ignore

# TODO(yashkatariya): Remove this when XLA can propagate memory kinds or when
# JAX puts memory kinds in the types of jaxpr.
if not all_default_mem_kind:
propagated_out_mem_kinds = get_out_memory_kinds_via_propagation(closed_jaxpr)
else:
propagated_out_mem_kinds = (None,) * len(global_out_avals)

spmd_mode_check(da_object, inline)

# 2. Build up the HLO
Expand All @@ -2131,7 +2173,7 @@ def lower_sharding_computation(
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
tuple(da_object) if prim_requires_devices else None, donated_invars,
name_stack, all_default_mem_kind, inout_aliases,
lowering_parameters=lowering_parameters)
propagated_out_mem_kinds, lowering_parameters=lowering_parameters)

# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
Expand Down Expand Up @@ -1300,6 +1301,11 @@ def scan_bind(*args, **params):
# TODO(mattjj,frostig): un-comment this pp rule
# core.pp_eqn_rules[scan_p] = _scan_pp_rule

def _propagate_mem_kind_scan(*xm, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll, _split_transpose):
return pxla.get_out_memory_kinds_via_propagation(jaxpr)
pxla.memory_kind_propagate_rule[scan_p] = _propagate_mem_kind_scan

### while_loop

@api_boundary
Expand Down
34 changes: 34 additions & 0 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,40 @@ def f(xs):
self.assertArraysEqual(out_host, np_inp + 1.0)
self.assertEqual(out_host.sharding, s_host)

def test_weight_offload_with_dp_on_output(self):
_, s_dev, np_inp, inp_dev = _create_inputs(
(8, 2), P("x", "y"), mem_kind="device")
s_host = s_dev.with_memory_kind('pinned_host')

@jax.jit
def f(x):
x = x * 2
y = jax.device_put(x, s_host)
return y

out_host = f(inp_dev)
self._check_device_put_addressable_shards(
out_host, np_inp * 2, s_host, 'pinned_host')

def test_output_streaming_inside_scan(self):
mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z"))
np_inp = np.arange(4096).reshape(16, 16, 16)
s_hbm = NamedSharding(mesh, P(None, "y", "z"), memory_kind="device")
arr_hbm = jax.device_put(np_inp, s_hbm)

@jax.jit
def f(xs):
def body(carry, x):
out_tpu = x + carry
return carry, jax.device_put(
out_tpu, NamedSharding(mesh, P("y", "z"), memory_kind="pinned_host"))
_, res = jax.lax.scan(body, 1, xs)
return res

out = f(arr_hbm)
self.assertArraysEqual(out, np_inp + 1)
self.assertEqual(out.sharding.memory_kind, 'pinned_host')


class ActivationOffloadingTest(jtu.JaxTestCase):

Expand Down

0 comments on commit a4693db

Please sign in to comment.