Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A need for contiguity-based axis-order optimization in tensordot #11940

Open
rsokl opened this issue Sep 13, 2018 · 2 comments
Open

A need for contiguity-based axis-order optimization in tensordot #11940

rsokl opened this issue Sep 13, 2018 · 2 comments

Comments

@rsokl
Copy link
Contributor

rsokl commented Sep 13, 2018

The ordering of axes fed to tensordot can have a massive (order of magnitude) impact on its efficiency, based on the memory layout of the array(s) being summed:

>>> import numpy as np
>>> x = np.random.rand(100, 100, 100)
>>> %%timeit
... np.tensordot(x, x, axes=((0, 1, 2), (0, 1, 2)))  
151 µs ± 6.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
>>> %%timeit
... np.tensordot(x, x, axes=((1, 2, 0), (1, 2, 0))) 
7.9 ms ± 143 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Moving x's axis leads to a swap in timing:

>>>  xt = np.moveaxis(x, -1, 0)
>>> %%timeit
... np.tensordot(xt, xt, axes=((0, 1, 2), (0, 1, 2)))  
10.8 ms ± 213 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %%timeit
... np.tensordot(xt, xt, axes=((1, 2, 0), (1, 2, 0))) 
146 µs ± 4.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

As suggested by @eric-wieser, tensordot would benefit from axis-ordering based on memory contiguity to help guard against these massive slow downs.

@liwt31
Copy link
Contributor

liwt31 commented Oct 17, 2018

I've run into this problem in my projects, and found that the origin of the difference comes from the reshape function in tensordot:

at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)

If we have to generate a new matrix here, then it's very time consuming.

I think a general way to solve this problem is not easy, because if we rely on numpy.dot to finally do the contraction, we can't avoid rearrange data in all cases. Maybe a better approach is to use einsum when the reshape is expensive.

I'm happy to prepare a PR for this, but I think I should hear more advice first.

@liwt31
Copy link
Contributor

liwt31 commented Oct 17, 2018

Here are results of some experiment. Firstly repeat previous results:

>>> import numpy as np
>>> x = np.random.rand(100, 100, 100)
>>> %timeit np.tensordot(x, x, axes=((0, 1, 2), (0, 1, 2)))
... 1000 loops, best of 3: 547 µs per loop
>>> %timeit np.tensordot(x, x, axes=((1, 2, 0), (1, 2, 0))) 
... 100 loops, best of 3: 13.8 ms per loop
>>> xt = np.moveaxis(x, -1, 0)
>>> %timeit np.tensordot(xt, xt, axes=((0, 1, 2), (0, 1, 2)))  
... 100 loops, best of 3: 17 ms per loop
>>> %timeit np.tensordot(xt, xt, axes=((1, 2, 0), (1, 2, 0))) 
... 1000 loops, best of 3: 498 µs per loop

Everything goes "as expected".
Then deal with x, mimic the logic of tensordot function to decompose the time cost:

>>> %timeit x.transpose((0, 1, 2)).reshape(100 ** 3)
... The slowest run took 7.41 times longer than the fastest. This could mean that an intermediate result is being cached.
    1000000 loops, best of 3: 1.23 µs per loop
>>> x_ = x.transpose((0, 1, 2)).reshape(100 ** 3)
>>> %timeit x_.dot(x_)
... 1000 loops, best of 3: 340 µs per loop
>>> %timeit x.transpose((1, 2, 0)).reshape(100 ** 3)
... 100 loops, best of 3: 5.55 ms per loop
>>> x_ = x.transpose((1, 2, 0)).reshape(100 ** 3)
>>> %timeit x_.dot(x_)
... 1000 loops, best of 3: 344 µs per loop

I'm not sure what the "cache" thing means, but the conclusion that reshape takes lots of time seems clear.
Now for xt:

>>> %timeit xt.transpose((0, 1, 2)).reshape(100 ** 3)
... 100 loops, best of 3: 7.85 ms per loop
>>> xt_ = xt.transpose((0, 1, 2)).reshape(100 ** 3)
>>> %timeit xt_.dot(xt_)
... 1000 loops, best of 3: 345 µs per loop
>>> %timeit xt.transpose((1, 2, 0)).reshape(100 ** 3)
... The slowest run took 8.00 times longer than the fastest. This could mean that an intermediate result is being cached.
    1000000 loops, best of 3: 1.28 µs per loop
>>> xt_ = xt.transpose((1, 2, 0)).reshape(100 ** 3)
>>> %timeit xt_.dot(xt_)
... 1000 loops, best of 3: 345 µs per loop

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants