Skip to content

Commit

Permalink
[shape_poly] Improve compile-time shape checking.
Browse files Browse the repository at this point in the history
JAX shape polymorphism relies on implicit assumptions.
For example, when tracing with input specification `(a, a)`,
we assume that the first two dimensions have the same size
greater or equal to 1.

Here we extend the checking that these assumptions hold. When
we call an `Exported` module from jax, with `jax_export.call_exported`
we check these assumptions statically. However, when we
stage an `Exported` using `XlaCallModule` to be called from
TensorFlow, or when we use TF graph serialization we need
to check these assumptions when we execute and compile
the op (that is when the shapes are available).

To prepare for this compile-time shape checking we add
`Exported.shape_check_module` to produce a serialized
MLIR module containing the shape checking code. This
will be added in a future change to `XlaCallModule`.
  • Loading branch information
gnecula committed Jun 13, 2023
1 parent 3e8506b commit 5b035d5
Show file tree
Hide file tree
Showing 6 changed files with 749 additions and 306 deletions.
41 changes: 40 additions & 1 deletion jax/_src/test_util.py
Expand Up @@ -22,7 +22,7 @@
import os
import tempfile
import textwrap
from typing import Callable, List, Generator, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Generator, Optional, Sequence, Tuple, Union
import unittest
import warnings
import zlib
Expand Down Expand Up @@ -1207,3 +1207,42 @@ def _parse_version(v: str) -> Tuple[int, ...]:

def numpy_version():
return _parse_version(np.__version__)

def parameterized_filterable(*,
kwargs: Sequence[Dict[str, Any]],
testcase_name: Optional[Callable[[Dict[str, Any]], str]] = None,
one_containing: Optional[str] = None,
):
"""
Decorator for named parameterized tests, with filtering.
Works like parameterized.named_parameters, except that it supports the
`one_containing` option. This is useful to select only one of the tests,
and to leave the test name unchanged (helps with specifying the desired test
when debugging).
Args:
kwargs: Each entry is a set of kwargs to be passed to the test function.
testcase_name: Optionally, a function to construct the testcase_name from
one kwargs dict. If not given then the kwarg must contain `testcase_name`.
one_containing: If given, then leave the test name unchanged, and use
only one `kwargs` whose `testcase_name` includes `one_containing`.
"""
# Ensure that all kwargs contain a testcase_name
kwargs_with_testcase_name: Sequence[Dict[str, Any]]
if testcase_name is not None:
kwargs_with_testcase_name = [dict(testcase_name=testcase_name(kw), **kw)
for kw in kwargs]
else:
for kw in kwargs:
assert "testcase_name" in kw
kwargs_with_testcase_name = kwargs
if one_containing is not None:
filtered = tuple(kw for kw in kwargs_with_testcase_name
if one_containing in kw["testcase_name"])
assert filtered, f"No testcase_name contains '{one_containing}'"
kw = filtered[0]
kw["testcase_name"] = ""
return parameterized.named_parameters([kw])
else:
return parameterized.named_parameters(*kwargs_with_testcase_name)
12 changes: 12 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -575,6 +575,18 @@ def _restore_context():
partial(shape_poly.compute_dim_vars_from_arg_shapes,
self.args_avals_flat, args_kwargs_tree=self.in_tree),
self.args_flat_tf, self.args_avals_flat, self.name_stack)

# We invoke shape checking to give it a chance to raise shape errors that
# are evident statically. This should work in TF eager mode because all
# the shapes are known.
# TODO: handle non-static shape checking for graph serialization
acc_shape_check_messages: List[str] = []
_, _ = _interpret_fun_jax(
partial(shape_poly.compute_shape_check_from_arg_shapes,
self.args_avals_flat, args_kwargs_tree=self.in_tree,
acc_shape_check_messages=acc_shape_check_messages),
self.args_flat_tf, self.args_avals_flat, self.name_stack)

_thread_local_state.shape_env = zip(dim_vars, dim_values)

fun_flat_jax, out_tree_thunk = flatten_fun_jax(self.fun_jax, self.in_tree)
Expand Down

0 comments on commit 5b035d5

Please sign in to comment.