Skip to content

Commit

Permalink
[jax2tf] Allows shape polymorphic specification to be polynomials.
Browse files Browse the repository at this point in the history
Until now, the polymorphic_shapes parameters could contain only
a constant or a dimension variable in each dimension. With this PR
we allow polynomials. These are needed in two situations:

  * when converting the VJP of a shape polymorphic function the shape
  specification corresponding to the cotangent must match the
  shape of the output of the primal function, which may contain
  polynomials of dimension variables.
  * one can specify that a dimension is even-sized by writing it
  as "2 * b", or that it is at least 10, by writing "b + 9".

This change requires changes to the code that solves the dimension
variables in terms of `tf.shape(arg)`. The dimension variables
are solved only from linear uni-variate polynomials followed by
replacing the solved variables in the other polynomials and
repeating until all variables are solved.
  • Loading branch information
gnecula committed Aug 6, 2021
1 parent d6df61c commit 29ffe9a
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 165 deletions.
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ self.assertEqual(0, f_jax(x0)) # JAX sees that the x.shape[0] == 0
# jax2tf catches the broken assumption b >= 1 if the converted function is executed
# eagerly.
# Raises: ValueError: PolyShape ('b',) has dimension variable 'b' corresponding to 0, for argument shapes (TensorShape([0]),)
# Raises: ValueError: Dimension variable b must have integer value >= 1. Found value 0 when solving b == 0
jax2tf.convert(f_jax, polymorphic_shapes=["b"])(x0))
# However, if we first trace to a TensorFlow graph, we may miss the broken assumption:
Expand All @@ -753,7 +753,7 @@ self.assertEqual(0, f_jax(x45)) # JAX seems that x.shape[0] != x.shape[1]
# jax2tf catches the broken assumption x.shape[0] == x.shape[1] if the converted
# function is executed eagerly.
# Raises: ValueError: PolyShape ('b, b',) has dimension variable 'b' corresponding to multiple values {4, 5}, for argument shapes (TensorShape([4, 5]),)
# Raises: ValueError: polymorphic shape ('b, b',) has dimension variable 'b' corresponding to multiple values {4, 5}, for argument shapes (TensorShape([4, 5]),)
jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])(x45)
# However, if we first trace to a TensorFlow graph, we may miss the broken assumption.
Expand Down
48 changes: 10 additions & 38 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
import collections
from functools import partial
import contextlib
import os
import re
import string
import threading
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import jax
from jax._src import ad_util
Expand Down Expand Up @@ -83,11 +82,6 @@ def _sanitize_scope_name(name):
DType = Any
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision

# A dimension environment maps dimension variables to TF expressions that
# compute the value of the dimension. These expressions refer to the TF
# function arguments.
_ShapeEnv = Dict[str, TfVal]

def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
return True
Expand Down Expand Up @@ -147,7 +141,7 @@ def __init__(self):
self.inside_call_tf = False

# Maps dimension variables to TF expressions
self.shape_env: _ShapeEnv = {}
self.shape_env: shape_poly.ShapeEnv = {}

# Whether to actually include XLA op metadata in the generated TF ops
self.include_xla_op_metadata = True
Expand Down Expand Up @@ -595,67 +589,45 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
return tf.convert_to_tensor(val, dtype=conversion_dtype), jax_dtype


def _args_to_avals_and_env(
args: Sequence[TfVal],
arg_jax_dtypes: Sequence[DType],
polymorphic_shapes: Sequence[Optional[Union[str, PolyShape]]]) -> \
Tuple[Sequence[core.ShapedArray], _ShapeEnv]:
Tuple[Sequence[core.ShapedArray], shape_poly.ShapeEnv]:
"""Computes canonicalized args, abstract values and a dimension environment for arguments.
Args:
args: the arguments, TF inputs. Must be tf.Tensor or tf.Variable.
arg_dtypes: the inferred JAX dtypes for the args.
polymorphic_shapes: the polymorphic specifications for the arguments.
Returns: a tuple of: a sequence of abstract values corresponding to the
arguments, and a dimension environment.
arguments, and a dimension variable environment.
"""
shapeenv: _ShapeEnv = {}

# Map shape variables to the set of integers they correspond to in the
# actual arguments
shape_var_map: Dict[str, Set[int]] = collections.defaultdict(set)
dim_equations: List[shape_poly.DimEquation] = []

def input_aval(arg: TfVal,
arg_jax_dtype: DType,
polymorphic_shape: Optional[str]) -> core.ShapedArray:
"""The abstract value for an input."""
arg_shape = np.shape(arg)
aval_shape = shape_poly.parse_spec(polymorphic_shape, arg_shape)

arg_tf_shape = tf.shape(arg)
for i, d in enumerate(aval_shape):
dim_size = arg_shape[i]
if isinstance(dim_size, tf.compat.v1.Dimension):
dim_size = dim_size.value
if not shape_poly.is_poly_dim(d):
assert d == dim_size
else:
d_var = d.to_var() # type: ignore
if d_var is not None:
if d_var not in shapeenv:
# Even if the shape of `arg` is known, we still use `tf.shape` for
# safety, because the promise is that we will convert the function
# to work for any value of the dimension.
shapeenv[d_var] = tf.shape(arg)[i] # type: ignore[index]
if dim_size is not None:
shape_var_map[d_var].add(int(dim_size))
dim_equations.append(shape_poly.DimEquation(
poly=d, tf_expr=arg_tf_shape[i])) # type: ignore


return core.ShapedArray(aval_shape, arg_jax_dtype)

avals = tuple(map(input_aval, args, arg_jax_dtypes, polymorphic_shapes)) # type: ignore
arg_shapes = tuple(np.shape(a) for a in args)

for dim_var, dim_var_values in shape_var_map.items():
if len(dim_var_values) != 1:
msg = (f"PolyShape {tuple(polymorphic_shapes)} has dimension variable '{dim_var}' "
f"corresponding to multiple values {set(sorted(dim_var_values))}, for "
f"argument shapes {arg_shapes}")
raise ValueError(msg)
elif list(dim_var_values)[0] <= 0:
msg = (f"PolyShape {tuple(polymorphic_shapes)} has dimension variable '{dim_var}' "
f"corresponding to 0, for argument shapes {arg_shapes}")
raise ValueError(msg)

shapeenv = shape_poly.solve_dim_equations(dim_equations)
return avals, shapeenv


Expand Down

0 comments on commit 29ffe9a

Please sign in to comment.