diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index c69de586c920..91c7af983f74 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3440,7 +3440,6 @@ def einsum( optimize: str = "optimal", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, - _use_xeinsum: bool = False, _dot_general: Callable[..., Array] = lax.dot_general, ) -> Array: ... @@ -3453,7 +3452,6 @@ def einsum( optimize: str = "optimal", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, - _use_xeinsum: bool = False, _dot_general: Callable[..., Array] = lax.dot_general, ) -> Array: ... @@ -3465,20 +3463,15 @@ def einsum( optimize: str = "optimal", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, - _use_xeinsum: bool = False, _dot_general: Callable[..., Array] = lax.dot_general, ) -> Array: operands = (subscripts, *operands) if out is not None: raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") - spec = operands[0] if isinstance(operands[0], str) else None - - if (_use_xeinsum or spec is not None and '{' in spec): + if spec is not None and '{' in spec: return jax.named_call(lax.xeinsum, name=spec)(*operands) - optimize = 'optimal' if optimize is True else optimize - # using einsum_call=True here is an internal api for opt_einsum # Allow handling of shape polymorphism non_constant_dim_types = { @@ -3490,6 +3483,7 @@ def einsum( else: ty = next(iter(non_constant_dim_types)) contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler) + # using einsum_call=True here is an internal api for opt_einsum... sorry operands, contractions = contract_path( *operands, einsum_call=True, use_blas=True, optimize=optimize) diff --git a/tests/xmap_test.py b/tests/xmap_test.py index d93f2c5a59f9..5264af3f3d9c 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -1491,7 +1491,7 @@ def test_xeinsum_no_named_axes_vector_dot(self): rng = self.rng() x = rng.randn(3) y = rng.randn(3) - out = jnp.einsum('i,i->', x, y, _use_xeinsum=True) + out = jnp.einsum('i,i->', x, y) expected = np.einsum('i,i->', x, y) self.assertAllClose(out, expected, check_dtypes=False) @@ -1499,7 +1499,7 @@ def test_xeinsum_no_named_axes_batch_vector_dot(self): rng = self.rng() x = rng.randn(3, 2) y = rng.randn(3, 2) - out = jnp.einsum('ij,ij->i', x, y, _use_xeinsum=True) + out = lax.xeinsum('ij,ij->i', x, y) expected = np.einsum('ij,ij->i', x, y) self.assertAllClose(out, expected, check_dtypes=True) @@ -1508,7 +1508,7 @@ def test_xeinsum_no_named_axes_batch_matmul(self): rng = np.random.RandomState(0) x = rng.randn(3, 5, 4) y = rng.randn(3, 4, 2) - out = jnp.einsum('bij,bjk->bik', x, y, _use_xeinsum=True) + out = lax.xeinsum('bij,bjk->bik', x, y) expected = np.einsum('bij,bjk->bik', x, y) self.assertAllClose(out, expected, check_dtypes=True) @@ -1516,7 +1516,7 @@ def test_xeinsum_no_named_axes_reduce_sum(self): rng = self.rng() x = rng.randn(3) y = rng.randn() - out = jnp.einsum('i,->', x, y, _use_xeinsum=True) + out = lax.xeinsum('i,->', x, y) expected = np.einsum('i,->', x, y) self.assertAllClose(out, expected, check_dtypes=True) @@ -1526,7 +1526,7 @@ def test_xeinsum_no_named_axes_reduce_and_contract(self): rng = np.random.RandomState(0) x = rng.randn(3, 5, 4) y = rng.randn(2, 4, 2) - out = jnp.einsum('bij,cjk->ik', x, y, _use_xeinsum=True) + out = lax.xeinsum('bij,cjk->ik', x, y) expected = np.einsum('bij,cjk->ik', x, y) self.assertAllClose(out, expected, check_dtypes=True)