Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle error in case outputs subscripts of xeinsum are not unique #18670

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

lgeiger
Copy link
Contributor

@lgeiger lgeiger commented Nov 24, 2023

np.einsum("ij->jij", np.zeros((2, 2)))

throws an error since the output subscript j is specified multiple times:

ValueError: einstein sum subscripts string includes output subscript 'j' multiple times

However

jnp.einsum_path("ij->jij", jnp.zeros((2, 2)))

runs into an internal assertion error which is not very user friendly.

   3629   raise NotImplementedError  # if this is actually reachable, open an issue!
   3631 # the resulting 'operand' with axis labels 'names' should be a permutation
   3632 # of the desired result
-> 3633 assert len(names) == len(result_names) == len(set(names))
   3634 assert set(names) == set(result_names)
   3635 if names != result_names:

AssertionError:

Since jax relies on opt_einsum for its error handling I made an upstream PR to fix that in dgasmith/opt_einsum#222 (also see numpy/numpy#25230 which does the same for np.einsum_path)

This PR fixes the _use_xeinsum=True code path which currently would completely ignore the error and still output an array:

jnp.einsum_path("ij->jij", jnp.zeros((2, 2)), _use_xeinsum=True)
# Array([[0., 0.],
#        [0., 0.]], dtype=float32)

Comment on lines +484 to +485
if len(out_subs) != len(set(out_subs)):
raise ValueError("Output subscripts should be unique.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to do something like this if we want a more verbose error message and don't mind iterating over the array

    for out_sub in out_subs:
        if out_subs.count(out_sub) != 1:
            raise ValueError("Output subscript %s appeared more than once in the output." % out_sub))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants