Skip to content

Commit

Permalink
[jax2tf] A new experimental version with JAX native lowering.
Browse files Browse the repository at this point in the history
In the future JAX will be able to use a serialization format
based on a variant of MHLO. This is not yet ready, but in this PR
we are starting to get jax2tf ready for this. As a temporary
step, we had introduced a TF op called XlaCallModule which carries
a serialized MHLO module and which e can use to wrap the JAX native
MHLO as a TF op. We still reuse parts of jax2tf, in particular
the gradient machinery.

This functionality can be enabled locally with a
`experimental_native_lowering` flag for `jax2tf.convert`, or
globally with the flag `--jax2tf_default_experimental_native_lowering`.
  • Loading branch information
gnecula committed Jul 19, 2022
1 parent ea627b8 commit ee50140
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 31 deletions.
9 changes: 9 additions & 0 deletions jax/_src/config.py
Expand Up @@ -582,6 +582,15 @@ def update_thread_local_jit_state(**kw):
)
)

jax2tf_default_experimental_native_lowering = config.define_bool_state(
name='jax2tf_default_experimental_native_lowering',
default=bool_env('JAX2TF_DEFAULT_EXPERIMENTAL_NATIVE_LOWERING', False),
help=(
'DO NOT USE, highly experimental. Sets the default value of the '
'experimental_native_lowering parameter to jax2tf.convert.'
)
)

jax_platforms = config.define_string_state(
name='jax_platforms',
default=None,
Expand Down
151 changes: 140 additions & 11 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -19,6 +19,8 @@
import threading
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast

from absl import logging

import jax
from jax import lax
from jax import config
Expand All @@ -30,6 +32,7 @@
from jax.experimental import pjit
from jax.experimental import sharding
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import partial_eval
from jax.interpreters import pxla
from jax.interpreters import xla
Expand Down Expand Up @@ -174,6 +177,9 @@ def __init__(self):
self.constant_cache = None # None means that we don't use a cache. We
# may be outside a conversion scope.

# Experimental flag to use the JAX native lowering.
self.experimental_native_lowering = False


_thread_local_state = _ThreadLocalState()

Expand All @@ -197,8 +203,8 @@ def convert(fun: Callable,
*,
polymorphic_shapes=None,
with_gradient=True,
enable_xla=True
) -> Callable:
enable_xla=True,
experimental_native_lowering="default") -> Callable:
"""Transforms `fun` to be executed by TensorFlow.
See
Expand Down Expand Up @@ -257,12 +263,20 @@ def convert(fun: Callable,
for the TFLite and TFjs converters. For those cases, unset this parameter
so the converter tries harder to use non-XLA TF ops to convert the
function and aborts if this is not possible.
experimental_native_lowering: DO NOT USE, for experimental purposes only.
The value "default" defers to --jax2tf_default_experimental_native_lowering.
Returns:
A version of `fun` that expects TfVals as arguments (or
tuple/lists/dicts) thereof, and returns TfVals as outputs, and uses
only TensorFlow ops.
"""
if experimental_native_lowering == "default":
experimental_native_lowering = config.jax2tf_default_experimental_native_lowering

if experimental_native_lowering and not enable_xla:
raise ValueError(
"experimental_native_lowering is not supported with enable_xla=False")
api._check_callable(fun)
fun_name = getattr(fun, "__name__", "unknown")
name_stack = util.wrap_name(fun_name, "jax2tf")
Expand Down Expand Up @@ -417,6 +431,9 @@ def fix_in_ct(in_ct, arg_aval: core.ShapedArray):
prev_enable_xla = _thread_local_state.enable_xla
_thread_local_state.enable_xla = enable_xla

prev_experimental_native_lowering = _thread_local_state.experimental_native_lowering
_thread_local_state.experimental_native_lowering = experimental_native_lowering

prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
# TODO(b/189306134): implement support for XLA metadata
_thread_local_state.include_xla_op_metadata = False
Expand Down Expand Up @@ -453,6 +470,7 @@ def converted_fun_flat_with_custom_gradient(*args_flat: TfVal) -> TfVal:
finally:
_thread_local_state.shape_env = ()
_thread_local_state.enable_xla = prev_enable_xla
_thread_local_state.experimental_native_lowering = prev_experimental_native_lowering
_thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata

out_flat = [tf.identity(x, "jax2tf_out") for x in out_flat]
Expand Down Expand Up @@ -505,17 +523,128 @@ def _interpret_fun(
extra_name_stack: Optional[str],
fresh_constant_cache: bool = False
) -> Sequence[Tuple[TfVal, core.ShapedArray]]:
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
fun = _interpret_subtrace(fun, main, in_avals)
with _extended_name_stack(extra_name_stack):
with core.new_sublevel():
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
_call_wrapped_with_new_constant_cache(fun, in_vals,
fresh_constant_cache=fresh_constant_cache)
if _thread_local_state.experimental_native_lowering:
return _lower_native(fun, in_vals, in_avals, extra_name_stack)
else:
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
fun = _interpret_subtrace(fun, main, in_avals)
with _extended_name_stack(extra_name_stack):
with core.new_sublevel():
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
_call_wrapped_with_new_constant_cache(fun, in_vals,
fresh_constant_cache=fresh_constant_cache)

del main

return tuple(out_vals)

def _lower_native(fun: lu.WrappedFun, in_vals: Sequence[TfVal],
in_avals: Sequence[core.ShapedArray],
extra_name_stack: Optional[str]):
"""Lowers the function using native lowering.
Work-in-progress.
Uses JAX native lowering to MHLO, and then wraps the result in a
XlaCallModule TF op. This op does not have backward-compatibility yet.
Special care must be taken in presence of shape polymorphism.
"""
# Look for shape polymorphism
abstract_axes: Sequence[Dict[int, str]] = [] # one for each argument
for aval in in_avals:
one_abstract_axes = {}
for i, d in enumerate(aval.shape):
if not core.is_constant_dim(d):
d_var = d.to_var()
if d_var is None:
raise ValueError(f"Only simple variables supported: {aval.shape}")
one_abstract_axes[i] = d_var
abstract_axes.append(one_abstract_axes)
if any(abstract_axes):
if not config.jax_dynamic_shapes:
raise ValueError(
"Found shape polymorphism but --jax_dynamic_shapes is not on")
# In order to use infer_input_type, we must manufacture some JAX arguments.
# Actually the only thing that matters is that get_aval(x) and x.shape work for them.
# This is a hack, should find a way to refactor infer_lambda_input_type so that we
# can reuse it here more cleanly.
top_trace = core.find_top_trace(())
fake_jax_vals = [
TensorFlowTracer(top_trace, val, aval) # type: ignore
for val, aval in zip(in_vals, in_avals)
]
in_type = partial_eval.infer_lambda_input_type(abstract_axes, fake_jax_vals) # type: ignore
fun = lu.annotate(fun, in_type)
arg_specs = [(None, None) for _ in in_avals]

nr_dim_vars = 0
# For each dimension variable, encode how to compute its value from the
# shape of the explicit arguments. E.g., "2.1" denotes args[2].shape[1]
dim_args_spec_dict: Dict[int, str] = {}
for arg_idx, (arg_aval, is_explicit) in enumerate(in_type):
if not is_explicit:
nr_dim_vars += 1
else:
for i, d in enumerate(arg_aval.shape):
if isinstance(d, core.DBIdx) and d.val not in dim_args_spec_dict:
dim_args_spec_dict[d.val] = f"{arg_idx - nr_dim_vars}.{i}"
dim_args_spec = [dim_args_spec_dict[i] for i in range(nr_dim_vars)]
else:
arg_specs = [(aval, None) for aval in in_avals] # type: ignore
dim_args_spec = []

# TODO: specify the backend for experimental_native_lowering
device = None
backend = jax.default_backend()
lowered = dispatch.lower_xla_callable(
fun,
device,
backend,
extra_name_stack,
(False,) * len(in_avals), # donated
True, # always_lower,
True, # keep_unused
*arg_specs)

mhlo_module = lowered.mhlo()
mhlo_module_text = mlir.module_to_string(mhlo_module)
logging.vlog(2, f"XlaCallModule {mhlo_module_text}")

# Figure out the result types and shapes
out_avals = lowered.compile_args["out_avals"]
# TODO: handle d being InDBIdx
out_shapes = tuple(
tuple(d if type(d) is int else None
for d in out_aval.shape)
for out_aval in out_avals)

def _out_type(jax_type):
if jax_type == dtypes.float0:
return dtypes.bool_
return jax_type
out_types = tuple(_out_type(out_aval.dtype) for out_aval in out_avals)

res = tfxla.call_module(
in_vals,
module=mhlo_module_text,
Tout=out_types,
Sout=out_shapes,
dim_args_spec=dim_args_spec)

# Convert the results to the needed TF types
def _convert_res(res_val, res_jax_type):
conversion_dtype = _to_tf_dtype(res_jax_type)
if conversion_dtype != res_jax_type:
return tf.cast(res_val, conversion_dtype)
else:
return res_val

del main
res = tuple(
_convert_res(res_val, out_aval.dtype)
for res_val, out_aval in zip(res, out_avals))
return zip(res, out_avals)

return tuple(out_vals)

def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
in_vals: Sequence[TfVal],
Expand Down
10 changes: 10 additions & 0 deletions jax/experimental/jax2tf/shape_poly.py
Expand Up @@ -409,6 +409,15 @@ def evaluate(self, env: ShapeEnv):
terms = [_multiply(mon.evaluate(env), np.int32(coeff)) for mon, coeff in self.monomials()]
return functools.reduce(_add, terms) if len(terms) > 1 else terms[0]

@staticmethod
def get_aval(_: "_DimPolynomial"):
return core.ShapedArray((),
dtypes.canonicalize_dtype(np.int64),
weak_type=True)


core.pytype_aval_mappings[_DimPolynomial] = _DimPolynomial.get_aval

def _ensure_poly(p: DimSize) -> _DimPolynomial:
if isinstance(p, _DimPolynomial): return p
return _DimPolynomial({_DimMon(): p})
Expand All @@ -417,6 +426,7 @@ def is_poly_dim(p: DimSize) -> bool:
return isinstance(p, _DimPolynomial)



class DimensionHandlerPoly(core.DimensionHandler):
"""See core.DimensionHandler.
Expand Down

0 comments on commit ee50140

Please sign in to comment.