Skip to content

Commit

Permalink
remove _use_xeinsum from jnp.einsum api
Browse files Browse the repository at this point in the history
can still call jnp.einsum with a '{' in the spec string to trigger xeinsum, or
just call lax.xeinsum directly
  • Loading branch information
mattjj committed Dec 9, 2023
1 parent e50ef1b commit 9a1a09c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 13 deletions.
10 changes: 2 additions & 8 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -3440,7 +3440,6 @@ def einsum(
optimize: str = "optimal",
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_use_xeinsum: bool = False,
_dot_general: Callable[..., Array] = lax.dot_general,
) -> Array: ...

Expand All @@ -3453,7 +3452,6 @@ def einsum(
optimize: str = "optimal",
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_use_xeinsum: bool = False,
_dot_general: Callable[..., Array] = lax.dot_general,
) -> Array: ...

Expand All @@ -3465,20 +3463,15 @@ def einsum(
optimize: str = "optimal",
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_use_xeinsum: bool = False,
_dot_general: Callable[..., Array] = lax.dot_general,
) -> Array:
operands = (subscripts, *operands)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")

spec = operands[0] if isinstance(operands[0], str) else None

if (_use_xeinsum or spec is not None and '{' in spec):
if spec is not None and '{' in spec:
return jax.named_call(lax.xeinsum, name=spec)(*operands)

optimize = 'optimal' if optimize is True else optimize
# using einsum_call=True here is an internal api for opt_einsum

# Allow handling of shape polymorphism
non_constant_dim_types = {
Expand All @@ -3490,6 +3483,7 @@ def einsum(
else:
ty = next(iter(non_constant_dim_types))
contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
# using einsum_call=True here is an internal api for opt_einsum... sorry
operands, contractions = contract_path(
*operands, einsum_call=True, use_blas=True, optimize=optimize)

Expand Down
10 changes: 5 additions & 5 deletions tests/xmap_test.py
Expand Up @@ -1491,15 +1491,15 @@ def test_xeinsum_no_named_axes_vector_dot(self):
rng = self.rng()
x = rng.randn(3)
y = rng.randn(3)
out = jnp.einsum('i,i->', x, y, _use_xeinsum=True)
out = jnp.einsum('i,i->', x, y)
expected = np.einsum('i,i->', x, y)
self.assertAllClose(out, expected, check_dtypes=False)

def test_xeinsum_no_named_axes_batch_vector_dot(self):
rng = self.rng()
x = rng.randn(3, 2)
y = rng.randn(3, 2)
out = jnp.einsum('ij,ij->i', x, y, _use_xeinsum=True)
out = lax.xeinsum('ij,ij->i', x, y)
expected = np.einsum('ij,ij->i', x, y)
self.assertAllClose(out, expected, check_dtypes=True)

Expand All @@ -1508,15 +1508,15 @@ def test_xeinsum_no_named_axes_batch_matmul(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 5, 4)
y = rng.randn(3, 4, 2)
out = jnp.einsum('bij,bjk->bik', x, y, _use_xeinsum=True)
out = lax.xeinsum('bij,bjk->bik', x, y)
expected = np.einsum('bij,bjk->bik', x, y)
self.assertAllClose(out, expected, check_dtypes=True)

def test_xeinsum_no_named_axes_reduce_sum(self):
rng = self.rng()
x = rng.randn(3)
y = rng.randn()
out = jnp.einsum('i,->', x, y, _use_xeinsum=True)
out = lax.xeinsum('i,->', x, y)
expected = np.einsum('i,->', x, y)
self.assertAllClose(out, expected, check_dtypes=True)

Expand All @@ -1526,7 +1526,7 @@ def test_xeinsum_no_named_axes_reduce_and_contract(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 5, 4)
y = rng.randn(2, 4, 2)
out = jnp.einsum('bij,cjk->ik', x, y, _use_xeinsum=True)
out = lax.xeinsum('bij,cjk->ik', x, y)
expected = np.einsum('bij,cjk->ik', x, y)
self.assertAllClose(out, expected, check_dtypes=True)

Expand Down

0 comments on commit 9a1a09c

Please sign in to comment.