Skip to content

Commit

Permalink
[jax2tf] Extend shape polymorphism to handle add_transpose with broad…
Browse files Browse the repository at this point in the history
…casting
  • Loading branch information
gnecula committed Jun 12, 2021
1 parent 750f586 commit edd9688
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
12 changes: 6 additions & 6 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
Returns:
An array containing the product.
"""
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and lhs.shape[-1] == rhs.shape[0]:
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.symbolic_equal_dim(lhs.shape[-1], rhs.shape[0]):
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
precision=precision, preferred_element_type=preferred_element_type)
else:
Expand Down Expand Up @@ -2274,22 +2274,22 @@ def _unbroadcast(aval, x):
if not isinstance(aval, ShapedArray):
raise TypeError("transpose with implicit broadcasting of unshaped values")
x_shape = np.shape(x)
if aval.shape == x_shape:
if core.symbolic_equal_shape(aval.shape, x_shape):
return x
assert not aval.shape or len(x_shape) == len(aval.shape)
if not aval.shape:
return _reduce_sum(x, list(range(len(x_shape))))
else:
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if a != b]
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.symbolic_equal_dim(a, b)]
if config.jax_enable_checks: assert all(aval.shape[i] == 1 for i in dims)
return reshape(_reduce_sum(x, dims), aval.shape)

def _maybe_broadcast(target_shape, x):
x_shape = np.shape(x)
if x_shape == target_shape:
if core.symbolic_equal_shape(x_shape, target_shape):
return x
else:
dims = [i for i, (a, b) in enumerate(zip(x_shape, target_shape)) if a == b]
dims = [i for i, (a, b) in enumerate(zip(x_shape, target_shape)) if core.symbolic_equal_dim(a, b)]
squeeze_shape = [x_shape[i] for i in dims]
return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims)

Expand Down Expand Up @@ -2941,7 +2941,7 @@ def _conv_general_dilated_shape_rule(
msg = ("conv_general_dilated feature_group_count must divide lhs feature "
"dimension size, but {} does not divide {}.")
raise ValueError(msg.format(feature_group_count, lhs_feature_count))
if quot != rhs.shape[dimension_numbers.rhs_spec[1]]:
if not core.symbolic_equal_dim(quot, rhs.shape[dimension_numbers.rhs_spec[1]]):
msg = ("conv_general_dilated lhs feature dimension size divided by "
"feature_group_count must equal the rhs input feature dimension "
"size, but {} // {} != {}.")
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ the ``a_inference_cos_tf_68__``HLO function that was compiled by TF from ``cos_t
## TensorFlow versions supported

The ``jax2tf.convert`` and `call_tf` require very recent versions of TensorFlow.
As of today, the tests are run using `tf_nightly==2.6.0-dev20210601`.
As of today, the tests are run using `tf_nightly==2.6.0-dev20210611`.

## Running on GPU

Expand Down
5 changes: 5 additions & 0 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,11 @@ def _make_harness(group_name: str, name: str,
[RandArg((3, 4), _f32)],
poly_axes=[0]),

_make_harness("add_transpose", "",
jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=0) + x)),
[RandArg((3, 4), _f32)],
poly_axes=[0]),

_make_harness("clamp", "",
lax.clamp,
[RandArg((3, 4, 5), _f32), RandArg((3, 4, 5), _f32),
Expand Down

0 comments on commit edd9688

Please sign in to comment.