Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement shared_intermediates context manager #43

Merged
merged 19 commits into from Aug 22, 2018

Conversation

fritzo
Copy link
Contributor

@fritzo fritzo commented Aug 20, 2018

Resolves #9

Description

This implements a shared_intermediates context manager, within which multiple calls to contract will share intermediate computations. This is very useful e.g. in message passing algorithms. The first draft of this PR is copied directly from Pyro, where we are using sharing for probabilistic inference in graphical models.

The implementation uses a special internal backend opt_einsum.backends.shared. When a sharing context is opened

with shared_intermediates():
    contract(..., backend='foo')
    contract(..., backend='foo')

a special handle_sharing() context is activated inside the contract function, temporarily switching to the shared backend, and setting the original backend foo in a global shared._CURRENT_BACKEND variable. The sharing backend then performs all the original foo operations and also memoizes them.

One design choice is to memoize by id() rather than by value. This makes sense from a computational perspective (equality comparison is expensive), but requires a bit more care by users.

A second design choice is to expose the memoization cache to users. This makes it easy to share intermediates by passing the cache around, as is necessary when it is difficult to perform all computation in a context:

with shared_intermediates() as cache:
    contract(...)
...pass control elsewhere...
with shared_intermediates(cache):  # <-- reuse previous cache
    contract(...)
del cache  # <-- intermediates will be garbage collected

Todos

  • Cache type conversions e.g. numpy->torch->numpy so args can be compared by id.
  • fix failing tests
  • add tests that pass around the cache
  • add tests where we tests the normal ordering:
    with shared_intermeidates() as cache:
        a = np.random.rand(4, 4)
        b = np.random.rand(4, 4)
        contract("ab,bc->ac", a, b)
    
    assert get_cache_key("zc,az->ac", b, a) in cache
    ...
  • test against the Pyro rm-einsum-shared branch
  • test nested sharing
  • write docs
  • move _alpha_canonicalize() to parser.py
  • test caching with constant expressions

Questions

  • Could a maintainer suggest where the handle_sharing() context should be moved? I've attempted to insert it into ContractExpression.__call__(), but I'm a bit lost.

Status

  • Ready to go

"The internal error was: '%s'" % original_msg, )
err.args = msg
raise
with handle_sharing(backend) as backend:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really the only line changed; I've simply moved the following lines inside this context.

@fritzo
Copy link
Contributor Author

fritzo commented Aug 20, 2018

cc @eb8680

@jcmgray
Copy link
Collaborator

jcmgray commented Aug 20, 2018

Nice, looking forward to testing this and putting it to use in the wild! I do have a suggestion to simplify this a bit further (sorry for suggesting all these refactors!).

Basically, instead of making it a backend, and then storing the real desired backend in a temporary and using the handle_sharing, the cached versions of these could could just be turned on directly in the core _tensordot, _einsum etc. Maybe as decorators defined (along with shared_intermediaries) in a shared.py file.

e.g.

# opt_einsum/shared.py

def shared_intermediaries(...):
    # and other stuff

def tensordot_cache_wrap(tensordot_fn):

    def tensordot_cached(x, y, axes, backend):
        if not _SHARING_STACK:
            return tensordot_fn(x, y, axes, backend)

        ...  # parse and get global cache etc
        key = (id(x), id(y), axes, backend)
        if key not in cache:
            cache[key] = tensordot_fn(x, y, axes, backend)

        return cache[key]

    return tensordot_cached

The advantage is that in contract.py we then just need

@tensordot_cache_wrap
def _tensordot(x, y, axes, backend)
    ....

I.e. the caching logic is really factored out and handle_shared and _CURRENT_BACKEND are not needed. Additionally, the calls would be explicitly routed through _einsum, _tensordot & _transpose (in future they might be modified). The decorator approach is obviously optional - just a nice away of keeping the core contract logic clean -- but the key thing is that shared_intermediaries might not need to use the backend api itself.

Also the tests seem to be failing on travis currently? but otherwise looking good!


_SHARING_STACK = []
_CURRENT_BACKEND = []

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these just be explicit global variables? Or does the list functionality enable nested sharing with different caches or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_SHARING_STACK indeed allows nesting / separate caches. I'll switch _CURRENT_BACKEND to an explicit global variable.

cache = _SHARING_STACK[-1]
cache['tensor', id(x)] = x
cache['tensor', id(y)] = y

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to cache the tensors on their own, are they every retrieved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required to prevent the input tensors from being garbage collected and their ids being reused, which would lead to an incorrect cache lookup. I'll add a comment to this effect.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok! I hadn't realised that about id.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fun example:

>>> a = np.random.rand(4, 4)
>>> id(a[0])
4560101008
>>> id(a[0])
4585119616
>>> id(a[0])
4585119616
>>> id(a[0])
4585119616
>>> id(a[1])
4585119616
>>> id(a[2])
4585119616

canonical_inputs = ','.join(input_ for input_, _ in canonical)
canonical_equation = _alpha_canonicalize('{}->{}'.format(canonical_inputs, output))
canonical_operands = tuple(d for _, d in canonical)
key = 'einsum', backend, canonical_equation, tuple(map(id, canonical_operands))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this bit enable it so that einsum('ab,bc->ca', x, y) matches e.g. einsum('jk,ij->ki', y, x)? If so, nice!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we're accounting for a little bit of commutativity in the cache lookup. We did this in Pyro to improve our sharing.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for normal ordering. I think this takes care of most edge cases. The only missing edge case that I can think of is if a user does take identical views in different contexts which would require checking of the __array_interface__ syntax. Seems like a stretch to take care of in the first pass.

cache = {}
_SHARING_STACK.append(cache)
yield cache
_SHARING_STACK.pop()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be in a try/finally block so that opt_einsum doesn't remain in 'shared' mode if an error is raised?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The try-finally is automatically performed by @contextlib.contextmanager

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about this specifically for contextmanager? The docs don't mention it and in my tests it is needed otherwise the post yield code is never reached. Testing with these snippets:

from contextlib import contextmanager

things = []

@contextmanager
def TrySomething(value):
    things.append(value)
#     try:
    yield things[-1]
#     finally:
    things.pop() 

then

with TrySomething('hello'):
    raise ValueError

gives things=['hello'] without the try/finally but [] with.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gosh, it looks like I'm mistaken. I'll fix and push. Thanks for catching this!

Copy link
Owner

@dgasmith dgasmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks really good. A couple minor points to look at.

Very excited to have these changes in, they will be very useful.

canonical_inputs = ','.join(input_ for input_, _ in canonical)
canonical_equation = _alpha_canonicalize('{}->{}'.format(canonical_inputs, output))
canonical_operands = tuple(d for _, d in canonical)
key = 'einsum', backend, canonical_equation, tuple(map(id, canonical_operands))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for normal ordering. I think this takes care of most edge cases. The only missing edge case that I can think of is if a user does take identical views in different contexts which would require checking of the __array_interface__ syntax. Seems like a stretch to take care of in the first pass.

cache['tensor', id(d)] = d

# compute a canonical hash, modulo commutativity
inputs, output = equation.split('->')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we guaranteed to have a "->" at this stage? I think so, but worth double checking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, but I'm not super familiar with the entire einsum syntax. At this stage we are guaranteed to have the same equation that was passed to contract(equation, ...).


def tensordot(x, y, axes=2):
backend = _CURRENT_BACKEND[0]
cache = _SHARING_STACK[-1]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we normal order x, y here as well using id as below? With indices explicitly labeled I think this should be ok.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we can't normal-order x,y because their non-contracted dimensions are treated differently: x's are on the left and y's are on the right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One possibility would be to check for (x, y, axes) then (y, x, axes[::-1]) and if its the second just perform the transposition (which will be something like transpose(x, [2, 3, 4, 0, 1]).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, in this first PR I've implemented lookup modulo-commutativity but not modulo-transpose. I agree that we could do this modulo-transpose lookup here and also in the einsum part. But I'm still struggling to get tests to pass even in the current PR, so I'd be happy to limit this first PR to commutativity.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to getting something in, can we flag this and make an issue so that we do not forget about this point?

Copy link
Collaborator

@jcmgray jcmgray Aug 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah absolutely - it's certainly not necessary for this PR. Good to keep track of these things so they can maybe be added later however!

cache = _SHARING_STACK[-1]
cache['tensor', id(x)] = x
cache['tensor', id(y)] = y

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fun example:

>>> a = np.random.rand(4, 4)
>>> id(a[0])
4560101008
>>> id(a[0])
4585119616
>>> id(a[0])
4585119616
>>> id(a[0])
4585119616
>>> id(a[1])
4585119616
>>> id(a[2])
4585119616

print('-' * 40)
print('Without sharing: {} expressions'.format(num_exprs_nosharing))
print('With sharing: {} expressions'.format(num_exprs_sharing))
assert num_exprs_nosharing > num_exprs_sharing
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should consider adding additional tests where we tests the normal ordering:

with shared_intermeidates() as cache:
    a = np.random.rand(4, 4)
    b = np.random.rand(4, 4)
    contract("ab,bc->ac", a, b)

   assert get_cache_key("zc,az->ac", b, a) in cache
   ...

# compute a canonical hash, modulo commutativity
inputs, output = equation.split('->')
inputs = inputs.split(',')
canonical = sorted(zip(inputs, operands), key=lambda x: id(x[1]))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be good to break this code out into another function. I have a feeling that tokenizing contractions will be very useful in the future.

@fritzo
Copy link
Contributor Author

fritzo commented Aug 20, 2018

@jcmgray Maybe as decorators ...

I like this idea! I'll try refactoring and see if I have better luck getting tests to pass with the decorator version.

@codecov-io
Copy link

codecov-io commented Aug 21, 2018

Codecov Report

Merging #43 into master will increase coverage by 0.22%.
The diff coverage is 98.9%.

@fritzo
Copy link
Contributor Author

fritzo commented Aug 21, 2018

@jcmgray I've refactored to use decorators as you suggested, and indeed the changes are now minimally intrusive.

I've also had to decorate to_torch and to_cupy. I don't know how to wrap the other automatically-converted backends, but all backends should work with sharing if the user does manual conversion.

@jcmgray
Copy link
Collaborator

jcmgray commented Aug 21, 2018

Yes thanks very much for that change, it's looking great. My notes at this point are completely optional extensions:

  • With regards to the conversion functions, I can't think of any reason that every to_{backend} function could't be memoized. At the point it would be called tensorflow (in non-eager mode) and theano are just building the expression, so there's no direct speed advantage, but I can imagine that it only helps the compilers if it can see that some tensors are the same object. On the other hand, there is no single function to_backend(x, backend=...) at the moment so it gets a bit messy -- could well be left for the moment.

  • Do we need to think about multi-threading scenarios? With one thread popping the cache early?

  • On a similar note, it might be nice to test the nested sharing if it's an intentional feature.

  • As I think @dgasmith pointed out, it also might be nice to just move the einsum pre-cache canonicalization into its own function in parser since it might be useful elsewhere.

But like I say, already looks good to go from my perspective without these niche things! Docs-wise, maybe added to the readme.md bullet point list and it's own page in the main docs, @dgasmith ?

@dgasmith
Copy link
Owner

dgasmith commented Aug 21, 2018

A separate doc page would be good I think, probably under the "Getting Started" heading and small snippet in the current README would be good. We also should test caching with constant expressions.

@fritzo All of these issues do not need to be tackled in this particular PR, but will need to be addressed before the next release. Please feel free to turn any points into issues to be tackled. I should have some time this weekend to work on a point or two.

LICENSE Outdated
@@ -1,6 +1,7 @@
The MIT License (MIT)

Copyright (c) 2014 Daniel Smith
Copyright (c) 2018 Uber Technologies
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dgasmith my employer requires me to add a copyright line somewhere. Is it ok here, or would you like me to move it to sharing.py or somewhere else?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, could we move this to sharing.py? We should probably look at changing the copyright to the "opt_einsum developers" in the future. I need to look into this angle of things a bit more.

@fritzo
Copy link
Contributor Author

fritzo commented Aug 21, 2018

@jcmgray Do we need to think about multi-threading scenarios?

I really don't know, as I've never used threads in Python. What do you think?

@fritzo
Copy link
Contributor Author

fritzo commented Aug 21, 2018

Ok, I think this should be ready to go when tests pass. I can open issues for remaining improvements.

Copy link
Owner

@dgasmith dgasmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall

LICENSE Outdated
@@ -1,6 +1,7 @@
The MIT License (MIT)

Copyright (c) 2014 Daniel Smith
Copyright (c) 2018 Uber Technologies
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, could we move this to sharing.py? We should probably look at changing the copyright to the "opt_einsum developers" in the future. I need to look into this angle of things a bit more.


>>> with shared_intermediates():
>>> marginals = {output: contract('{}->{}'.format(inputs, output), *factors)
>>> for output in 'abcdef'}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth showing a quick timing comparison on bigger tensors with an explicit demonstration of the contractions done.

assert num_exprs_nosharing['einsum'] > num_exprs_sharing['einsum']


def compute_cost(cache):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be _compute_cost?

cache['tensor', id(tensor)] = tensor


def _memoize(key, fn, *args, **kwargs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add docstrings on the next 4 functions? Not a lot, but just something to indicate their use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@dgasmith
Copy link
Owner

@jcmgray Can you give this a final review as well?

Copy link
Collaborator

@jcmgray jcmgray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looks good to me!

@jcmgray
Copy link
Collaborator

jcmgray commented Aug 21, 2018

@fritzo I really don't know, as I've never used threads in Python. What do you think?

I thought about this a bit, since it's not completely unlikely that someone at some point might parallelize some numeric code which uses opt_einsum sharing in a thread pool or something. It might be inefficient but nothing catastrophic I think -> each thread will add its own cache to the list, but only use whichever is last (so a bit randomly jumbled). Additionally, while any one thread is in a shared context, opt_einsum in all threads will cache intermediaries.

Anyway I can't think of any remedy so really just noting it! I guess it's just a minor cost of the major convenience here of using module global state.

NB. if you supply the same cache to all threads, they will all add it, but will all use it, so that sidesteps the inefficiency to an extent.

@fritzo
Copy link
Contributor Author

fritzo commented Aug 22, 2018

Ok, I've addressed all review comments.

@dgasmith and @jcmgray thank you for your detailed review!

@dgasmith
Copy link
Owner

Everything looks good to me. Thanks for the great PR! We will look at releasing a 2.2 soon to get these changes into production.

@dgasmith dgasmith merged commit 7a83c49 into dgasmith:master Aug 22, 2018
@jcmgray
Copy link
Collaborator

jcmgray commented Aug 22, 2018

This is a really cool addition, thanks @fritzo!

@fritzo
Copy link
Contributor Author

fritzo commented Aug 22, 2018

@dgasmith I'd recommend using squash-and-merge for future merges to avoid leaking non-functioning commits into the commit history. That makes it much easier to git bisect, since you can restrict to merge commits that passed CI tests. I certainly didn't intend some of the intermediate commits in this PR to end up in your commit history.

@dgasmith
Copy link
Owner

I usually do for larger projects, but opt_einsum is small enough that there are not a ton of commits coming in and we haven't declared a git model that we use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Re-use intermediates across various contractions (Common Subexpression Elimination)
4 participants