Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-use intermediates across various contractions (Common Subexpression Elimination) #9

Closed
ebatz opened this issue Nov 6, 2017 · 15 comments · Fixed by #43
Closed

Re-use intermediates across various contractions (Common Subexpression Elimination) #9

ebatz opened this issue Nov 6, 2017 · 15 comments · Fixed by #43

Comments

@ebatz
Copy link

ebatz commented Nov 6, 2017

Suppose you want to compute two contractions of the same tensors, e.g. contraction strings

['ab,dca,eb,cde', 'ab,cda,eb,cde']

The (globally) optimal way to do this would be to first perform the contractions over indices a,b and e, and then perform the remaining contractions over c and d for the two sets of contractions. The current opt_einsum implementation does not allow for such a global optimization of contraction order and re-use of common intermediates.

I'm still organizing my thoughts on this, and all input would be most welcome. On a side note, I believe implementing such a more general optimization strategy will also fix #7 as a by-product.

Optimization logic

A relevant publication suggesting a concrete algorithm for this optimization problem is

Hartono et al, Identifying Cost-Effective Common Subexpressions to Reduce Operation Count in Tensor Contraction Evaluations

I do not know to what extent the current code can be re-used in this more general setup, but the single-term optimization should be under control with the current implementation.

Interface

Such a multi_contract function could be called with a list of tuples (contractionString, [tensors, ..]), and would return a list of results with the same length as the input list.

Internally, one would have to find out which tensors are actually identical between the various contractions, and then use the above contraction logic. Ideally this information should be deduced automatically and not rely on user input being in the correct form. In the same spirit, dummy indices should be allowed to have arbitrary names, i.e. they should not have to match across different contractions to be correctly identified as a common subexpression.
This may require transforming the input into a 'canonical' form first to make sure that common subexpressions are captured correctly.

In contrast to the setup outlined in Hartono et al, contraction strings should maybe remain in their current 'simple' form and not be generalized to allow for numerical expressions like sums of tensors etc. Such a behavior can be implemented a posteriori with the interface described here by computing the sum of the resulting tensors, e.g.

contract('ab,ab + ab,ba', N,M) --> 'sum'(multi_contract( [('ab,ab', N,M), ('ab,ba', N,M)] ))

Thus, restricting contraction strings to be of the form currently used does not cause loss of generality, the only downside being that it might lead to a slightly increased memory-footprint as the function would return several arrays instead of one.

Other thoughts?

@dgasmith
Copy link
Owner

dgasmith commented Nov 6, 2017

I need to think a bit on the optimization logic and how to implement it. Im kind of thinking if we are going to build another interface we should only build one that is relatively bullet proof. What if we had something like the following:

A = np.random.rand(4, 5, 6)
B = np.random.rand(5, 6, 7)
term1, term2 = contract(["A_abc B_bcd", "B_bcd A_abc"], {"A": A, "B": B}) # only one contraction
term1, term2 = contract(["A_abc B_bcd + A_abc", "B_bcd A_abc"], {"A": A, "B": B}) # one contraction and one addition

This could even be extend to:

C = np.zeros((4, 7))
contract(["C_cd -= A_abc B_bcd"], {"A": A, "B": B}, out={"C" : C})

This would be similar to something like numexpr.

@ebatz
Copy link
Author

ebatz commented Nov 7, 2017

I might be bikeshedding the interface question, but to me, your suggested call signature seems unnecessarily complicated in two regards,

  1. from the user, it does require identical tensors to be labeled consistently, which instead can be deduced automatically from the input in traditional form -- automatic deduction reduces the chances of error. Your suggested contraction string form could serve as 'canonical form' internally, so if we understand how to implement the logic starting from your argument form, we can decide on the final user interface later

  2. I would rather avoid any sort of arithmetic in the input string; I feel this adds unneeded complexity, for one could transform your second example as I suggested in the original post.

@dgasmith
Copy link
Owner

dgasmith commented Nov 7, 2017

  1. If we are extending the input id rather not limit the user, there are plenty of cases where there are common subexpression but the input tensors are not identical for each equation. The trivial example would be:
contract(["A_ab B_bc C_cd", "A_ab B_bc D_cd"], ...)
  1. Agreed, that kind of tech can quickly get out of hand. Its worth exploring the limits of the scope then shrinking back.

@jcmgray
Copy link
Collaborator

jcmgray commented Mar 23, 2018

This is now possible using the new backend support #17 (to an extent - you need to use dask).
Specifically, dask uniquely identifies arrays and operations and reuses the results whenever it can. I just checked this works and can write an example if helpful.

@ebatz
Copy link
Author

ebatz commented Mar 23, 2018

That's exciting! If you could give me a hint on how to start that would be most welcome.

@jcmgray
Copy link
Collaborator

jcmgray commented Mar 24, 2018

Here's a quick demonstration:

Set up to get the to the point where we have the 5 numpy arrays and two expression:

import opt_einsum as oe
import numpy as np
import dask.array as da
from dask import delayed

sizes = {l: 10 for l in 'abcde'}

contraction1 = 'ab,dca,eb,cde'
terms1 = contraction1.split(',')

contraction2 = 'ab,cda,eb,cde'
terms2 = contraction2.split(',')

# define the intial set of arrays
inputs = sorted(set((*terms1, *terms2)))
np_arrays = {s: np.random.randn(*(sizes[c] for c in s)) for s in inputs}

Now convert them to dask arrays and perform the contraction:

da_arrays = {s: da.from_array(np_arrays[s], chunks=1000) for s in inputs}

# select dask arrays for each expression
ops1  = [da_arrays[s] for s in terms1]
ops2  = [da_arrays[s] for s in terms2]

# contract!
dy1 = oe.contract(contraction1, *ops1, backend='dask')
dy2 = oe.contract(contraction2, *ops2, backend='dask')

# just wrap them in delayed so as to combine in the same computation
dy3 = delayed([dy1, dy2])
dy3.compute()
[77.13078823390276, -41.30794090122288]

And we can check that the intermediaries are being used by visualizing the task graph:

dy3.visualize(optimize_graph=True)

dask-graph

@dgasmith
Copy link
Owner

@jcmgray Thats really neat, could you write this up in the docs?

@fritzo
Copy link
Contributor

fritzo commented Aug 17, 2018

We're achieving shared computation in Pyro by creating a custom deferred backend that cons hashes expressions. It can then use any other backend during evaluation (we're using the torch backend for evaluation).

The syntax is

x = torch.randn(2, 3)
y = torch.randn(3, 4)
z = torch.randn(4, 5)
with shared_intermediates():
    x_, y_, z_ = map(deferred_tensor, [x, y, z])
    a_ = opt_einsum.contract('ab,bc,cd->a', backend='pyro.ops.einsum.deferred')
    b_ = opt_einsum.contract('ab,bc,cd->b', backend='pyro.ops.einsum.deferred')
    c_ = opt_einsum.contract('ab,bc,cd->c', backend='pyro.ops.einsum.deferred')
    d_ = opt_einsum.contract('ab,bc,cd->d', backend='pyro.ops.einsum.deferred')
a = a_.eval(backend='torch')
b = b_.eval(backend='torch')
c = c_.eval(backend='torch')
d = d_.eval(backend='torch')

Do you have any interest in us moving this backend upstream into opt_einsum?

cc @eb8680

@dgasmith
Copy link
Owner

I think this would be neat. We might need to play a bit with the interface, but the overall idea of being able to tie intermediates together is a very positive one, and the solution is elegant.

@jcmgray
Copy link
Collaborator

jcmgray commented Aug 18, 2018

It would be really nice to have. I did play around with something similar very briefly (e.g. this gist which can act as a backend), but actually hashing the arrays using xxhash in case they mutated - not very efficient! Ultimately my needs were met by the constants functionality.

A context manager is what I had in mind as well, since it nice and explicitly limits the scope/size of the cache.

My first thought would be, is the completely delayed .eval method needed? Couldn't einsum/tensordot just eagerly hash / id & cache / retrieve?

@fritzo
Copy link
Contributor

fritzo commented Aug 18, 2018

@jcmgray is the completely delayed .eval method needed?

That's a good point. I think we could change this to eagerly evaluate. It would also be nice to avoid the need to wrap inputs in deferred_tensor(). I believe this would require special handling inside contract(). One option would be to add special handling inside of contract() to check if it is inside a shared_intermediates() context, if so it would wrap in deferred_tensor and unwrap via .eval() at the end, all under-the-hood. This would act exactly like non-shared contract() from the user perspective.

@dgasmith
Copy link
Owner

In the hash, I don't think we need to hash all of the data. Looking at the interface of a non-writeable view would be sufficient I believe?

>>> a = np.arange(6)
>>> b = a.reshape(3, 2)
>>> a.__array_interface__
{'data': (140415706605184, False), 'strides': None, 'descr': [('', '<i8')], 'typestr': '<i8', 'shape': (6,), 'version': 3}
>>> b.__array_interface__
{'data': (140415706605184, False), 'strides': None, 'descr': [('', '<i8')], 'typestr': '<i8', 'shape': (3, 2), 'version': 3}

Dask probably has an optimal way of getting, but I didn't see the hashing function on a quick browse.

I don't think a bit of overhead checking if a result is in a cache would hurt contract. The overhead of parsing the string Python side is already quite high.

@fritzo
Copy link
Contributor

fritzo commented Aug 18, 2018

@dgasmith I don't think we need to hash all of the data

👍 I've updated my prototype to be eager (as @jcmgray suggested), and it now clearly hashes by the id() of the passed array object. Now there is no longer need for users to wrap inputs and unwrap outputs, and the shared backend is only used internally.

@jcmgray
Copy link
Collaborator

jcmgray commented Aug 18, 2018

Looks cool! I like the idea of doing this kind of thing:

with oe.shared_intermediaries() as cache:
    # do some stuff
    oe.contract(...)

# do some non cached stuff

# use the same cache again
with oe.shared_intermediaries(cache):
    # do even more stuff
    oe.contract(...)

Which looks possible with my brief reading of your code.

Also, yes with an explicitly managed scope I think hashing just using id is fine. On that note, and its entirely possible that I'm missing something, but is the _Shared object needed at all? Could the key not just be generated with id directly e.g.:

key = 'tensordot', backend, id(x), id(y), axes

Thus keeping all the logic in the context manager and backend functions.

While I remeber, I think this kind of change might require modifying opt_einsum to try {backend}.transpose(array, axes) before array.transpose(axes) to give priority to custom functions.

@dgasmith
Copy link
Owner

Interestingly id will work for deeper intermediates. Consider two contractions with the first three tensors being identical:

ab,bc,cd,defg->aefg
ab,bc,cd,def->aef

Steps:

1) ab,bc,cd,defg->aefg:   ab,bc->ac # cached
2) ac,cd,defg->aefg:      ac,cd->ad # cached
3) ad,defg->aefg:         ad,defg->aefg

4) ab,bc,cd,def->aef:     ab,bc->ac # pull result from step 1
5) ac,cd,defg->aef:       ac,cd->ad # pull result from step 2 the tensor id of`ac` is the same as above
6) ad,defg->aef:          ad,defg->aef

The other thing to consider is the "switched case":

ab,bc->ac   # tensor 1/2
bc,ab->ac   # tensor 2/1

Normal ordering should be applied to the key to watch for these cases. A simple id(x) > id(y) metric would work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants