Skip to content

Commit

Permalink
[jax2tf] Fix grad of pjit in native lowering.
Browse files Browse the repository at this point in the history
Since jax2tf.convert is called recursively for the purpose of
serializing the vjp function, we must ensure that if the primal
function is a pjit with shardings then the vjp function must also
be converted as a pjit.

Without this fix the serialization with gradients of a pjit function
will fail the an error that there are shardings but not pjit at
the top-level.
  • Loading branch information
gnecula authored and joglekara committed Mar 24, 2023
1 parent ffc8a34 commit ce3f534
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 27 deletions.
2 changes: 1 addition & 1 deletion jax/_src/pjit.py
Expand Up @@ -784,7 +784,7 @@ def flatten_axis_resources(what, tree, shardings, tupled_args):

axis_tree = shardings

# Because ecause we only have the `tree` treedef and not the full pytree here,
# Because we only have the `tree` treedef and not the full pytree here,
# we construct a dummy tree to compare against. Revise this in callers?
dummy_tree = tree_unflatten(tree, [PytreeLeaf()] * tree.num_leaves)
errors = prefix_errors(axis_tree, dummy_tree)
Expand Down
89 changes: 66 additions & 23 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -430,16 +430,15 @@ def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal:
lowering_platform = native_serialization_platforms[0]
else:
lowering_platform = None
exported: Exported = serialize_native(
exported: Optional[Exported] = serialize_native(
fun_flat_jax, args_avals_flat,
lowering_platform=lowering_platform,
strict_checks=native_serialization_strict_checks)

def run_fun_flat_as_tf(
args_flat_tf: Sequence[TfVal]
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
outs_tf, out_avals = run_exported_as_tf(
args_avals_flat, args_flat_tf, exported,
args_avals_flat, args_flat_tf, exported, # type: ignore
native_serialization_strict_checks)
return outs_tf, out_avals
else:
Expand All @@ -448,6 +447,7 @@ def run_fun_flat_as_tf(
dim_values, _ = _interpret_fun_jax(get_dim_values_jax, args_flat_tf,
args_avals_flat, name_stack)
shape_env = zip(dim_vars, dim_values) # type: ignore
exported = None
def run_fun_flat_as_tf(
args_flat_tf: Sequence[TfVal]
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
Expand Down Expand Up @@ -477,15 +477,16 @@ def run_fun_flat_as_tf(
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
outs_tf, out_avals = run_fun_flat_as_tf(args_flat_tf)
return (tuple(outs_tf),
make_custom_gradient_fn_tf(
_make_custom_gradient_fn_tf(
fun_flat_jax=fun_flat_jax,
args_flat_tf=args_flat_tf,
args_avals_flat=args_avals_flat,
polymorphic_shapes_flat=polymorphic_shapes_flat,
out_avals=out_avals,
native_serialization=native_serialization,
native_serialization_platforms=native_serialization_platforms,
native_serialization_strict_checks=native_serialization_strict_checks))
native_serialization_strict_checks=native_serialization_strict_checks,
exported_primal=exported))

out_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
else:
Expand Down Expand Up @@ -599,17 +600,19 @@ def preprocess_arg_tf(arg_idx: int,
return arg_tf, arg_aval


# Prepare the grad_fn for tf.custom_gradient.
def make_custom_gradient_fn_tf(*,
fun_flat_jax: Callable,
args_flat_tf: Sequence[TfVal],
polymorphic_shapes_flat: Sequence[str],
args_avals_flat: Sequence[core.ShapedArray],
out_avals: Sequence[core.ShapedArray],
native_serialization: Union[str, bool],
native_serialization_platforms: Sequence[str],
native_serialization_strict_checks: bool
):
def _make_custom_gradient_fn_tf(*,
fun_flat_jax: Callable,
args_flat_tf: Sequence[TfVal],
polymorphic_shapes_flat: Sequence[str],
args_avals_flat: Sequence[core.ShapedArray],
out_avals: Sequence[core.ShapedArray],
native_serialization: Union[str, bool],
native_serialization_platforms: Sequence[str],
native_serialization_strict_checks: bool,
exported_primal: Optional["Exported"]):
"""Prepares the TF function to be used with tf.custom_gradient.
"""

def grad_fn_tf(*out_cts_flat_tf: TfVal,
variables=None):
Expand Down Expand Up @@ -659,6 +662,45 @@ def fix_in_ct(in_ct_jax, arg_aval: core.ShapedArray):
in_cts_fixed_flat_jax = tuple(map(fix_in_ct, in_cts_flat_jax, args_avals_flat))
return in_cts_fixed_flat_jax

if exported_primal is not None:
# Native lowering
all_in_shardings = [pxla._UNSPECIFIED] * len(exported_primal.in_avals)
for idx, in_s in zip(sorted(exported_primal.module_kept_var_idx),
exported_primal.in_shardings):
all_in_shardings[idx] = in_s # type: ignore
all_shardings = all_in_shardings + list(exported_primal.out_shardings)
# We cannot mix unspecified and specified shardings. Make the unspecified
# ones replicated
specified_shardings = [
s for s in all_shardings if not pxla._is_unspecified(s)]
if 0 < len(specified_shardings) < len(all_shardings):
# There are some specified, but not all
in_s = specified_shardings[0] # pjit will enforce that all have same devices
assert isinstance(in_s, sharding.XLACompatibleSharding)
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
all_shardings = [
s if not pxla._is_unspecified(s) else replicated_s
for s in all_shardings]
# Since fun_vjp_jax takes two tuples of arguments we must split the in_shardings
vjp_in_args_shardings, vjp_in_out_ct_shardings = util.split_list(all_shardings,
[len(exported_primal.in_avals)])
# pjit front-end does not like all-unspecified
if all(pxla._is_unspecified(s) for s in vjp_in_args_shardings):
vjp_in_args_shardings = pxla._UNSPECIFIED
else:
vjp_in_args_shardings = tuple(vjp_in_args_shardings)
if all(pxla._is_unspecified(s) for s in vjp_in_out_ct_shardings):
vjp_in_out_ct_shardings = pxla._UNSPECIFIED
else:
vjp_in_out_ct_shardings = tuple(vjp_in_out_ct_shardings)

if pxla._is_unspecified(vjp_in_args_shardings) and pxla._is_unspecified(vjp_in_args_shardings):
vjp_in_shardings = pxla._UNSPECIFIED
else:
vjp_in_shardings = (vjp_in_args_shardings, vjp_in_out_ct_shardings)
fun_vjp_jax = pjit.pjit(fun_vjp_jax,
in_shardings=vjp_in_shardings,
out_shardings=vjp_in_args_shardings)
# TODO: enable higher-order gradients
with tf.name_scope("jax2tf_vjp"):
in_cts_flat = convert(
Expand Down Expand Up @@ -707,15 +749,16 @@ class Exported:
"""Represents a lowered and serialized module."""
in_avals: Sequence[core.ShapedArray]
out_avals: Sequence[core.ShapedArray]
in_shardings: Optional[Sequence[Any]]
out_shardings: Optional[Sequence[Any]]
# The in_shardings reflect only the module_ket_var_idx
in_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
out_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
lowering_platform: str # One of "tpu", "cpu", "cuda", "rocm"

mlir_module: mlir.ir.Module
mlir_module_serialized: bytes # VHLO bytecode format
xla_call_module_version: int # Follows the versions of XlaCallModule
module_kept_var_idx: Sequence[bool] # Specifies if an argument is kept in the
# lowering. As long as `out_avals`.
module_kept_var_idx: Sequence[int] # Specifies if an argument is kept in the
# lowering. As long as `out_avals`.
dim_args_spec: Sequence[str]

def serialize_native(fun_jax: Callable,
Expand Down Expand Up @@ -767,7 +810,7 @@ def serialize_native(fun_jax: Callable,
raise NotImplementedError("host_callbacks are not yet implemented for the jax2tf native lowering")

if "kept_var_idx" in lowered.compile_args:
module_kept_var_idx = lowered.compile_args["kept_var_idx"]
module_kept_var_idx = tuple(sorted(lowered.compile_args["kept_var_idx"]))
else:
# For pmap
module_kept_var_idx = tuple(range(len(args_avals)))
Expand Down Expand Up @@ -837,8 +880,8 @@ def serialize_native(fun_jax: Callable,
return Exported(
in_avals=args_avals,
out_avals=out_avals,
in_shardings=lowered.compile_args.get("in_shardings"),
out_shardings=lowered.compile_args.get("out_shardings"),
in_shardings=lowered.compile_args["in_shardings"],
out_shardings=lowered.compile_args["out_shardings"],
lowering_platform=lowering_platform or default_jax_backend(),
mlir_module=mlir_module,
mlir_module_serialized=mlir_module_serialized,
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/jax2tf_test.py
Expand Up @@ -1101,7 +1101,7 @@ def test_error_disallowed_custom_call(self):
"Cannot serialize code with custom calls whose targets .*"):
jax2tf.convert(
lambda a, b: jax.lax.linalg.triangular_solve(a, b, left_side=True),
experimental_native_lowering=True)(a, b)
native_serialization=True)(a, b)

def test_op_metadata_simple(self):
self.skipTest("include_xla_op_metadata not yet enabled")
Expand Down
84 changes: 82 additions & 2 deletions jax/experimental/jax2tf/tests/sharding_test.py
Expand Up @@ -358,6 +358,58 @@ def f_jax(x): # x: f32[10, 20]
(r"custom_call_target.*Sharding", 2 + count_inner_sharding)
])


@parameterized.named_parameters(
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
in_shardings=in_shardings, out_shardings=out_shardings)
for in_shardings in ("missing", None, "P")
for out_shardings in ("missing", None, "P")
)
@jtu.with_mesh([("x", 2)])
def test_grad_pjit(self, in_shardings="missing", out_shardings="None"):
def f_jax(x): # x: f32[10,20] -> f32[20,10]
return jnp.sin(x.T)

pjit_kwargs = {}
if in_shardings != "missing":
pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None)
if out_shardings != "missing":
pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None)
f_jax = pjit.pjit(f_jax, **pjit_kwargs)
x_shape = (10, 20)
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)

def f_grad_tf(x_v, res_ct):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x_v)
res_tf = jax2tf.convert(f_jax)(x_v)
return tape.gradient(res_tf, x_v, output_gradients=res_ct)

# Annotation count for the primal input and the grad output
count_in_P = self.GEQ(2) if in_shardings == "P" else 0
if config.jax2tf_default_native_serialization:
# With native serialization even unspecified in_shardings turn into replicated
count_in_replicated = self.GEQ(2) if in_shardings in [None, "missing"] else 0
else:
count_in_replicated = self.GEQ(2) if in_shardings is None else 0
# Annotation count for the contangent input
count_out_P = self.GEQ(1) if out_shardings == "P" else 0
if config.jax2tf_default_native_serialization:
# With native serialization even unspecified in_shardings turn into replicated
count_out_replicated = self.GEQ(1) if out_shardings in [None, "missing"] else 0
else:
count_out_replicated = self.GEQ(1) if out_shardings is None else 0

self.check_sharding(f_grad_tf, [x, x.T],
checks=[
# The input primal argument, and the output grad
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", count_in_P),
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", count_in_replicated),
# The primal result, and the input cotangent
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P),
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", count_out_replicated),
])

@parameterized.named_parameters(
dict(testcase_name=f"_kind={kind}_in_shardings={in_shardings}_out_shardings={out_shardings}",
kind=kind, in_shardings=in_shardings, out_shardings=out_shardings)
Expand Down Expand Up @@ -460,8 +512,8 @@ def test_xmap_basic(self):
bshape = (2, 7)
b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape)

# f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28]
# f_jax: f32[5], f32[7] -> f32[10], f32[28]
# f_jax: f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28]
# lambda ...: f32[5], f32[7] -> f32[10], f32[28]
f_jax = xmap(lambda a, b: (jnp.concatenate([a, a], axis=0) * 2.,
jnp.concatenate([b, b, b, b], axis=0) * 4.),
in_axes=({0: 'a', 1: 'b'}, ['c', ...]),
Expand Down Expand Up @@ -535,6 +587,34 @@ def f_tf(a, b):
(r"f32\[8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1),
])

def test_grad_xmap(self):
devices = np.reshape(self.devices, (1, 2))
ashape = (16, 8, 5)
a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape)

# f_jax: f32[16,8,5]-> f32[16,8,10]
# lambda ...: f32[5]-> f32[10]
f_jax = xmap(lambda a: jnp.concatenate([a, a], axis=0) * 2.,
in_axes=({0: 'a', 1: 'b'}),
out_axes={0: 'a', 1: 'b'},
axis_resources={'a': 'x', 'b': 'y'})

def f_grad_tf(a, res_ct):
with tf.GradientTape(persistent=True) as tape:
tape.watch(a)
res_tf = jax2tf.convert(f_jax, native_serialization=True)(a)
return tape.gradient(res_tf, a, output_gradients=res_ct)


with Mesh(devices, ('x', 'y')):
self.check_sharding(f_grad_tf, [a, np.concatenate([a, a], axis=2)],
checks=[
# Primal input and grad output
(r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(2)),
# Input cotangent
(r"f32\[16,8,10\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(1)),
])

@jtu.ignore_warning(category=UserWarning,
message="all_to_all .* are only implemented properly for TPUs and GPUs .*")
def test_shmap_all_to_all(self):
Expand Down

0 comments on commit ce3f534

Please sign in to comment.