Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When I use opt_einsum optimizes torch.einum, the running time after optimization increases. #202

Open
edwin-zft opened this issue Oct 28, 2022 · 5 comments

Comments

@edwin-zft
Copy link

import numpy as np
import time
import torch
from opt_einsum import contract

dim = 4

x = torch.randn(6 ,4 ,4, 4)
w1 = torch.randn(1,4,dim)
w2 = torch.randn(dim,4,dim)
w3 = torch.randn(dim,4,dim)
w4 = torch.randn(dim,8,dim)
w5 = torch.randn(dim,8,dim)
w6 = torch.randn(dim,4,1)

def naive(x, w1, w2, w3, w4, w5, w6):
    return torch.einsum('bkxy,ikj,jxm,myf,fpl,lqz,zri -> bpqr', x, w1, w2, w3, w4, w5, w6)

def optimized(x, w1, w2, w3, w4, w5, w6):
    return contract('bkxy,ikj,jxm,myf,fpl,lqz,zri -> bpqr', x, w1, w2, w3, w4, w5, w6)

The respective running time:

naive
0.0005145072937011719
optimized
0.876018762588501

I want to know what caused this.Thanks!

@jcmgray
Copy link
Collaborator

jcmgray commented Oct 28, 2022

Hey @edwin-zft, I get:

%%timeit
y = naive(x, w1, w2, w3, w4, w5, w6)
# 536 µs ± 4.06 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

vs.

%%timeit
y = optimized(x, w1, w2, w3, w4, w5, w6)
# 470 µs ± 2.07 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

and as a bonus:

expr = contract_expression(
    'bkxy,ikj,jxm,myf,fpl,lqz,zri->bpqr', 
    x.shape, w1.shape, w2.shape, w3.shape, w4.shape, w5.shape, w6.shape,
    optimize='dp',
)

%%timeit
y = expr(x, w1, w2, w3, w4, w5, w6)
# 72.2 µs ± 758 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

so maybe its just a warm-up issue for you, are you using timeit?

@edwin-zft
Copy link
Author

Thank you for your reply!
I reuse timeit for testing

%%timeit
y = naive(x, w1, w2, w3, w4, w5, w6)
#0.00007126200944185257s (10000 loops)

vs.

%%timeit
y = optimized(x, w1, w2, w3, w4, w5, w6)
#0.00006402703002095222s (10000 loops)

The improvement of running speed after optimization is not obvious. I guess it is due to the particularity of this expression.
Moreover,I tried contract_expression,but it didn't reduce the time.I want to know why.

expr = contract_expression(
    'bkxy,ikj,jxm,myf,fpl,lqz,zri->bpqr', 
    x.shape, w1.shape, w2.shape, w3.shape, w4.shape, w5.shape, w6.shape,
    optimize='dp',
)
%%timeit
y = expr(x, w1, w2, w3, w4, w5, w6)
print(timeit.timeit('y', setup="from __main__ import y",number=10000))
#0.00006920704618096352s (10000 loops)

Finally, thank you very much for your answers and your work!

@jcmgray
Copy link
Collaborator

jcmgray commented Oct 28, 2022

The improvement of running speed after optimization is not obvious.

Some of the recent PRs/issues etc. in torch make it seem like they may have included path optimization themselves - including possibly a version of opt_einsum.

If I increase to dim=4000 the timings are still similar despite a theoretical speedup of 1.828e+14 (compared to doing a single einsum), which would be hard to miss... so it seems torch.einsum at least uses pairwise contractions now.

but it didn't reduce the time. I want to know why.

I don't know the intricacies of timeit, but I guess its running the path optimization to produce expr each time, despite the setup.

@janeyx99
Copy link
Contributor

janeyx99 commented Nov 1, 2022

FYI torch indeed does default to using opt_einsum if it's found in the environment.

@dgasmith
Copy link
Owner

dgasmith commented Nov 2, 2022

FYI torch indeed does default to using opt_einsum if it's found in the environment.

Super cool!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants