Skip to content

Commit

Permalink
Add TransferToMemoryKind as a private API to allow device_put to tr…
Browse files Browse the repository at this point in the history
…ansfer to different memories without specifying the sharding and allowing the SPMD partitioner to choose the sharding for the intermediate.

Exposing it as a public API can be done later.

PiperOrigin-RevId: 559314369
  • Loading branch information
yashk2810 authored and jax authors committed Aug 23, 2023
1 parent bad217b commit aeb62cc
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 12 deletions.
16 changes: 9 additions & 7 deletions jax/_src/api.py
Expand Up @@ -66,7 +66,7 @@
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import PmapSharding
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
Expand Down Expand Up @@ -2489,8 +2489,8 @@ def _infer_src_sharding(src, x):

def device_put(
x,
device: None | xc.Device | Sharding | Any = None,
*, src: None | xc.Device | Sharding | Any = None):
device: None | xc.Device | Sharding | Any | TransferToMemoryKind = None,
*, src: None | xc.Device | Sharding | Any | TransferToMemoryKind = None):
"""Transfers ``x`` to ``device``.
Args:
Expand All @@ -2514,8 +2514,10 @@ def device_put(
blocking the calling Python thread until any transfers are completed.
"""
with config_explicit_device_put_scope():
if ((device is None or isinstance(device, (xc.Device, Sharding))) and
(src is None or isinstance(src, (xc.Device, Sharding)))):
if ((device is None or
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
(src is None or
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
return tree_map(
lambda y: dispatch.device_put_p.bind(
y, device=device, src=_infer_src_sharding(src, y)), x)
Expand All @@ -2524,8 +2526,8 @@ def device_put(
device_flat = flatten_axes("device_put device", treedef, device)
src_flat = flatten_axes("device_put source", treedef, src)
out_flat = [
dispatch.device_put_p.bind(y, device=d, src=_infer_src_sharding(s, y))
for y, d, s in zip(x_flat, device_flat, src_flat)
dispatch.device_put_p.bind(xf, device=d, src=_infer_src_sharding(s, xf))
for xf, d, s in zip(x_flat, device_flat, src_flat)
]
return tree_unflatten(treedef, out_flat)

Expand Down
18 changes: 14 additions & 4 deletions jax/_src/dispatch.py
Expand Up @@ -49,7 +49,7 @@
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
UNSPECIFIED, GSPMDSharding)
UNSPECIFIED, GSPMDSharding, TransferToMemoryKind)


JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
Expand Down Expand Up @@ -467,6 +467,14 @@ def _device_put_impl(
device: Device | Sharding | None = None,
src: Device | Sharding | None = None):
from jax._src import array

if (isinstance(device, TransferToMemoryKind) or
isinstance(src, TransferToMemoryKind)):
raise ValueError(
"TransferToMemoryKind argument to jax.device_put can only be used"
" inside jax.jit. If you are using device_put outside jax.jit, then"
" please provide a concrete Sharding with memory_kind.")

try:
aval = xla.abstractify(x)
except TypeError as err:
Expand Down Expand Up @@ -521,12 +529,14 @@ def device_put_transpose_rule(ct, _, device, src):
batching.defvectorized(device_put_p)

def _device_put_lowering(ctx, x, *, device, src):
if isinstance(device, XLACompatibleSharding) and device.memory_kind is not None:
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
device.memory_kind is not None):
aval, = ctx.avals_in
out_aval, = ctx.avals_out
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
if isinstance(device, XLACompatibleSharding):
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
return [x]
return [x]
mlir.register_lowering(device_put_p, _device_put_lowering)
14 changes: 13 additions & 1 deletion jax/_src/interpreters/pxla.py
Expand Up @@ -1925,6 +1925,17 @@ def _create_da_object( # pytype: disable=invalid-annotation
return _DeviceAssignment(device_assignment)


def jaxpr_has_dp_with_transfer_mem_kind(jaxpr: core.Jaxpr) -> bool:
for eqn in jaxpr.eqns:
if (eqn.primitive is dispatch.device_put_p and
isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)):
return True
for subjaxpr in core.subjaxprs(jaxpr):
if jaxpr_has_dp_with_transfer_mem_kind(subjaxpr):
return True
return False


@profiler.annotate_function
def lower_sharding_computation(
fun_or_jaxpr: lu.WrappedFun | core.ClosedJaxpr,
Expand Down Expand Up @@ -1983,7 +1994,8 @@ def lower_sharding_computation(
len(device_assignment) > 1 or
any(not is_unspecified(i) for i in in_shardings) or
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(o) for o in out_shardings))
any(not is_unspecified(o) for o in out_shardings) or
jaxpr_has_dp_with_transfer_mem_kind(jaxpr))

gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/sharding_impls.py
Expand Up @@ -58,6 +58,11 @@
"and annotate Shardings with it."))


@dataclasses.dataclass(frozen=True)
class TransferToMemoryKind:
memory_kind: str


# Shardings that inherit from XLACompatibleSharding should implement the
# `_device_assignment` property and `_to_xla_hlo_sharding` method.
@use_cpp_class(xc.XLACompatibleSharding)
Expand Down

0 comments on commit aeb62cc

Please sign in to comment.