Skip to content

Commit

Permalink
raise error for optimize='dp' when no possible contractions
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Nov 2, 2020
1 parent 0be911f commit 5123a68
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
5 changes: 5 additions & 0 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import heapq
import itertools
import random
import operator
from collections import Counter, OrderedDict, defaultdict

import numpy as np
Expand Down Expand Up @@ -961,6 +962,7 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
output = set(symbol2int[c] for c in output)
size_dict = {symbol2int[c]: v for c, v in size_dict.items() if c in symbol2int}
size_dict = [size_dict[j] for j in range(len(size_dict))]
naive_cost = functools.reduce(operator.mul, size_dict)

inputs, inputs_done, inputs_contractions = _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts)

Expand Down Expand Up @@ -1033,6 +1035,9 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
xn, g, all_tensors, inputs, i1_cut_i2_wo_output,
memory_limit, cntrct1, cntrct2)

if (cost_cap >= naive_cost) and (len(x[-1]) == 0):
raise RuntimeError("No contraction found for given `memory_limit`.")

# increase cost cap for next iteration:
cost_cap = cost_increment * cost_cap

Expand Down
16 changes: 16 additions & 0 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,22 @@ def test_custom_dp_can_set_cost_cap():
assert info1.opt_cost == info2.opt_cost == info3.opt_cost


def test_dp_errors_when_no_contractions_found():
eq, shapes, size_dict = oe.helpers.rand_equation(10, 3, seed=42, return_size_dict=True)

# first get the actual minimum cost
opt = oe.DynamicProgramming(minimize='size')
path, info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)
mincost = info.largest_intermediate

# check we can still find it without minimizing size explicitly
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost, optimize='dp')

# but check just below this threshold raises
with pytest.raises(RuntimeError):
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost - 1, optimize='dp')


@pytest.mark.parametrize("optimize", ['greedy', 'branch-2', 'branch-all', 'optimal', 'dp'])
def test_can_optimize_outer_products(optimize):
a, b, c = [np.random.randn(10, 10) for _ in range(3)]
Expand Down

0 comments on commit 5123a68

Please sign in to comment.