Skip to content

Commit

Permalink
Small fixes: 'dp' and memory_limit + tensordot axes order (dgasmith#154)
Browse files Browse the repository at this point in the history
* raise error for optimize='dp' when no possible contractions

* use a standardized axes order for tensordot calls

* use infer_backend instead of isinstance

removes numpy import and fixes bug with some conversion backens

* fix comment typo

* dp: fix cost_cap check

* fix empty tensordot axes, raise dp max cost_cap
  • Loading branch information
jcmgray committed Nov 4, 2020
1 parent 0be911f commit 32fa384
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 9 deletions.
23 changes: 14 additions & 9 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from collections import namedtuple
from decimal import Decimal

import numpy as np

from . import backends, blas, helpers, parser, paths, sharing

__all__ = ["contract_path", "contract", "format_const_einsum_str", "ContractExpression", "shape_only"]
Expand Down Expand Up @@ -563,14 +561,21 @@ def _core_contract(operands, contraction_list, backend='auto', evaluate_constant

tensor_result = "".join(s for s in input_left + input_right if s not in idx_rm)

# Find indices to contract over
left_pos, right_pos = [], []
for s in idx_rm:
left_pos.append(input_left.find(s))
right_pos.append(input_right.find(s))
if idx_rm:
# Find indices to contract over
left_pos, right_pos = [], []
for s in idx_rm:
left_pos.append(input_left.find(s))
right_pos.append(input_right.find(s))

# Construct the axes tuples in a canonical order
axes = tuple(zip(*sorted(zip(left_pos, right_pos))))
else:
# Ensure axes is always pair of tuples
axes = ((), ())

# Contract!
new_view = _tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)), backend=backend)
new_view = _tensordot(*tmp_operands, axes=axes, backend=backend)

# Build a new view if needed
if (tensor_result != results_index) or handle_out:
Expand Down Expand Up @@ -757,7 +762,7 @@ def __call__(self, *arrays, **kwargs):
try:
# Check if the backend requires special preparation / calling
# but also ignore non-numpy arrays -> assume user wants same type back
if backends.has_backend(backend) and all(isinstance(x, np.ndarray) for x in arrays):
if backends.has_backend(backend) and all(infer_backend(x) == 'numpy' for x in arrays):
return self._contract_with_conversion(ops, out, backend, evaluate_constants=evaluate_constants)

return self._contract(ops, out, backend, evaluate_constants=evaluate_constants)
Expand Down
5 changes: 5 additions & 0 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import heapq
import itertools
import random
import operator
from collections import Counter, OrderedDict, defaultdict

import numpy as np
Expand Down Expand Up @@ -961,6 +962,7 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
output = set(symbol2int[c] for c in output)
size_dict = {symbol2int[c]: v for c, v in size_dict.items() if c in symbol2int}
size_dict = [size_dict[j] for j in range(len(size_dict))]
naive_cost = len(inputs) * functools.reduce(operator.mul, size_dict)

inputs, inputs_done, inputs_contractions = _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts)

Expand Down Expand Up @@ -1033,6 +1035,9 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
xn, g, all_tensors, inputs, i1_cut_i2_wo_output,
memory_limit, cntrct1, cntrct2)

if (cost_cap > naive_cost) and (len(x[-1]) == 0):
raise RuntimeError("No contraction found for given `memory_limit`.")

# increase cost cap for next iteration:
cost_cap = cost_increment * cost_cap

Expand Down
16 changes: 16 additions & 0 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,22 @@ def test_custom_dp_can_set_cost_cap():
assert info1.opt_cost == info2.opt_cost == info3.opt_cost


def test_dp_errors_when_no_contractions_found():
eq, shapes, size_dict = oe.helpers.rand_equation(10, 3, seed=42, return_size_dict=True)

# first get the actual minimum cost
opt = oe.DynamicProgramming(minimize='size')
path, info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)
mincost = info.largest_intermediate

# check we can still find it without minimizing size explicitly
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost, optimize='dp')

# but check just below this threshold raises
with pytest.raises(RuntimeError):
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost - 1, optimize='dp')


@pytest.mark.parametrize("optimize", ['greedy', 'branch-2', 'branch-all', 'optimal', 'dp'])
def test_can_optimize_outer_products(optimize):
a, b, c = [np.random.randn(10, 10) for _ in range(3)]
Expand Down

0 comments on commit 32fa384

Please sign in to comment.