Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement shared_intermediates context manager #43

Merged
merged 19 commits into from Aug 22, 2018
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions LICENSE
@@ -1,6 +1,7 @@
The MIT License (MIT)

Copyright (c) 2014 Daniel Smith
Copyright (c) 2018 Uber Technologies
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dgasmith my employer requires me to add a copyright line somewhere. Is it ok here, or would you like me to move it to sharing.py or somewhere else?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, could we move this to sharing.py? We should probably look at changing the copyright to the "opt_einsum developers" in the future. I need to look into this angle of things a bit more.


Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
1 change: 1 addition & 0 deletions opt_einsum/__init__.py
Expand Up @@ -4,6 +4,7 @@

from .contract import contract, contract_path, contract_expression
from .parser import get_symbol
from .sharing import shared_intermediates
from . import paths
from . import blas
from . import helpers
Expand Down
2 changes: 2 additions & 0 deletions opt_einsum/backends/cupy.py
Expand Up @@ -4,8 +4,10 @@

from __future__ import absolute_import
import numpy as np
from ..sharing import to_backend_cache_wrap


@to_backend_cache_wrap
def to_cupy(array): # pragma: no cover
import cupy

Expand Down
2 changes: 2 additions & 0 deletions opt_einsum/backends/torch.py
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from ..parser import convert_to_valid_einsum_chars, einsum_symbols_base
from ..sharing import to_backend_cache_wrap

_TORCH_DEVICE = None

Expand Down Expand Up @@ -84,6 +85,7 @@ def tensordot(x, y, axes=2):
return einsum(einsum_str, x, y)


@to_backend_cache_wrap
def to_torch(array):
torch, device = _get_torch_and_device()

Expand Down
4 changes: 4 additions & 0 deletions opt_einsum/contract.py
Expand Up @@ -9,6 +9,7 @@
from . import helpers
from . import parser
from . import paths
from . import sharing


def contract_path(*operands, **kwargs):
Expand Down Expand Up @@ -267,6 +268,7 @@ def contract_path(*operands, **kwargs):
return path, path_print


@sharing.einsum_cache_wrap
def _einsum(*operands, **kwargs):
"""Base einsum, but with pre-parse for valid characters if string given.
"""
Expand All @@ -289,6 +291,7 @@ def _einsum(*operands, **kwargs):
return fn(einsum_str, *operands, **kwargs)


@sharing.transpose_cache_wrap
def _transpose(x, axes, backend='numpy'):
"""Base transpose.
"""
Expand All @@ -300,6 +303,7 @@ def _transpose(x, axes, backend='numpy'):
return fn(x, axes)


@sharing.tensordot_cache_wrap
def _tensordot(x, y, axes, backend='numpy'):
"""Base tensordot.
"""
Expand Down
147 changes: 147 additions & 0 deletions opt_einsum/sharing.py
@@ -0,0 +1,147 @@
import contextlib
import functools
import numbers
from collections import Counter, OrderedDict

from .parser import get_symbol, parse_einsum_input

_SHARING_STACK = []


@contextlib.contextmanager
def shared_intermediates(cache=None):
"""Context in which contract intermediate results are shared.

Note that intermediate computations will not be garbage collected until
1. this context exits, and
2. the yielded cache is garbage collected (if it was captured).

Parameters
----------
cache : dict
If specified, a user-stored dict in which intermediate results will
be stored. This can be used to interleave sharing contexts.

Returns
-------
cache : dict
A dictionary in which sharing results are stored. If ignored,
sharing results will be garbage collected when this context is
exited. This dict can be passed to another context to resume
sharing.
"""
if cache is None:
cache = {}
try:
_SHARING_STACK.append(cache)
yield cache
finally:
_SHARING_STACK.pop()


def count_cached_ops(cache):
"""Returns a counter of the types of each op in the cache.
This is useful for profiling to increase sharing.
"""
return Counter(key[0] for key in cache.keys())


def _alpha_canonicalize(equation):
"""Alpha convert in an order-independent canonical way.
"""
rename = OrderedDict()
for name in equation:
if name in ',->':
continue
if name not in rename:
rename[name] = get_symbol(len(rename))
return ''.join(rename.get(x, x) for x in equation)


def _save_tensors(*tensors):
"""Save tensors in the cache to prevent their ids from being recycled.
This is needed to prevent false cache lookups.
"""
cache = _SHARING_STACK[-1]
for tensor in tensors:
cache['tensor', id(tensor)] = tensor


def _memoize(key, fn, *args, **kwargs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add docstrings on the next 4 functions? Not a lot, but just something to indicate their use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

cache = _SHARING_STACK[-1]
if key in cache:
return cache[key]
result = fn(*args, **kwargs)
cache[key] = result
return result


def transpose_cache_wrap(transpose):

@functools.wraps(transpose)
def cached_transpose(a, axes, backend='numpy'):
if not _SHARING_STACK:
return transpose(a, axes, backend=backend)

# hash by axes
_save_tensors(a)
axes = tuple(axes)
key = 'transpose', backend, id(a), axes
return _memoize(key, transpose, a, axes, backend=backend)

return cached_transpose


def tensordot_cache_wrap(tensordot):

@functools.wraps(tensordot)
def cached_tensordot(x, y, axes=2, backend='numpy'):
if not _SHARING_STACK:
return tensordot(x, y, axes, backend=backend)

# hash based on the (axes_x,axes_y) form of axes
_save_tensors(x, y)
if isinstance(axes, numbers.Number):
axes = list(range(len(x.shape)))[len(x.shape) - axes:], list(range(len(y.shape)))[:axes]
axes = tuple(axes[0]), tuple(axes[1])
key = 'tensordot', backend, id(x), id(y), axes
return _memoize(key, tensordot, x, y, axes, backend=backend)

return cached_tensordot


def einsum_cache_wrap(einsum):

@functools.wraps(einsum)
def cached_einsum(*args, **kwargs):
if not _SHARING_STACK:
return einsum(*args, **kwargs)

# hash modulo commutativity by computing a canonical ordering and names
backend = kwargs.pop('backend', 'numpy')
equation = args[0]
inputs, output, operands = parse_einsum_input(args)
inputs = inputs.split(',')
_save_tensors(*operands)
canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1])
canonical_ids = tuple(id_ for _, id_ in canonical)
canonical_inputs = ','.join(input_ for input_, _ in canonical)
canonical_equation = _alpha_canonicalize('{}->{}'.format(canonical_inputs, output))
key = 'einsum', backend, canonical_equation, canonical_ids
return _memoize(key, einsum, equation, *operands, backend=backend)

return cached_einsum


def to_backend_cache_wrap(to_backend):

@functools.wraps(to_backend)
def cached_to_backend(array):
if not _SHARING_STACK:
return to_backend(array)

# hash by id
key = to_backend.__name__, id(array)
return _memoize(key, to_backend, array)

return cached_to_backend