New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement shared_intermediates context manager #43
Conversation
opt_einsum/contract.py
Outdated
"The internal error was: '%s'" % original_msg, ) | ||
err.args = msg | ||
raise | ||
with handle_sharing(backend) as backend: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really the only line changed; I've simply moved the following lines inside this context.
cc @eb8680 |
Nice, looking forward to testing this and putting it to use in the wild! I do have a suggestion to simplify this a bit further (sorry for suggesting all these refactors!). Basically, instead of making it a backend, and then storing the real desired backend in a temporary and using the e.g. # opt_einsum/shared.py
def shared_intermediaries(...):
# and other stuff
def tensordot_cache_wrap(tensordot_fn):
def tensordot_cached(x, y, axes, backend):
if not _SHARING_STACK:
return tensordot_fn(x, y, axes, backend)
... # parse and get global cache etc
key = (id(x), id(y), axes, backend)
if key not in cache:
cache[key] = tensordot_fn(x, y, axes, backend)
return cache[key]
return tensordot_cached The advantage is that in @tensordot_cache_wrap
def _tensordot(x, y, axes, backend)
.... I.e. the caching logic is really factored out and Also the tests seem to be failing on travis currently? but otherwise looking good! |
opt_einsum/backends/shared.py
Outdated
|
||
_SHARING_STACK = [] | ||
_CURRENT_BACKEND = [] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could these just be explicit global
variables? Or does the list functionality enable nested sharing
with different caches or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_SHARING_STACK
indeed allows nesting / separate caches. I'll switch _CURRENT_BACKEND
to an explicit global variable.
opt_einsum/backends/shared.py
Outdated
cache = _SHARING_STACK[-1] | ||
cache['tensor', id(x)] = x | ||
cache['tensor', id(y)] = y | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to cache the tensors on their own, are they every retrieved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is required to prevent the input tensors from being garbage collected and their id
s being reused, which would lead to an incorrect cache lookup. I'll add a comment to this effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok! I hadn't realised that about id
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fun example:
>>> a = np.random.rand(4, 4)
>>> id(a[0])
4560101008
>>> id(a[0])
4585119616
>>> id(a[0])
4585119616
>>> id(a[0])
4585119616
>>> id(a[1])
4585119616
>>> id(a[2])
4585119616
opt_einsum/backends/shared.py
Outdated
canonical_inputs = ','.join(input_ for input_, _ in canonical) | ||
canonical_equation = _alpha_canonicalize('{}->{}'.format(canonical_inputs, output)) | ||
canonical_operands = tuple(d for _, d in canonical) | ||
key = 'einsum', backend, canonical_equation, tuple(map(id, canonical_operands)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this bit enable it so that einsum('ab,bc->ca', x, y)
matches e.g. einsum('jk,ij->ki', y, x)
? If so, nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we're accounting for a little bit of commutativity in the cache lookup. We did this in Pyro to improve our sharing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for normal ordering. I think this takes care of most edge cases. The only missing edge case that I can think of is if a user does take identical views in different contexts which would require checking of the __array_interface__
syntax. Seems like a stretch to take care of in the first pass.
opt_einsum/backends/shared.py
Outdated
cache = {} | ||
_SHARING_STACK.append(cache) | ||
yield cache | ||
_SHARING_STACK.pop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be in a try/finally block so that opt_einsum
doesn't remain in 'shared' mode if an error is raised?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The try-finally is automatically performed by @contextlib.contextmanager
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure about this specifically for contextmanager
? The docs don't mention it and in my tests it is needed otherwise the post yield code is never reached. Testing with these snippets:
from contextlib import contextmanager
things = []
@contextmanager
def TrySomething(value):
things.append(value)
# try:
yield things[-1]
# finally:
things.pop()
then
with TrySomething('hello'):
raise ValueError
gives things=['hello']
without the try/finally but []
with.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gosh, it looks like I'm mistaken. I'll fix and push. Thanks for catching this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks really good. A couple minor points to look at.
Very excited to have these changes in, they will be very useful.
opt_einsum/backends/shared.py
Outdated
canonical_inputs = ','.join(input_ for input_, _ in canonical) | ||
canonical_equation = _alpha_canonicalize('{}->{}'.format(canonical_inputs, output)) | ||
canonical_operands = tuple(d for _, d in canonical) | ||
key = 'einsum', backend, canonical_equation, tuple(map(id, canonical_operands)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for normal ordering. I think this takes care of most edge cases. The only missing edge case that I can think of is if a user does take identical views in different contexts which would require checking of the __array_interface__
syntax. Seems like a stretch to take care of in the first pass.
opt_einsum/backends/shared.py
Outdated
cache['tensor', id(d)] = d | ||
|
||
# compute a canonical hash, modulo commutativity | ||
inputs, output = equation.split('->') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we guaranteed to have a "->" at this stage? I think so, but worth double checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe so, but I'm not super familiar with the entire einsum
syntax. At this stage we are guaranteed to have the same equation
that was passed to contract(equation, ...)
.
opt_einsum/backends/shared.py
Outdated
|
||
def tensordot(x, y, axes=2): | ||
backend = _CURRENT_BACKEND[0] | ||
cache = _SHARING_STACK[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we normal order x, y
here as well using id
as below? With indices explicitly labeled I think this should be ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can't normal-order x,y
because their non-contracted dimensions are treated differently: x
's are on the left and y
's are on the right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One possibility would be to check for (x, y, axes) then (y, x, axes[::-1]) and if its the second just perform the transposition (which will be something like transpose(x, [2, 3, 4, 0, 1])
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, in this first PR I've implemented lookup modulo-commutativity but not modulo-transpose. I agree that we could do this modulo-transpose lookup here and also in the einsum
part. But I'm still struggling to get tests to pass even in the current PR, so I'd be happy to limit this first PR to commutativity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to getting something in, can we flag this and make an issue so that we do not forget about this point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah absolutely - it's certainly not necessary for this PR. Good to keep track of these things so they can maybe be added later however!
opt_einsum/backends/shared.py
Outdated
cache = _SHARING_STACK[-1] | ||
cache['tensor', id(x)] = x | ||
cache['tensor', id(y)] = y | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fun example:
>>> a = np.random.rand(4, 4)
>>> id(a[0])
4560101008
>>> id(a[0])
4585119616
>>> id(a[0])
4585119616
>>> id(a[0])
4585119616
>>> id(a[1])
4585119616
>>> id(a[2])
4585119616
print('-' * 40) | ||
print('Without sharing: {} expressions'.format(num_exprs_nosharing)) | ||
print('With sharing: {} expressions'.format(num_exprs_sharing)) | ||
assert num_exprs_nosharing > num_exprs_sharing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should consider adding additional tests where we tests the normal ordering:
with shared_intermeidates() as cache:
a = np.random.rand(4, 4)
b = np.random.rand(4, 4)
contract("ab,bc->ac", a, b)
assert get_cache_key("zc,az->ac", b, a) in cache
...
opt_einsum/backends/shared.py
Outdated
# compute a canonical hash, modulo commutativity | ||
inputs, output = equation.split('->') | ||
inputs = inputs.split(',') | ||
canonical = sorted(zip(inputs, operands), key=lambda x: id(x[1])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be good to break this code out into another function. I have a feeling that tokenizing contractions will be very useful in the future.
I like this idea! I'll try refactoring and see if I have better luck getting tests to pass with the decorator version. |
@jcmgray I've refactored to use decorators as you suggested, and indeed the changes are now minimally intrusive. I've also had to decorate |
Yes thanks very much for that change, it's looking great. My notes at this point are completely optional extensions:
But like I say, already looks good to go from my perspective without these niche things! Docs-wise, maybe added to the |
A separate doc page would be good I think, probably under the "Getting Started" heading and small snippet in the current README would be good. We also should test caching with constant expressions. @fritzo All of these issues do not need to be tackled in this particular PR, but will need to be addressed before the next release. Please feel free to turn any points into issues to be tackled. I should have some time this weekend to work on a point or two. |
LICENSE
Outdated
@@ -1,6 +1,7 @@ | |||
The MIT License (MIT) | |||
|
|||
Copyright (c) 2014 Daniel Smith | |||
Copyright (c) 2018 Uber Technologies |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dgasmith my employer requires me to add a copyright line somewhere. Is it ok here, or would you like me to move it to sharing.py or somewhere else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, could we move this to sharing.py
? We should probably look at changing the copyright to the "opt_einsum developers" in the future. I need to look into this angle of things a bit more.
I really don't know, as I've never used threads in Python. What do you think? |
Ok, I think this should be ready to go when tests pass. I can open issues for remaining improvements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall
LICENSE
Outdated
@@ -1,6 +1,7 @@ | |||
The MIT License (MIT) | |||
|
|||
Copyright (c) 2014 Daniel Smith | |||
Copyright (c) 2018 Uber Technologies |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, could we move this to sharing.py
? We should probably look at changing the copyright to the "opt_einsum developers" in the future. I need to look into this angle of things a bit more.
|
||
>>> with shared_intermediates(): | ||
>>> marginals = {output: contract('{}->{}'.format(inputs, output), *factors) | ||
>>> for output in 'abcdef'} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth showing a quick timing comparison on bigger tensors with an explicit demonstration of the contractions done.
opt_einsum/tests/test_sharing.py
Outdated
assert num_exprs_nosharing['einsum'] > num_exprs_sharing['einsum'] | ||
|
||
|
||
def compute_cost(cache): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be _compute_cost
?
cache['tensor', id(tensor)] = tensor | ||
|
||
|
||
def _memoize(key, fn, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add docstrings on the next 4 functions? Not a lot, but just something to indicate their use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
@jcmgray Can you give this a final review as well? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All looks good to me!
I thought about this a bit, since it's not completely unlikely that someone at some point might parallelize some numeric code which uses Anyway I can't think of any remedy so really just noting it! I guess it's just a minor cost of the major convenience here of using module global state. NB. if you supply the same cache to all threads, they will all add it, but will all use it, so that sidesteps the inefficiency to an extent. |
Everything looks good to me. Thanks for the great PR! We will look at releasing a 2.2 soon to get these changes into production. |
This is a really cool addition, thanks @fritzo! |
@dgasmith I'd recommend using squash-and-merge for future merges to avoid leaking non-functioning commits into the commit history. That makes it much easier to |
I usually do for larger projects, but |
Resolves #9
Description
This implements a
shared_intermediates
context manager, within which multiple calls tocontract
will share intermediate computations. This is very useful e.g. in message passing algorithms. The first draft of this PR is copied directly from Pyro, where we are using sharing for probabilistic inference in graphical models.The implementation uses a special internal backend
opt_einsum.backends.shared
. When a sharing context is openeda special
handle_sharing()
context is activated inside thecontract
function, temporarily switching to theshared
backend, and setting the original backendfoo
in a globalshared._CURRENT_BACKEND
variable. The sharing backend then performs all the originalfoo
operations and also memoizes them.One design choice is to memoize by
id()
rather than by value. This makes sense from a computational perspective (equality comparison is expensive), but requires a bit more care by users.A second design choice is to expose the memoization
cache
to users. This makes it easy to share intermediates by passing the cache around, as is necessary when it is difficult to perform all computation in a context:Todos
numpy
->torch
->numpy
so args can be compared by id._alpha_canonicalize()
toparser.py
Questions
handle_sharing()
context should be moved? I've attempted to insert it intoContractExpression.__call__()
, but I'm a bit lost.Status