Handle error in case outputs subscripts of xeinsum are not unique #18670
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
throws an error since the output subscript
j
is specified multiple times:However
runs into an internal assertion error which is not very user friendly.
3629 raise NotImplementedError # if this is actually reachable, open an issue! 3631 # the resulting 'operand' with axis labels 'names' should be a permutation 3632 # of the desired result -> 3633 assert len(names) == len(result_names) == len(set(names)) 3634 assert set(names) == set(result_names) 3635 if names != result_names: AssertionError:
Since jax relies on
opt_einsum
for its error handling I made an upstream PR to fix that in dgasmith/opt_einsum#222 (also see numpy/numpy#25230 which does the same fornp.einsum_path
)This PR fixes the
_use_xeinsum=True
code path which currently would completely ignore the error and still output an array: