From 8a2d4a01f578f3e9dce7a48968031c6efc5403e8 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 4 Dec 2023 10:24:01 +0200 Subject: [PATCH] [export] Add and fix a test for exporting higher-order gradients with sharding There was a test for export with gradients, we changed the test to (a) export 2nd order gradient also, and (b) to export both with a mesh context and without a mesh context (using NamedSharding). This test currently fails, only in the case when we do NOT have a mesh context, as explained below: When exporting gradient functions, we first export the primal functions and we use the in/out-shardings to construct shardings of the gradient function. Since Exported shardings now contain only HloSharding objects, and to lower the gradient function we must use `pjit(vjp(f)).lower()`, we construct GSPMDSharding objects using the current devices and the HloSharding object from the Exported primal. However, these objects do not have the `_original_sharding` attribute. Later in `pjit._resource_typing_pjit` we attempt to `parse_flatten_op_sharding` using the mesh context (which is empty). This fails. This PR contains one workaround, to skip `parse_flatten_op_sharding` if the physical mesh of the `resource_env` is empty. Another, probably better solution, is to ensure that `resource_env` is `None` when then is no mesh context. That seemed reasonable, but currently the code returns an empty mesh from the resource_env if there is no mesh context. Changing this would have effects in more parts of the code, so I have not done it here, but it may be worth doing. --- jax/_src/pjit.py | 4 +- .../jax2tf/tests/sharding_test.py | 25 +++++-- tests/export_test.py | 67 ++++++++++++++++--- 3 files changed, 80 insertions(+), 16 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 4140876e373d..8ace238be02f 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1816,7 +1816,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r s._original_sharding, '_parsed_pspec'): parsed_pspec = s._original_sharding._parsed_pspec else: - if resource_env is not None: + if resource_env is not None and not resource_env.physical_mesh.empty: parsed_pspec = parse_flatten_op_sharding( s._hlo_sharding, resource_env.physical_mesh)[0] else: @@ -1838,7 +1838,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r s._original_sharding, '_parsed_pspec'): parsed_pspec = s._original_sharding._parsed_pspec else: - if resource_env is not None: + if resource_env is not None and not resource_env.physical_mesh.empty: parsed_pspec = parse_flatten_op_sharding( s._hlo_sharding, resource_env.physical_mesh)[0] else: diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 568153f58c0a..0ba785261b57 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -22,6 +22,7 @@ import contextlib from functools import partial import logging +import math import os import re from typing import Any @@ -40,6 +41,7 @@ from jax.experimental import pjit from jax.experimental.maps import xmap from jax.experimental.shard_map import shard_map +from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P import jax.numpy as jnp @@ -382,16 +384,25 @@ def f_jax(x): # x: f32[10, 20], optionally some axes as polymorphic for in_shardings in ("missing", None, "P") for out_shardings in ("missing", None, "P") ]) - @jtu.with_mesh([("x", 2)]) def test_grad_pjit(self, in_shardings="P", out_shardings=None): + if not config.jax2tf_default_native_serialization.value: + self.skipTest("TODO: failure in non-native serialization") + local_devices = list(jax.local_devices()) + size = 2 + if len(local_devices) < size: + raise unittest.SkipTest(f"Test requires {size} local devices") + mesh_devices = np.array(local_devices[:size]).reshape((2,)) + mesh = jax.sharding.Mesh(mesh_devices, ("x",)) def f_jax(x): # x: f32[10,20] -> f32[20,10] return jnp.sin(x.T) pjit_kwargs = {} if in_shardings != "missing": - pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None) + pjit_kwargs["in_shardings"] = ( + NamedSharding(mesh, P(None, "x")) if in_shardings == "P" else None) if out_shardings != "missing": - pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None) + pjit_kwargs["out_shardings"] = ( + NamedSharding(mesh, P("x", None)) if out_shardings == "P" else None) f_jax = pjit.pjit(f_jax, **pjit_kwargs) x_shape = (10, 20) x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) @@ -399,8 +410,12 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] def f_grad_tf(x_v, res_ct): with tf.GradientTape(persistent=True) as tape: tape.watch(x_v) - res_tf = jax2tf.convert(f_jax)(x_v) - return tape.gradient(res_tf, x_v, output_gradients=res_ct) + with tf.GradientTape() as tape2: + tape2.watch(x_v) + res_tf = jax2tf.convert(f_jax)(x_v) + dy_dx = tape.gradient(res_tf, x_v, output_gradients=res_ct) + d2y_dx2 = tape.gradient(dy_dx, x_v) + return d2y_dx2 # Annotation count for the primal input and the grad output count_in_P = self.GEQ(2) if in_shardings == "P" else 0 diff --git a/tests/export_test.py b/tests/export_test.py index 4ad0f5f50699..65904dcdf92e 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -28,6 +28,7 @@ from jax import tree_util from jax.experimental.export import export from jax.experimental import pjit +from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P @@ -755,13 +756,16 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] )(a) @jtu.parameterized_filterable( + one_containing="in_shardings_None_out_shardings_P_with_mesh_False", kwargs=[ - dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}", - in_shardings=in_shardings, out_shardings=out_shardings) + dict(in_shardings=in_shardings, out_shardings=out_shardings, + with_mesh=with_mesh) for in_shardings in ("missing", None, "P") for out_shardings in ("missing", None, "P") + for with_mesh in (True, False) ]) - def test_grad_with_sharding(self, in_shardings="P", out_shardings=None): + def test_grad_with_sharding(self, in_shardings="P", out_shardings=None, + with_mesh=False): if len(jax.devices()) < 2: self.skipTest("Test requires at least 2 devices") x_shape = (10, 20) @@ -769,16 +773,33 @@ def test_grad_with_sharding(self, in_shardings="P", out_shardings=None): def f_jax(x): # x: f32[10,20] -> f32[20,10] return jnp.sin(x.T) + mesh = Mesh(jax.devices()[:2], "d") pjit_kwargs = {} + # Use NamedShardings if we don't have a mesh_context + if with_mesh: + sharding_None_d = P(None, "d") + sharding_d_None = P("d", None) + else: + sharding_None_d = NamedSharding(mesh, P(None, "d")) + sharding_d_None = NamedSharding(mesh, P("d", None)) + if in_shardings != "missing": - pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None) + pjit_kwargs["in_shardings"] = ( + sharding_None_d if in_shardings == "P" else None) if out_shardings != "missing": - pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None) - f_jax = pjit.pjit(f_jax, **pjit_kwargs) + pjit_kwargs["out_shardings"] = ( + sharding_d_None if out_shardings == "P" else None) + f_jax_pjit = pjit.pjit(f_jax, **pjit_kwargs) + + with contextlib.ExitStack() as stack: + if with_mesh: + stack.enter_context(mesh) + # Serialize higher-order gradiends + exp = export.export(f_jax_pjit)(x) - with Mesh(jax.devices()[:2], "x"): - exp = export.export(f_jax)(x) exp_vjp = exp.vjp() + # Try 2nd order grad as well + exp_vjp2 = exp_vjp.vjp() vjp_module_str = str(exp_vjp.mlir_module()) @@ -812,13 +833,41 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] # Custom calls for the primal output shape all match primal_out_sharding primal_out_calls = re.findall( - r"custom_call @Sharding.* {mhlo.sharding = \"(.+)\".*:.*tensor<20x10xf32>", + r"custom_call @Sharding.*mhlo.sharding = \"(.+)\".*:.*tensor<20x10xf32>", vjp_module_str) self.assertTrue( all(s == primal_out_sharding for s in primal_out_calls), primal_in_calls ) + # Call the exported gradient functions. In order to set the device context + # we replicate the inputs. If we don't use a mesh context and there are + # no shardings on inputs or outputs, then we have serialized for one + # device. + if in_shardings != "P" and out_shardings != "P" and not with_mesh: + self.assertEqual(exp_vjp.nr_devices, 1) + self.assertEqual(exp_vjp2.nr_devices, 1) + call_mesh = Mesh(jax.devices()[:1], "e") + else: + self.assertEqual(exp_vjp.nr_devices, 2) + self.assertEqual(exp_vjp2.nr_devices, 2) + call_mesh = Mesh(jax.devices()[:2], "e") + + g1 = pjit.pjit(export.call_exported(exp_vjp), + in_shardings=(NamedSharding(call_mesh, None), + NamedSharding(call_mesh, None)))(x, x.T) + _, f_jax_vjp = jax.vjp(f_jax, x) + xbar = f_jax_vjp(x.T) + self.assertAllClose(xbar, g1) + + g2 = pjit.pjit(export.call_exported(exp_vjp2), + in_shardings=(NamedSharding(call_mesh, None), + NamedSharding(call_mesh, None), + NamedSharding(call_mesh, None)))(x, x.T, x) + _, f_jax_vjp2 = jax.vjp(f_jax_vjp, x.T) + xbar2, = f_jax_vjp2((x,)) + self.assertAllClose(xbar2, g2[1]) + def test_multi_platform(self): x = np.arange(8, dtype=np.float32) exp = export.export(_testing_multi_platform_func,