New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement shared_intermediates context manager #43
Changes from 18 commits
dd6b34a
560726e
589950b
5ff111f
72b9ef8
d208e3d
f5b53fc
472dbff
4fb5ae5
e771973
632c266
b3f2394
f0bebe5
a6983c6
e96a8e4
911eca4
86d4550
0139c75
e43d68b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,7 +123,7 @@ Table of Contents | |
install | ||
backends | ||
reusing_paths | ||
|
||
sharing_intermediates | ||
|
||
.. toctree:: | ||
:maxdepth: 1 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
===================== | ||
Sharing Intermediates | ||
===================== | ||
|
||
If you want to compute multiple similar contractions with common terms, you can embed them in a :func:`~opt_einsum.shared_intermediates` context. Computations of subexpressions in this context will be memoized, and will be garbage collected when the contexts exits. | ||
|
||
For example, suppose we want to compute marginals at each point in a factor chain: | ||
|
||
.. code:: python | ||
|
||
>>> inputs = 'ab,bc,cd,de,ef' | ||
>>> factors = [np.random.rand(4, 4) for _ in range(5)] | ||
>>> marginals = {output: contract('{}->{}'.format(inputs, output), *factors) | ||
>>> for output in 'abcdef'} | ||
|
||
To share this computation, we can perform all contractions in a shared context | ||
|
||
.. code:: python | ||
|
||
>>> with shared_intermediates(): | ||
>>> marginals = {output: contract('{}->{}'.format(inputs, output), *factors) | ||
>>> for output in 'abcdef'} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be worth showing a quick timing comparison on bigger tensors with an explicit demonstration of the contractions done. |
||
|
||
If it is difficult to fit your code into a context, you can instead save the sharing cache for later reuse. | ||
|
||
.. code:: python | ||
|
||
>>> with shared_intermediates() as cache: # create a cache | ||
>>> pass | ||
>>> marginals = {} | ||
>>> for output in 'abcdef': | ||
>>> with shared_intermediates(cache): # reuse a common cache | ||
>>> marginals[output] = contract('{}->{}'.format(inputs, output), *factors) | ||
>>> del cache # garbage collect intermediates | ||
|
||
Note that sharing contexts can be nested, so it is safe to to use :func:`~opt_einsum.shared_intermediates` in library code without leaking intermediates into user caches. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import contextlib | ||
import functools | ||
import numbers | ||
from collections import Counter | ||
|
||
from .parser import alpha_canonicalize, parse_einsum_input | ||
|
||
_SHARING_STACK = [] | ||
|
||
|
||
@contextlib.contextmanager | ||
def shared_intermediates(cache=None): | ||
"""Context in which contract intermediate results are shared. | ||
|
||
Note that intermediate computations will not be garbage collected until | ||
1. this context exits, and | ||
2. the yielded cache is garbage collected (if it was captured). | ||
|
||
Parameters | ||
---------- | ||
cache : dict | ||
If specified, a user-stored dict in which intermediate results will | ||
be stored. This can be used to interleave sharing contexts. | ||
|
||
Returns | ||
------- | ||
cache : dict | ||
A dictionary in which sharing results are stored. If ignored, | ||
sharing results will be garbage collected when this context is | ||
exited. This dict can be passed to another context to resume | ||
sharing. | ||
""" | ||
if cache is None: | ||
cache = {} | ||
try: | ||
_SHARING_STACK.append(cache) | ||
yield cache | ||
finally: | ||
_SHARING_STACK.pop() | ||
|
||
|
||
def count_cached_ops(cache): | ||
"""Returns a counter of the types of each op in the cache. | ||
This is useful for profiling to increase sharing. | ||
""" | ||
return Counter(key[0] for key in cache.keys()) | ||
|
||
|
||
def _save_tensors(*tensors): | ||
"""Save tensors in the cache to prevent their ids from being recycled. | ||
This is needed to prevent false cache lookups. | ||
""" | ||
cache = _SHARING_STACK[-1] | ||
for tensor in tensors: | ||
cache['tensor', id(tensor)] = tensor | ||
|
||
|
||
def _memoize(key, fn, *args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add docstrings on the next 4 functions? Not a lot, but just something to indicate their use. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
cache = _SHARING_STACK[-1] | ||
if key in cache: | ||
return cache[key] | ||
result = fn(*args, **kwargs) | ||
cache[key] = result | ||
return result | ||
|
||
|
||
def transpose_cache_wrap(transpose): | ||
|
||
@functools.wraps(transpose) | ||
def cached_transpose(a, axes, backend='numpy'): | ||
if not _SHARING_STACK: | ||
return transpose(a, axes, backend=backend) | ||
|
||
# hash by axes | ||
_save_tensors(a) | ||
axes = tuple(axes) | ||
key = 'transpose', backend, id(a), axes | ||
return _memoize(key, transpose, a, axes, backend=backend) | ||
|
||
return cached_transpose | ||
|
||
|
||
def tensordot_cache_wrap(tensordot): | ||
|
||
@functools.wraps(tensordot) | ||
def cached_tensordot(x, y, axes=2, backend='numpy'): | ||
if not _SHARING_STACK: | ||
return tensordot(x, y, axes, backend=backend) | ||
|
||
# hash based on the (axes_x,axes_y) form of axes | ||
_save_tensors(x, y) | ||
if isinstance(axes, numbers.Number): | ||
axes = list(range(len(x.shape)))[len(x.shape) - axes:], list(range(len(y.shape)))[:axes] | ||
axes = tuple(axes[0]), tuple(axes[1]) | ||
key = 'tensordot', backend, id(x), id(y), axes | ||
return _memoize(key, tensordot, x, y, axes, backend=backend) | ||
|
||
return cached_tensordot | ||
|
||
|
||
def einsum_cache_wrap(einsum): | ||
|
||
@functools.wraps(einsum) | ||
def cached_einsum(*args, **kwargs): | ||
if not _SHARING_STACK: | ||
return einsum(*args, **kwargs) | ||
|
||
# hash modulo commutativity by computing a canonical ordering and names | ||
backend = kwargs.pop('backend', 'numpy') | ||
equation = args[0] | ||
inputs, output, operands = parse_einsum_input(args) | ||
inputs = inputs.split(',') | ||
_save_tensors(*operands) | ||
canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1]) | ||
canonical_ids = tuple(id_ for _, id_ in canonical) | ||
canonical_inputs = ','.join(input_ for input_, _ in canonical) | ||
canonical_equation = alpha_canonicalize('{}->{}'.format(canonical_inputs, output)) | ||
key = 'einsum', backend, canonical_equation, canonical_ids | ||
return _memoize(key, einsum, equation, *operands, backend=backend) | ||
|
||
return cached_einsum | ||
|
||
|
||
def to_backend_cache_wrap(to_backend): | ||
|
||
@functools.wraps(to_backend) | ||
def cached_to_backend(array): | ||
if not _SHARING_STACK: | ||
return to_backend(array) | ||
|
||
# hash by id | ||
key = to_backend.__name__, id(array) | ||
return _memoize(key, to_backend, array) | ||
|
||
return cached_to_backend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dgasmith my employer requires me to add a copyright line somewhere. Is it ok here, or would you like me to move it to sharing.py or somewhere else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, could we move this to
sharing.py
? We should probably look at changing the copyright to the "opt_einsum developers" in the future. I need to look into this angle of things a bit more.