From f1fc7fe302148dad29c602031892970cab252714 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 3 Sep 2022 08:08:50 +0300 Subject: [PATCH] [jax2tf] Refactor the experimental_native_lowering path in jax2tf This is part of a suite of refactorings aimed towards supporting pjit by jax2tf experimental_native_lowering. The goal here is to remove many references to internal JAX core APIs, and instead use the AOT APIs: jax.jit(func_jax).lower(*args). Only the experimental_native_lowering behavior should be affected. --- jax/experimental/jax2tf/jax2tf.py | 123 ++++++++----------- jax/experimental/jax2tf/tests/jax2tf_test.py | 6 +- 2 files changed, 55 insertions(+), 74 deletions(-) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 8212a3133c6c..38fe57746211 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -17,7 +17,7 @@ import os import re import threading -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, cast from absl import logging @@ -180,10 +180,6 @@ def __init__(self): self.constant_cache = None # None means that we don't use a cache. We # may be outside a conversion scope. - # Experimental flag to use the JAX native lowering. - self.experimental_native_lowering = False - - _thread_local_state = _ThreadLocalState() def _get_current_name_stack() -> Union[NameStack, str]: @@ -323,9 +319,6 @@ def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal: prev_enable_xla = _thread_local_state.enable_xla _thread_local_state.enable_xla = enable_xla - prev_experimental_native_lowering = _thread_local_state.experimental_native_lowering - _thread_local_state.experimental_native_lowering = experimental_native_lowering - prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata # TODO(b/189306134): implement support for XLA metadata _thread_local_state.include_xla_op_metadata = False @@ -343,7 +336,8 @@ def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal: outs_tf, out_avals = _interpret_fun_jax(fun_flat_jax, args_flat_tf, args_avals_flat, name_stack, - fresh_constant_cache=True) + fresh_constant_cache=True, + experimental_native_lowering=experimental_native_lowering) return (tuple(outs_tf), make_custom_gradient_fn_tf( fun_flat_jax=fun_flat_jax, @@ -356,7 +350,9 @@ def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal: else: outs_tf, out_avals = _interpret_fun_jax(fun_flat_jax, args_flat_tf, args_avals_flat, - name_stack, fresh_constant_cache=True) + name_stack, + fresh_constant_cache=True, + experimental_native_lowering=experimental_native_lowering) message = ("The jax2tf-converted function does not support gradients. " "Use `with_gradient` parameter to enable gradients") # We use PreventGradient, which is propagated through a SavedModel. @@ -367,7 +363,6 @@ def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal: finally: _thread_local_state.shape_env = () _thread_local_state.enable_xla = prev_enable_xla - _thread_local_state.experimental_native_lowering = prev_experimental_native_lowering _thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata out_flat_tf = [tf.identity(x, "jax2tf_out") for x in out_flat_tf] @@ -538,30 +533,33 @@ def _extended_name_stack(extra_name_stack: Optional[str]): def _interpret_fun_jax( fun_jax: Callable, - in_vals_tf: Sequence[TfVal], - in_avals: Sequence[core.ShapedArray], + args_tf: Sequence[TfVal], + args_avals: Sequence[core.ShapedArray], extra_name_stack: Optional[str], - fresh_constant_cache: bool = False + fresh_constant_cache: bool = False, + experimental_native_lowering: bool = False ) -> Tuple[Sequence[TfVal], Tuple[core.ShapedArray]]: - if _thread_local_state.experimental_native_lowering: - return util.unzip2(_lower_native(fun_jax, in_vals_tf, in_avals, extra_name_stack)) + if experimental_native_lowering: + del extra_name_stack + return _lower_native_and_run(fun_jax, args_avals, args_tf) else: with core.new_base_main(TensorFlowTrace) as main: # type: ignore - subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, in_avals) + subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals) with _extended_name_stack(extra_name_stack): with core.new_sublevel(): out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \ - _call_wrapped_with_new_constant_cache(subtrace_fun, in_vals_tf, + _call_wrapped_with_new_constant_cache(subtrace_fun, args_tf, fresh_constant_cache=fresh_constant_cache) del main return util.unzip2(out_vals) -def _lower_native(fun_jax: Callable, in_vals_tf: Sequence[TfVal], - in_avals: Sequence[core.ShapedArray], - extra_name_stack: Optional[str]): - """Lowers the function using native lowering. +def _lower_native_and_run(fun_jax: Callable, + args_avals: Sequence[core.ShapedArray], + args_tf: Sequence[TfVal], + ) -> Tuple[Sequence[TfVal], Tuple[core.ShapedArray]]: + """Lowers the function using native lowering and then invokes it. Work-in-progress. @@ -570,63 +568,46 @@ def _lower_native(fun_jax: Callable, in_vals_tf: Sequence[TfVal], Special care must be taken in presence of shape polymorphism. """ - lu_fun = lu.wrap_init(fun_jax) # Look for shape polymorphism - abstract_axes: Sequence[Dict[int, str]] = [] # one for each argument - for aval in in_avals: + # For each arg, map axis idx to dimension variable name + abstracted_axes: Sequence[Dict[int, str]] = [] + # For each dimension variable, encode how to compute its value from the + # shape of the explicit arguments. E.g., "2.1" denotes args_tf[2].shape[1]. + # Note: We assume that lowering will introduce dim args in the order in which + # dim variables are first seen when scanning the explicit arguments + # in order and then scanning their shapes for dim variables. + dim_args_spec: List[str] = [] + dim_vars_seen: Set[str] = set() + for arg_idx, aval in enumerate(args_avals): one_abstract_axes = {} - for i, d in enumerate(aval.shape): + for axis_idx, d in enumerate(aval.shape): if not core.is_constant_dim(d): d_var = d.to_var() if d_var is None: - raise ValueError(f"Only simple variables supported: {aval.shape}") - one_abstract_axes[i] = d_var - abstract_axes.append(one_abstract_axes) - if any(abstract_axes): + raise ValueError(f"Only simple dimension variables supported: {aval.shape}") + if not d_var in dim_vars_seen: + dim_args_spec.append(f"{arg_idx}.{axis_idx}") + dim_vars_seen.add(d_var) + one_abstract_axes[axis_idx] = d_var + abstracted_axes.append(one_abstract_axes) + + if any(abstracted_axes): if not config.jax_dynamic_shapes: raise ValueError( - "Found shape polymorphism but --jax_dynamic_shapes is not on") - # In order to use infer_input_type, we must manufacture some JAX arguments. - # Actually the only thing that matters is that get_aval(x) and x.shape work for them. - # This is a hack, should find a way to refactor infer_lambda_input_type so that we - # can reuse it here more cleanly. - top_trace = core.find_top_trace(()) - fake_vals_jax = [ - TensorFlowTracer(top_trace, val, aval) # type: ignore - for val, aval in zip(in_vals_tf, in_avals) - ] - in_type = partial_eval.infer_lambda_input_type(abstract_axes, fake_vals_jax) # type: ignore - lu_fun = lu.annotate(lu_fun, in_type) - arg_specs = [(None, None) for _ in in_avals] - - nr_dim_vars = 0 - # For each dimension variable, encode how to compute its value from the - # shape of the explicit arguments. E.g., "2.1" denotes args[2].shape[1] - dim_args_spec_dict: Dict[int, str] = {} - for arg_idx, (arg_aval, is_explicit) in enumerate(in_type): - if not is_explicit: - nr_dim_vars += 1 - else: - for i, d in enumerate(arg_aval.shape): - if isinstance(d, core.DBIdx) and d.val not in dim_args_spec_dict: - dim_args_spec_dict[d.val] = f"{arg_idx - nr_dim_vars}.{i}" - dim_args_spec = [dim_args_spec_dict[i] for i in range(nr_dim_vars)] + "Found shape polymorphism but --jax_dynamic_shapes is not set") + abstracted_axes = tuple(abstracted_axes) else: - arg_specs = [(aval, None) for aval in in_avals] # type: ignore - dim_args_spec = [] + abstracted_axes = None # type: ignore + arg_specs_jax = [ + jax.ShapeDtypeStruct(aval.shape, aval.dtype) + for aval in args_avals + ] # TODO: specify the backend for experimental_native_lowering - device = None backend = jax.default_backend() - lowered = dispatch.lower_xla_callable( - lu_fun, - device, - backend, - extra_name_stack, - (False,) * len(in_avals), # donated - True, # always_lower, - True, # keep_unused - *arg_specs) + lowered = jax.jit(fun_jax, backend=backend, + keep_unused=True, # TODO: allow dropping unused + abstracted_axes=abstracted_axes).lower(*arg_specs_jax)._lowering mhlo_module = lowered.mhlo() mhlo_module_text = mlir.module_to_string(mhlo_module) @@ -652,7 +633,7 @@ def _lower_native(fun_jax: Callable, in_vals_tf: Sequence[TfVal], # Figure out the result types and shapes out_avals = lowered.compile_args["out_avals"] - # TODO: handle d being InDBIdx + # TODO(necula): handle d being InDBIdx out_shapes = tuple( tuple(d if type(d) is int else None for d in out_aval.shape) @@ -665,7 +646,7 @@ def _out_type(jax_type): out_types = tuple(_out_type(out_aval.dtype) for out_aval in out_avals) res = tfxla.call_module( - in_vals_tf, + args_tf, module=mhlo_module_text, Tout=out_types, Sout=out_shapes, @@ -682,7 +663,7 @@ def _convert_res(res_val, res_jax_type): res = tuple( _convert_res(res_val, out_aval.dtype) for res_val, out_aval in zip(res, out_avals)) - return zip(res, out_avals) + return res, out_avals def _fixup_mhlo_module_text(mhlo_module_text: str) -> str: # A workaround for MHLO not (yet) having backwards compatibility. With diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index e40fb2040091..5b7f2b0d355b 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -488,8 +488,8 @@ def g(x): # x: i32 dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) - def test_gradients_unused_argument_readme(self, with_function=True): - # x2 and x3 are not used. x3 has integer type. + def test_gradients_unused_argument_readme(self, with_function=False): + # x1 and x3 are not used. x3 has integer type. def fn(x0, x1, x2, x3): return x0 * 0. + x2 * 2. @@ -536,7 +536,7 @@ def fn(x0, x1, x2, x3): dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) - def test_gradients_int_argument(self, with_function=True): + def test_gradients_int_argument(self, with_function=False): # https://github.com/google/jax/issues/6975 # Also issue #6975. # An expanded version of test_gradients_unused_argument