Skip to content

Commit

Permalink
[shape_poly] Simplify the API for processing polymorphic_shape specif…
Browse files Browse the repository at this point in the history
…ications

Before, we had `export.poly_spec` to create a jax.ShapedDtypeStruct`
given a polymorphic shape specification. This function was
invoked `poly_spec(arg_shape, arg_dtype, polymorphic_shape)`.
The `arg_shape` was only needed when the polymorphic shape spec
contained placeholders.

We break out an `export.symbolic_shape` that is just a parser
of polymorphic shape specs and we ask the user to invoke
`jax.ShapeDtypeStruct` directly:

`jax.ShapeDtypeStruct(export.symbolic_shape(polymorphic_shape, like=arg_shape), arg_dtype)`.

We also rename the `export.poly_specs` to `export.arg_specs`.
  • Loading branch information
gnecula committed Nov 28, 2023
1 parent 86e99a9 commit c6afdfd
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 170 deletions.
4 changes: 2 additions & 2 deletions jax/_src/internal_test_util/export_back_compat_test_util.py
Expand Up @@ -284,7 +284,7 @@ def serialize(self,
a string (for debugging), and (c) the module serialization version.
"""
# Use the native exporter, to make sure we get the proper serialization.
args_specs = export.poly_specs(data.inputs, polymorphic_shapes)
args_specs = export.args_specs(data.inputs, polymorphic_shapes)
exported = export.export(
jax.jit(func),
lowering_platforms=(self.default_jax_backend(),),
Expand All @@ -300,7 +300,7 @@ def serialize(self,

def run_serialized(self, data: CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None):
args_specs = export.poly_specs(data.inputs, polymorphic_shapes)
args_specs = export.args_specs(data.inputs, polymorphic_shapes)
def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray:
return core.ShapedArray(a.shape, a.dtype)
in_avals_tree = tree_util.tree_map(ndarray_to_aval, args_specs)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/typing.py
Expand Up @@ -59,7 +59,7 @@ def dtype(self) -> DType: ...

# Shapes are tuples of dimension sizes, which are normally integers. We allow
# modules to extend the set of dimension sizes to contain other types, e.g.,
# symbolic dimensions in jax2tf.shape_poly.DimVar and masking.Poly.
# symbolic dimensions in export.DimExpr.
DimSize = Union[int, Any] # extensible
Shape = Sequence[DimSize]

Expand Down
60 changes: 26 additions & 34 deletions jax/experimental/export/export.py
Expand Up @@ -53,6 +53,7 @@
zip = util.safe_zip

DType = Any
Shape = jax._src.core.Shape

class DisabledSafetyCheck:
"""A safety check should be skipped on (de)serialization.
Expand Down Expand Up @@ -307,52 +308,40 @@ def default_lowering_platform() -> str:
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
return xb.canonicalize_platform(jax.default_backend())

def poly_spec(
arg_shape: Sequence[Optional[int]],
arg_dtype: DType,
polymorphic_shape: Optional[str]) -> jax.ShapeDtypeStruct:
def symbolic_shape(
shape_spec: Optional[str],
*,
like: Optional[Sequence[Optional[int]]] = None) -> Shape:
"""Constructs a jax.ShapeDtypeStruct with polymorphic shapes.
Args:
arg_shape: the shape, with possibly some unspecified dimensions.
arg_dtype: the jax dtype.
polymorphic_shape: a string specifying the polymorphic shape.
.. warning:: The shape-polymorphic lowering is an experimental feature.
It is meant to be sound, but it is known to reject some JAX programs
that are shape polymorphic. The details of this feature can change.
It should be either `None` (all dimensions are constant), or a string of
specification for one axis, and can be either a constant, `_` denoting
a constant dimension given by the `arg_shape`, or the name of a
dimension variable assumed to range over dimension greater than 0. For
convenience, zero or more trailing `_` can be abbreviated with `...`, and
the surrounding parentheses may be missing.
Note that this function does not ensure that the provided `arg_shape`
is compatible with `polymorphic_shape`. The `arg_shape` is used only
to fill-in placeholders from `polymorphic_shape`.
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
shape_spec: a symbolic shape specification. None stands for "...".
like: when `shape_spec` contains placeholders ("_", "..."), use this
shape to fill in the placeholders.
The dimensions of `like` that are used for filling
must be known (not `None`). If a dimension in `like` is known and
the corresponding dimension in `shape_spec` is a constant then they
must be equal.
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
Returns: a jax.ShapeDTypeStruct with shapes that may contain symbolic
expressions involving dimension variables.
"""
aval_shape = shape_poly._parse_spec(polymorphic_shape, arg_shape)
return jax.ShapeDtypeStruct(aval_shape, arg_dtype)
return shape_poly.symbolic_shape(shape_spec, like=like)

def shape_and_dtype_jax_array(a) -> tuple[Sequence[Optional[int]], DType]:
"""Returns the shape and dtype of a jax.Array."""
aval = core.raise_to_shaped(core.get_aval(a))
return aval.shape, aval.dtype

def poly_specs(
def args_specs(
args, # pytree of arguments
polymorphic_shapes, # prefix pytree of strings
get_shape_and_dtype=shape_and_dtype_jax_array,
):
"""Constructs a pytree of jax.ShapeDtypeSpec.
"""Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
Args:
args: a pytree of arguments
Expand All @@ -363,12 +352,14 @@ def poly_specs(
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
Note that this function does not ensure that the provided `args` shapes
are compatible with `polymorphic_shapes`. The `args.shape` are used only
to fill-in placeholders from `polymorphic_shapes`.
are compatible with `polymorphic_shapes`. The `.shape` of the `args` are
used only to fill-in placeholders from `polymorphic_shapes`.
See docstring of `poly_spec` and
See docstring of `symbolic_shape` and
[the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
get_shape_and_dtype: a function that given an argument extracts a tuple
of a shape and a dtype.
Returns: a pytree of jax.ShapeDTypeStruct matching `args`.
"""
Expand All @@ -394,8 +385,9 @@ def poly_specs(
raise e("jax_export polymorphic_shapes") from None

# Now add in the polymorphic shapes
args_specs_flat = tuple(
map(poly_spec, shapes, dtypes, polymorphic_shapes_flat))
args_specs_flat = (
jax.ShapeDtypeStruct(symbolic_shape(spec, like=s), t)
for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat))

return args_tree.unflatten(args_specs_flat)

Expand Down
78 changes: 40 additions & 38 deletions jax/experimental/export/shape_poly.py
Expand Up @@ -925,20 +925,22 @@ def __str__(self):
return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")"


def _parse_spec(shape_spec: Union[str, PolyShape, None],
arg_shape: Sequence[Optional[int]]) -> Sequence[DimSize]:
"""Parses the shape polymorphic specification for one array argument.
def symbolic_shape(shape_spec: Union[str, PolyShape, None],
*,
like: Optional[Sequence[Optional[int]]] = None
) -> Sequence[DimSize]:
"""Parses the shape polymorphic specification into a symbolic shape.
We have to be able to parse all strings produced by str(_DimExpr) because
sometimes the output polymorphic shapes of one function become the input
polymorphic shapes of another.
Args:
shape_spec: a shape polymorphic specification. None stands for "...".
arg_shape: an actual shape, possibly containing unknown dimensions (None).
We use `arg_shape` to fill-in the placeholders `_` and `...` in
the `shape_spec`. The dimensions of `arg_shape` that are used for filling
must be known (not `None`). If a dimension in `arg_shape` is known and
like: when `shape_spec` contains placeholders ("_", "..."), use this
shape to fill in the placeholders.
The dimensions of `like` that are used for filling
must be known (not `None`). If a dimension in `like` is known and
the corresponding dimension in `shape_spec` is a constant then they
must be equal.
Expand All @@ -952,16 +954,16 @@ def _parse_spec(shape_spec: Union[str, PolyShape, None],
elif not isinstance(shape_spec, str):
raise ValueError("polymorphic shape spec should be None or a string. "
f"Found {shape_spec_repr}.")
return _Parser(shape_spec, arg_shape, shape_spec_repr).parse()
return _Parser(shape_spec, like, shape_spec_repr).parse()

class _Parser:
def __init__(self,
shape_spec: str,
arg_shape: Sequence[Optional[int]],
like_shape: Optional[Sequence[Optional[int]]],
shape_spec_repr: str):
self.shape_spec = shape_spec
self.shape_spec_repr = shape_spec_repr # For error messages
self.arg_shape = arg_shape
self.like_shape = like_shape
self.dimensions: list[DimSize] = [] # dimensions we have parsed

def parse(self) -> Sequence[DimSize]:
Expand All @@ -975,19 +977,20 @@ def parse(self) -> Sequence[DimSize]:
def add_dim(self, expr: Optional[DimSize], tok: tokenize.TokenInfo):
if expr is None:
raise self.parse_err(tok,
("unexpected placeholder for unknown dimension "
f"for argument shape {self.arg_shape}"))
arg_shape_dim = self.arg_shape[len(self.dimensions)]
if core.is_constant_dim(expr) and arg_shape_dim is not None:
if expr != arg_shape_dim:
("unexpected placeholder for unknown dimension; "
f"like={self.like_shape}"))

if core.is_constant_dim(expr) and self.like_shape is not None:
like_shape_dim = self.like_shape[len(self.dimensions)]
if expr != like_shape_dim:
raise self.parse_err(tok,
(f"different size {expr} for known dimension "
f"for argument shape {self.arg_shape}"))
(f"different size {expr} for known dimension; "
f"like={self.like_shape}"))
self.dimensions.append(expr)

def parse_err(self, tok: Optional[tokenize.TokenInfo], detail: str) -> Exception:
msg = (
f"syntax error in polymorphic shape {self.shape_spec_repr} "
f"syntax error in symbolic shape {self.shape_spec_repr} "
f"in dimension {len(self.dimensions)}: {detail}. ")
if tok is not None:
msg += f"Parsed '{tok.line[:tok.start[1]]}', remaining '{tok.line[tok.start[1]:]}'."
Expand Down Expand Up @@ -1033,17 +1036,31 @@ def shape(self, tok: tokenize.TokenInfo) -> tuple[Sequence[DimSize], tokenize.To
while True:
if tok.exact_type in self.FOLLOW_SHAPE:
break
# Error checking in presence of placeholders
if (tok.exact_type == tokenize.ELLIPSIS or
tok.exact_type == tokenize.NAME and tok.string == "_"):
if self.like_shape is None:
raise self.parse_err(tok,
"spec contains ... but no 'like' shape was given")
if tok.exact_type == tokenize.ELLIPSIS:
min_len_like_shape = len(self.dimensions)
else:
min_len_like_shape = len(self.dimensions) + 1
if len(self.like_shape) < min_len_like_shape:
raise self.parse_err(
tok,
f"cannot resolve placeholder '{tok.string}' because we parsed "
f"{len(self.dimensions)} already and 'like' shape has "
f"only {len(self.like_shape)} dimensions")
if tok.exact_type == tokenize.ELLIPSIS:
to_add = self.arg_shape[len(self.dimensions):]
to_add = self.like_shape[len(self.dimensions):] # type: ignore[index]
for ad in to_add:
self.add_dim(ad, tok)
tok = self.next_tok()
break
if len(self.dimensions) >= len(self.arg_shape):
raise self.parse_err(tok,
f"too many dimensions, arg_shape has {len(self.arg_shape)}")

if tok.exact_type == tokenize.NAME and tok.string == "_":
e = self.arg_shape[len(self.dimensions)]
e = self.like_shape[len(self.dimensions)] # type: ignore[index]
tok = self.next_tok()
else:
e, tok = self.expr(tok)
Expand Down Expand Up @@ -1170,21 +1187,6 @@ def _dimension_size_lowering_rule(ctx, arg, *, dimension):
mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule)


def arg_aval(
arg_shape: Sequence[Optional[int]],
arg_jax_dtype: DType,
polymorphic_shape: Optional[Union[str, PolyShape]]) -> core.ShapedArray:
"""Computes abstract values.
Args:
arg_shape: the shape for the argument, possibly having None dimensions.
arg_dtype: the inferred JAX dtype for the arg.
polymorphic_shape: the polymorphic specification for the argument.
Returns: the JAX abstract value for the argument.
"""
aval_shape = _parse_spec(polymorphic_shape, arg_shape)
return core.ShapedArray(aval_shape, arg_jax_dtype)

def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]:
dim_vars: set[str] = set()
for a in args_avals:
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -367,12 +367,12 @@ def shape_and_dtype_tf(a: TfVal) -> tuple[Sequence[Optional[int]], DType]:
_, a_jax_dtype = _tfval_to_tensor_jax_dtype(a)
return tf_arg_shape, a_jax_dtype

args_specs = export.poly_specs(args_tf,
args_specs = export.args_specs(args_tf,
polymorphic_shapes=polymorphic_shapes,
get_shape_and_dtype=shape_and_dtype_tf)
# The polymorphic_shapes argument refers to positional arguments only.
# We assume None for the kwargs.
kwargs_specs = export.poly_specs(kwargs_tf,
kwargs_specs = export.args_specs(kwargs_tf,
polymorphic_shapes=None,
get_shape_and_dtype=shape_and_dtype_tf)
combined_args_tf = (args_tf, kwargs_tf)
Expand Down Expand Up @@ -652,7 +652,7 @@ def eval_polymorphic_shape(fun_jax: Callable,
(c, a)
"""
def do_eval_polymorphic_shape(*args_specs) -> Any:
args_poly_specs = export.poly_specs(
args_poly_specs = export.args_specs(
args_specs, polymorphic_shapes=polymorphic_shapes)
res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs)
# TODO(necula): For now we export the polymorphic shapes using `str`.
Expand Down

0 comments on commit c6afdfd

Please sign in to comment.