Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Einsum gets slow with multiple arguments #5366

Closed
perimosocordiae opened this issue Dec 11, 2014 · 4 comments
Closed

Einsum gets slow with multiple arguments #5366

perimosocordiae opened this issue Dec 11, 2014 · 4 comments

Comments

@perimosocordiae
Copy link
Contributor

With multiple arguments, einsum blows up fast:

In [1]: import numpy as np

In [2]: a = np.random.random((5,5))

In [3]: %timeit a.dot(a).dot(a).dot(a)
100000 loops, best of 3: 3.3 µs per loop

In [4]: %timeit np.linalg.multi_dot((a,a,a,a))
10000 loops, best of 3: 30.8 µs per loop

In [5]: %timeit np.einsum('ij,kl,mn,op->ip', a,a,a,a)
100 loops, best of 3: 7.2 ms per loop

In [6]: a = np.random.random((50,50))

In [7]: %timeit a.dot(a).dot(a).dot(a)
10000 loops, best of 3: 82.1 µs per loop

In [8]: %timeit np.linalg.multi_dot((a,a,a,a))
10000 loops, best of 3: 150 µs per loop

In [9]: %timeit np.einsum('ij,kl,mn,op->ip', a,a,a,a)
# takes so long I had to kill IPython.

Is there some fundamental limitation involved, or is this a bug?

@njsmith
Copy link
Member

njsmith commented Dec 11, 2014

I'm not an einsum expert, so I don't know what "ij,kl,mn,op->ip" means. But it isn't the same as dot:

In [12]: np.max(np.abs(np.einsum("ij,kl,mn,op->ip", a, a, a, a) - a.dot(a).dot(a).dot(a)))
Out[12]: 1629.9863026164651

You want something like "ij,jk,kl,lm":

In [13]: np.max(np.abs(np.einsum("ij,jk,kl,lm", a, a, a, a) - a.dot(a).dot(a).dot(a)))
Out[13]: 7.1054273576010019e-15

This is still slower than dot but by a much less embarrassing margin:

In [15]: %timeit a.dot(a).dot(a).dot(a)
100000 loops, best of 3: 2.58 µs per loop

In [16]: %timeit np.einsum("ij,jk,kl,lm", a, a, a, a)
10000 loops, best of 3: 57.7 µs per loop

@jaimefrio
Copy link
Member

Someone more in the know may want to correct me, but I looked at something related to this a while back, and my conclusion was that what tripped einsum was not using any intermediate storage.

A much simpler example:

In [2]: a, b, c =  np.random.rand(3, 1000)

In [3]: %timeit a * np.einsum('i,i', b, c)
100000 loops, best of 3: 6.38 us per loop

In [4]: %timeit np.einsum('j,i,i', a, b, c)
1000 loops, best of 3: 1.16 ms per loop

In [5]: np.allclose(a * np.einsum('i,i', b, c), np.einsum('j,i,i', a, b, c))
Out[5]: True

In the first instance, the dot product of b and c is computed once, and that value is multiplied by every item in a. In the second case, it is redoing that calculation for every item in a.

I don't think this qualifies as a bug, and I am not even sure a general case enhancement is a viable option: you would need to come up with an algorithm that detects groups of operands that could be collapsed, and if some of the operands belong to more than one group, decide in what order to do the collapsing. For the chained matrix multiplication this is solved with a more or less standard DP algorithm, as in multi_dot, for the general case it looks like a daunting task. Although perhaps for 3 or 4 operands a brute force search is a viable option...

@perimosocordiae
Copy link
Contributor Author

@njsmith: Thanks, that's what I get for not sanity checking. Your example is the one I had initially run across.

@jaimefrio Ok, that makes sense. I think my mental model for how einsum works was off. I'm not sure how to make the docs clearer, though.

@charris
Copy link
Member

charris commented Dec 20, 2014

Closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants