Skip to content

Commit

Permalink
[jax2tf] Refactor the experimental_native_lowering path in jax2tf
Browse files Browse the repository at this point in the history
This is part of a suite of refactorings aimed towards supporting
pjit by jax2tf experimental_native_lowering. The goal here is
to remove many references to internal JAX core APIs, and instead
use the AOT APIs: jax.jit(func_jax).lower(*args).

Only the experimental_native_lowering behavior should be affected.
  • Loading branch information
gnecula committed Sep 7, 2022
1 parent 9c79439 commit f1fc7fe
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 74 deletions.
123 changes: 52 additions & 71 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -17,7 +17,7 @@
import os
import re
import threading
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, cast

from absl import logging

Expand Down Expand Up @@ -180,10 +180,6 @@ def __init__(self):
self.constant_cache = None # None means that we don't use a cache. We
# may be outside a conversion scope.

# Experimental flag to use the JAX native lowering.
self.experimental_native_lowering = False


_thread_local_state = _ThreadLocalState()

def _get_current_name_stack() -> Union[NameStack, str]:
Expand Down Expand Up @@ -323,9 +319,6 @@ def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal:
prev_enable_xla = _thread_local_state.enable_xla
_thread_local_state.enable_xla = enable_xla

prev_experimental_native_lowering = _thread_local_state.experimental_native_lowering
_thread_local_state.experimental_native_lowering = experimental_native_lowering

prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
# TODO(b/189306134): implement support for XLA metadata
_thread_local_state.include_xla_op_metadata = False
Expand All @@ -343,7 +336,8 @@ def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
outs_tf, out_avals = _interpret_fun_jax(fun_flat_jax,
args_flat_tf, args_avals_flat,
name_stack,
fresh_constant_cache=True)
fresh_constant_cache=True,
experimental_native_lowering=experimental_native_lowering)
return (tuple(outs_tf),
make_custom_gradient_fn_tf(
fun_flat_jax=fun_flat_jax,
Expand All @@ -356,7 +350,9 @@ def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
else:
outs_tf, out_avals = _interpret_fun_jax(fun_flat_jax,
args_flat_tf, args_avals_flat,
name_stack, fresh_constant_cache=True)
name_stack,
fresh_constant_cache=True,
experimental_native_lowering=experimental_native_lowering)
message = ("The jax2tf-converted function does not support gradients. "
"Use `with_gradient` parameter to enable gradients")
# We use PreventGradient, which is propagated through a SavedModel.
Expand All @@ -367,7 +363,6 @@ def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
finally:
_thread_local_state.shape_env = ()
_thread_local_state.enable_xla = prev_enable_xla
_thread_local_state.experimental_native_lowering = prev_experimental_native_lowering
_thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata

out_flat_tf = [tf.identity(x, "jax2tf_out") for x in out_flat_tf]
Expand Down Expand Up @@ -538,30 +533,33 @@ def _extended_name_stack(extra_name_stack: Optional[str]):

def _interpret_fun_jax(
fun_jax: Callable,
in_vals_tf: Sequence[TfVal],
in_avals: Sequence[core.ShapedArray],
args_tf: Sequence[TfVal],
args_avals: Sequence[core.ShapedArray],
extra_name_stack: Optional[str],
fresh_constant_cache: bool = False
fresh_constant_cache: bool = False,
experimental_native_lowering: bool = False
) -> Tuple[Sequence[TfVal], Tuple[core.ShapedArray]]:
if _thread_local_state.experimental_native_lowering:
return util.unzip2(_lower_native(fun_jax, in_vals_tf, in_avals, extra_name_stack))
if experimental_native_lowering:
del extra_name_stack
return _lower_native_and_run(fun_jax, args_avals, args_tf)
else:
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, in_avals)
subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals)
with _extended_name_stack(extra_name_stack):
with core.new_sublevel():
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
_call_wrapped_with_new_constant_cache(subtrace_fun, in_vals_tf,
_call_wrapped_with_new_constant_cache(subtrace_fun, args_tf,
fresh_constant_cache=fresh_constant_cache)

del main

return util.unzip2(out_vals)

def _lower_native(fun_jax: Callable, in_vals_tf: Sequence[TfVal],
in_avals: Sequence[core.ShapedArray],
extra_name_stack: Optional[str]):
"""Lowers the function using native lowering.
def _lower_native_and_run(fun_jax: Callable,
args_avals: Sequence[core.ShapedArray],
args_tf: Sequence[TfVal],
) -> Tuple[Sequence[TfVal], Tuple[core.ShapedArray]]:
"""Lowers the function using native lowering and then invokes it.
Work-in-progress.
Expand All @@ -570,63 +568,46 @@ def _lower_native(fun_jax: Callable, in_vals_tf: Sequence[TfVal],
Special care must be taken in presence of shape polymorphism.
"""
lu_fun = lu.wrap_init(fun_jax)
# Look for shape polymorphism
abstract_axes: Sequence[Dict[int, str]] = [] # one for each argument
for aval in in_avals:
# For each arg, map axis idx to dimension variable name
abstracted_axes: Sequence[Dict[int, str]] = []
# For each dimension variable, encode how to compute its value from the
# shape of the explicit arguments. E.g., "2.1" denotes args_tf[2].shape[1].
# Note: We assume that lowering will introduce dim args in the order in which
# dim variables are first seen when scanning the explicit arguments
# in order and then scanning their shapes for dim variables.
dim_args_spec: List[str] = []
dim_vars_seen: Set[str] = set()
for arg_idx, aval in enumerate(args_avals):
one_abstract_axes = {}
for i, d in enumerate(aval.shape):
for axis_idx, d in enumerate(aval.shape):
if not core.is_constant_dim(d):
d_var = d.to_var()
if d_var is None:
raise ValueError(f"Only simple variables supported: {aval.shape}")
one_abstract_axes[i] = d_var
abstract_axes.append(one_abstract_axes)
if any(abstract_axes):
raise ValueError(f"Only simple dimension variables supported: {aval.shape}")
if not d_var in dim_vars_seen:
dim_args_spec.append(f"{arg_idx}.{axis_idx}")
dim_vars_seen.add(d_var)
one_abstract_axes[axis_idx] = d_var
abstracted_axes.append(one_abstract_axes)

if any(abstracted_axes):
if not config.jax_dynamic_shapes:
raise ValueError(
"Found shape polymorphism but --jax_dynamic_shapes is not on")
# In order to use infer_input_type, we must manufacture some JAX arguments.
# Actually the only thing that matters is that get_aval(x) and x.shape work for them.
# This is a hack, should find a way to refactor infer_lambda_input_type so that we
# can reuse it here more cleanly.
top_trace = core.find_top_trace(())
fake_vals_jax = [
TensorFlowTracer(top_trace, val, aval) # type: ignore
for val, aval in zip(in_vals_tf, in_avals)
]
in_type = partial_eval.infer_lambda_input_type(abstract_axes, fake_vals_jax) # type: ignore
lu_fun = lu.annotate(lu_fun, in_type)
arg_specs = [(None, None) for _ in in_avals]

nr_dim_vars = 0
# For each dimension variable, encode how to compute its value from the
# shape of the explicit arguments. E.g., "2.1" denotes args[2].shape[1]
dim_args_spec_dict: Dict[int, str] = {}
for arg_idx, (arg_aval, is_explicit) in enumerate(in_type):
if not is_explicit:
nr_dim_vars += 1
else:
for i, d in enumerate(arg_aval.shape):
if isinstance(d, core.DBIdx) and d.val not in dim_args_spec_dict:
dim_args_spec_dict[d.val] = f"{arg_idx - nr_dim_vars}.{i}"
dim_args_spec = [dim_args_spec_dict[i] for i in range(nr_dim_vars)]
"Found shape polymorphism but --jax_dynamic_shapes is not set")
abstracted_axes = tuple(abstracted_axes)
else:
arg_specs = [(aval, None) for aval in in_avals] # type: ignore
dim_args_spec = []
abstracted_axes = None # type: ignore

arg_specs_jax = [
jax.ShapeDtypeStruct(aval.shape, aval.dtype)
for aval in args_avals
]
# TODO: specify the backend for experimental_native_lowering
device = None
backend = jax.default_backend()
lowered = dispatch.lower_xla_callable(
lu_fun,
device,
backend,
extra_name_stack,
(False,) * len(in_avals), # donated
True, # always_lower,
True, # keep_unused
*arg_specs)
lowered = jax.jit(fun_jax, backend=backend,
keep_unused=True, # TODO: allow dropping unused
abstracted_axes=abstracted_axes).lower(*arg_specs_jax)._lowering

mhlo_module = lowered.mhlo()
mhlo_module_text = mlir.module_to_string(mhlo_module)
Expand All @@ -652,7 +633,7 @@ def _lower_native(fun_jax: Callable, in_vals_tf: Sequence[TfVal],

# Figure out the result types and shapes
out_avals = lowered.compile_args["out_avals"]
# TODO: handle d being InDBIdx
# TODO(necula): handle d being InDBIdx
out_shapes = tuple(
tuple(d if type(d) is int else None
for d in out_aval.shape)
Expand All @@ -665,7 +646,7 @@ def _out_type(jax_type):
out_types = tuple(_out_type(out_aval.dtype) for out_aval in out_avals)

res = tfxla.call_module(
in_vals_tf,
args_tf,
module=mhlo_module_text,
Tout=out_types,
Sout=out_shapes,
Expand All @@ -682,7 +663,7 @@ def _convert_res(res_val, res_jax_type):
res = tuple(
_convert_res(res_val, out_aval.dtype)
for res_val, out_aval in zip(res, out_avals))
return zip(res, out_avals)
return res, out_avals

def _fixup_mhlo_module_text(mhlo_module_text: str) -> str:
# A workaround for MHLO not (yet) having backwards compatibility. With
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Expand Up @@ -488,8 +488,8 @@ def g(x): # x: i32
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_gradients_unused_argument_readme(self, with_function=True):
# x2 and x3 are not used. x3 has integer type.
def test_gradients_unused_argument_readme(self, with_function=False):
# x1 and x3 are not used. x3 has integer type.
def fn(x0, x1, x2, x3):
return x0 * 0. + x2 * 2.

Expand Down Expand Up @@ -536,7 +536,7 @@ def fn(x0, x1, x2, x3):
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_gradients_int_argument(self, with_function=True):
def test_gradients_int_argument(self, with_function=False):
# https://github.com/google/jax/issues/6975
# Also issue #6975.
# An expanded version of test_gradients_unused_argument
Expand Down

0 comments on commit f1fc7fe

Please sign in to comment.