From b52bcc1639368069075284eefc763f824ca155f1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 13 Dec 2023 13:44:36 -0800 Subject: [PATCH] Reverts 3c07c10a9a55f9a32390dd10cf3f420bdf3f1ed8 PiperOrigin-RevId: 590700623 --- CHANGELOG.md | 10 ---------- jax/_src/interpreters/pxla.py | 14 ++++++++++++-- tests/pjit_test.py | 11 ++++++++--- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4cc6606a451f..a1c7df9df882 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,16 +8,6 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.22 -* Changes - * JAX lowering to StableHLO does not depend on physical devices anymore. - If your primitive wraps custom_paritioning or JAX callbacks in the lowering - rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your - primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set. - This is needed because custom_partitioning and JAX callbacks need physical - devices to create `Sharding`s during lowering. - This is a temporary state until we can create `Sharding`s without physical - devices. - * Deprecations * The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated. Explicit buffers have been replaced by the more flexible array sharding interface, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index fdbe4edb8b2c..6a19d2844771 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -102,6 +102,12 @@ class WeakRefList(list): MeshDimAssignment = Union[ShardedAxis, Replicated] ShardingSpec = sharding_specs.ShardingSpec +# TODO(yashkatariya): Remove this flag when nvidia's use cases are fixed. +_JAX_REQUIRE_DEVICES_DURING_LOWERING = config.DEFINE_bool( + "jax_require_devices_during_lowering", + True, + help="Forces physical devices to be passed during lowering to stablehlo.") + ### util def identity(x): return x @@ -1971,13 +1977,17 @@ def lower_sharding_computation( semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) + materialized_da = ( + tuple(da_object) + if prim_requires_devices or _JAX_REQUIRE_DEVICES_DURING_LOWERING.value + else None) (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, in_layouts, out_layouts, len(da_object), - tuple(da_object) if prim_requires_devices else None, donated_invars, - name_stack, all_default_mem_kind, lowering_parameters=lowering_parameters) + materialized_da, donated_invars, name_stack, all_default_mem_kind, + lowering_parameters=lowering_parameters) # backend and device_assignment is passed through to MeshExecutable because # if keep_unused=False and all in_shardings are pruned, then there is no way diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 73642ee1db6e..b6a5ab8d14c0 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3802,9 +3802,14 @@ def g(a): b = jax.device_put(out_a, NamedSharding(mesh2, P('y'))) f(b) # lowering cache *hit* - with jtu.count_jit_and_pmap_compiles() as count: - g(np.arange(8)) - self.assertEqual(count[0], 1) + prev_value = pxla._JAX_REQUIRE_DEVICES_DURING_LOWERING.value + try: + jax.config.update('jax_require_devices_during_lowering', False) + with jtu.count_jit_and_pmap_compiles() as count: + g(np.arange(8)) + self.assertEqual(count[0], 1) + finally: + jax.config.update('jax_require_devices_during_lowering', prev_value) def test_lowering_cache_miss_different_devices_and_sharding(self): if jax.device_count() < 4: