Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement shared_intermediates context manager #43

Merged
merged 19 commits into from Aug 22, 2018
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions LICENSE
@@ -1,6 +1,7 @@
The MIT License (MIT)

Copyright (c) 2014 Daniel Smith
Copyright (c) 2018 Uber Technologies
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dgasmith my employer requires me to add a copyright line somewhere. Is it ok here, or would you like me to move it to sharing.py or somewhere else?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, could we move this to sharing.py? We should probably look at changing the copyright to the "opt_einsum developers" in the future. I need to look into this angle of things a bit more.


Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -19,6 +19,7 @@ As well as [`opt_einsum.contract`](https://optimized-einsum.readthedocs.io/en/la
* Perform the contractions with many [different backends](http://optimized-einsum.readthedocs.io/en/latest/backends.html), including on the GPU and with libraries such as [TensorFlow](https://www.tensorflow.org) and [PyTorch](https://pytorch.org).
* Generate [reusable expressions](http://optimized-einsum.readthedocs.io/en/latest/reusing_paths.html), potentially with [constant tensors](http://optimized-einsum.readthedocs.io/en/latest/reusing_paths.html#specifying-constants), that can be compiled for greater performance.
* Use an arbitrary number of indices to find contractions for [hundreds or even thousands of tensors](http://optimized-einsum.readthedocs.io/en/latest/ex_large_expr_with_greedy.html).
* Share [intermediate computations](http://optimized-einsum.readthedocs.io/en/latest/sharing_intermediates.html) among multiple contractions.

## Quick tutorial
Einsum is a powerful function for contracting tensors of arbitrary dimension and index.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -123,7 +123,7 @@ Table of Contents
install
backends
reusing_paths

sharing_intermediates

.. toctree::
:maxdepth: 1
Expand Down
36 changes: 36 additions & 0 deletions docs/source/sharing_intermediates.rst
@@ -0,0 +1,36 @@
=====================
Sharing Intermediates
=====================

If you want to compute multiple similar contractions with common terms, you can embed them in a :func:`~opt_einsum.shared_intermediates` context. Computations of subexpressions in this context will be memoized, and will be garbage collected when the contexts exits.

For example, suppose we want to compute marginals at each point in a factor chain:

.. code:: python

>>> inputs = 'ab,bc,cd,de,ef'
>>> factors = [np.random.rand(4, 4) for _ in range(5)]
>>> marginals = {output: contract('{}->{}'.format(inputs, output), *factors)
>>> for output in 'abcdef'}

To share this computation, we can perform all contractions in a shared context

.. code:: python

>>> with shared_intermediates():
>>> marginals = {output: contract('{}->{}'.format(inputs, output), *factors)
>>> for output in 'abcdef'}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth showing a quick timing comparison on bigger tensors with an explicit demonstration of the contractions done.


If it is difficult to fit your code into a context, you can instead save the sharing cache for later reuse.

.. code:: python

>>> with shared_intermediates() as cache: # create a cache
>>> pass
>>> marginals = {}
>>> for output in 'abcdef':
>>> with shared_intermediates(cache): # reuse a common cache
>>> marginals[output] = contract('{}->{}'.format(inputs, output), *factors)
>>> del cache # garbage collect intermediates

Note that sharing contexts can be nested, so it is safe to to use :func:`~opt_einsum.shared_intermediates` in library code without leaking intermediates into user caches.
1 change: 1 addition & 0 deletions opt_einsum/__init__.py
Expand Up @@ -4,6 +4,7 @@

from .contract import contract, contract_path, contract_expression
from .parser import get_symbol
from .sharing import shared_intermediates
from . import paths
from . import blas
from . import helpers
Expand Down
2 changes: 2 additions & 0 deletions opt_einsum/backends/cupy.py
Expand Up @@ -4,8 +4,10 @@

from __future__ import absolute_import
import numpy as np
from ..sharing import to_backend_cache_wrap


@to_backend_cache_wrap
def to_cupy(array): # pragma: no cover
import cupy

Expand Down
2 changes: 2 additions & 0 deletions opt_einsum/backends/torch.py
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from ..parser import convert_to_valid_einsum_chars, einsum_symbols_base
from ..sharing import to_backend_cache_wrap

_TORCH_DEVICE = None

Expand Down Expand Up @@ -84,6 +85,7 @@ def tensordot(x, y, axes=2):
return einsum(einsum_str, x, y)


@to_backend_cache_wrap
def to_torch(array):
torch, device = _get_torch_and_device()

Expand Down
4 changes: 4 additions & 0 deletions opt_einsum/contract.py
Expand Up @@ -9,6 +9,7 @@
from . import helpers
from . import parser
from . import paths
from . import sharing


def contract_path(*operands, **kwargs):
Expand Down Expand Up @@ -267,6 +268,7 @@ def contract_path(*operands, **kwargs):
return path, path_print


@sharing.einsum_cache_wrap
def _einsum(*operands, **kwargs):
"""Base einsum, but with pre-parse for valid characters if string given.
"""
Expand All @@ -289,6 +291,7 @@ def _einsum(*operands, **kwargs):
return fn(einsum_str, *operands, **kwargs)


@sharing.transpose_cache_wrap
def _transpose(x, axes, backend='numpy'):
"""Base transpose.
"""
Expand All @@ -300,6 +303,7 @@ def _transpose(x, axes, backend='numpy'):
return fn(x, axes)


@sharing.tensordot_cache_wrap
def _tensordot(x, y, axes, backend='numpy'):
"""Base tensordot.
"""
Expand Down
15 changes: 14 additions & 1 deletion opt_einsum/parser.py
Expand Up @@ -4,8 +4,9 @@
A functionally equivalent parser of the numpy.einsum input parser
"""

import numpy as np
from collections import OrderedDict

import numpy as np

einsum_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'

Expand Down Expand Up @@ -65,6 +66,18 @@ def convert_to_valid_einsum_chars(einsum_str):
return "".join(replacer.get(x, x) for x in einsum_str)


def alpha_canonicalize(equation):
"""Alpha convert an equation in an order-independent canonical way.
"""
rename = OrderedDict()
for name in equation:
if name in '.,->':
continue
if name not in rename:
rename[name] = get_symbol(len(rename))
return ''.join(rename.get(x, x) for x in equation)


def find_output_str(subscripts):
"""Find the output string for the inputs ``susbcripts``.
"""
Expand Down
135 changes: 135 additions & 0 deletions opt_einsum/sharing.py
@@ -0,0 +1,135 @@
import contextlib
import functools
import numbers
from collections import Counter

from .parser import alpha_canonicalize, parse_einsum_input

_SHARING_STACK = []


@contextlib.contextmanager
def shared_intermediates(cache=None):
"""Context in which contract intermediate results are shared.

Note that intermediate computations will not be garbage collected until
1. this context exits, and
2. the yielded cache is garbage collected (if it was captured).

Parameters
----------
cache : dict
If specified, a user-stored dict in which intermediate results will
be stored. This can be used to interleave sharing contexts.

Returns
-------
cache : dict
A dictionary in which sharing results are stored. If ignored,
sharing results will be garbage collected when this context is
exited. This dict can be passed to another context to resume
sharing.
"""
if cache is None:
cache = {}
try:
_SHARING_STACK.append(cache)
yield cache
finally:
_SHARING_STACK.pop()


def count_cached_ops(cache):
"""Returns a counter of the types of each op in the cache.
This is useful for profiling to increase sharing.
"""
return Counter(key[0] for key in cache.keys())


def _save_tensors(*tensors):
"""Save tensors in the cache to prevent their ids from being recycled.
This is needed to prevent false cache lookups.
"""
cache = _SHARING_STACK[-1]
for tensor in tensors:
cache['tensor', id(tensor)] = tensor


def _memoize(key, fn, *args, **kwargs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add docstrings on the next 4 functions? Not a lot, but just something to indicate their use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

cache = _SHARING_STACK[-1]
if key in cache:
return cache[key]
result = fn(*args, **kwargs)
cache[key] = result
return result


def transpose_cache_wrap(transpose):

@functools.wraps(transpose)
def cached_transpose(a, axes, backend='numpy'):
if not _SHARING_STACK:
return transpose(a, axes, backend=backend)

# hash by axes
_save_tensors(a)
axes = tuple(axes)
key = 'transpose', backend, id(a), axes
return _memoize(key, transpose, a, axes, backend=backend)

return cached_transpose


def tensordot_cache_wrap(tensordot):

@functools.wraps(tensordot)
def cached_tensordot(x, y, axes=2, backend='numpy'):
if not _SHARING_STACK:
return tensordot(x, y, axes, backend=backend)

# hash based on the (axes_x,axes_y) form of axes
_save_tensors(x, y)
if isinstance(axes, numbers.Number):
axes = list(range(len(x.shape)))[len(x.shape) - axes:], list(range(len(y.shape)))[:axes]
axes = tuple(axes[0]), tuple(axes[1])
key = 'tensordot', backend, id(x), id(y), axes
return _memoize(key, tensordot, x, y, axes, backend=backend)

return cached_tensordot


def einsum_cache_wrap(einsum):

@functools.wraps(einsum)
def cached_einsum(*args, **kwargs):
if not _SHARING_STACK:
return einsum(*args, **kwargs)

# hash modulo commutativity by computing a canonical ordering and names
backend = kwargs.pop('backend', 'numpy')
equation = args[0]
inputs, output, operands = parse_einsum_input(args)
inputs = inputs.split(',')
_save_tensors(*operands)
canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1])
canonical_ids = tuple(id_ for _, id_ in canonical)
canonical_inputs = ','.join(input_ for input_, _ in canonical)
canonical_equation = alpha_canonicalize('{}->{}'.format(canonical_inputs, output))
key = 'einsum', backend, canonical_equation, canonical_ids
return _memoize(key, einsum, equation, *operands, backend=backend)

return cached_einsum


def to_backend_cache_wrap(to_backend):

@functools.wraps(to_backend)
def cached_to_backend(array):
if not _SHARING_STACK:
return to_backend(array)

# hash by id
key = to_backend.__name__, id(array)
return _memoize(key, to_backend, array)

return cached_to_backend