Skip to content

Commit

Permalink
stateful path optimizers check args + minor doc tweaks (dgasmith#141)
Browse files Browse the repository at this point in the history
* update contract docstring

also rename internal variable dimension_dict -> size_dict for consistency

* stateful optimizers: check args against first call

i.e. BranchBound and RandomGreedy

* test the path optimizer arg check

* make arg check private and add note about init
  • Loading branch information
jcmgray committed May 13, 2020
1 parent 95c1c80 commit e88983f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 17 deletions.
45 changes: 28 additions & 17 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def contract_path(*operands, **kwargs):
indices = set(input_subscripts.replace(',', ''))

# Get length of each unique dimension and ensure all dimensions are correct
dimension_dict = {}
size_dict = {}
for tnum, term in enumerate(input_list):
sh = input_shps[tnum]

Expand All @@ -224,18 +224,18 @@ def contract_path(*operands, **kwargs):
for cnum, char in enumerate(term):
dim = int(sh[cnum])

if char in dimension_dict:
if char in size_dict:
# For broadcasting cases we always want the largest dim size
if dimension_dict[char] == 1:
dimension_dict[char] = dim
elif dim not in (1, dimension_dict[char]):
if size_dict[char] == 1:
size_dict[char] = dim
elif dim not in (1, size_dict[char]):
raise ValueError("Size of label '{}' for operand {} ({}) does not match previous "
"terms ({}).".format(char, tnum, dimension_dict[char], dim))
"terms ({}).".format(char, tnum, size_dict[char], dim))
else:
dimension_dict[char] = dim
size_dict[char] = dim

# Compute size of each input array plus the output array
size_list = [helpers.compute_size_by_dict(term, dimension_dict) for term in input_list + [output_subscript]]
size_list = [helpers.compute_size_by_dict(term, size_dict) for term in input_list + [output_subscript]]
memory_arg = _choose_memory_arg(memory_limit, size_list)

num_ops = len(input_list)
Expand All @@ -245,7 +245,7 @@ def contract_path(*operands, **kwargs):
# indices_in_input = input_subscripts.replace(',', '')

inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
naive_cost = helpers.flop_count(indices, inner_product, num_ops, dimension_dict)
naive_cost = helpers.flop_count(indices, inner_product, num_ops, size_dict)

# Compute the path
if not isinstance(path_type, (str, paths.PathOptimizer)):
Expand All @@ -256,10 +256,10 @@ def contract_path(*operands, **kwargs):
path = [tuple(range(num_ops))]
elif isinstance(path_type, paths.PathOptimizer):
# Custom path optimizer supplied
path = path_type(input_sets, output_set, dimension_dict, memory_arg)
path = path_type(input_sets, output_set, size_dict, memory_arg)
else:
path_optimizer = paths.get_path_fn(path_type)
path = path_optimizer(input_sets, output_set, dimension_dict, memory_arg)
path = path_optimizer(input_sets, output_set, size_dict, memory_arg)

cost_list = []
scale_list = []
Expand All @@ -275,10 +275,10 @@ def contract_path(*operands, **kwargs):
out_inds, input_sets, idx_removed, idx_contract = contract_tuple

# Compute cost, scale, and size
cost = helpers.flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict)
cost = helpers.flop_count(idx_contract, idx_removed, len(contract_inds), size_dict)
cost_list.append(cost)
scale_list.append(len(idx_contract))
size_list.append(helpers.compute_size_by_dict(out_inds, dimension_dict))
size_list.append(helpers.compute_size_by_dict(out_inds, size_dict))

tmp_inputs = [input_list.pop(x) for x in contract_inds]
tmp_shapes = [input_shps.pop(x) for x in contract_inds]
Expand Down Expand Up @@ -312,7 +312,7 @@ def contract_path(*operands, **kwargs):
return operands, contraction_list

path_print = PathInfo(contraction_list, input_subscripts, output_subscript, indices, path, scale_list, naive_cost,
opt_cost, size_list, dimension_dict)
opt_cost, size_list, size_dict)

return path, path_print

Expand Down Expand Up @@ -393,15 +393,26 @@ def contract(*operands, **kwargs):
- ``'optimal'`` An algorithm that explores all possible ways of
contracting the listed tensors. Scales factorially with the number of
terms in the contraction.
- ``'dp'`` A faster (but essentially optimal) algorithm that uses
dynamic programming to exhaustively search all contraction paths
without outer-products.
- ``'greedy'`` An cheap algorithm that heuristically chooses the best
pairwise contraction at each step. Scales linearly in the number of
terms in the contraction.
- ``'random-greedy'`` Run a randomized version of the greedy algorithm
32 times and pick the best path.
- ``'random-greedy-128'`` Run a randomized version of the greedy
algorithm 128 times and pick the best path.
- ``'branch-all'`` An algorithm like optimal but that restricts itself
to searching 'likely' paths. Still scales factorially.
- ``'branch-2'`` An even more restricted version of 'branch-all' that
only searches the best two options at each step. Scales exponentially
with the number of terms in the contraction.
- ``'greedy'`` An algorithm that heuristically chooses the best pair
contraction at each step.
- ``'auto'`` Choose the best of the above algorithms whilst aiming to
keep the path finding time below 1ms.
- ``'auto-hq'`` Aim for a high quality contraction, choosing the best
of the above algorithms whilst aiming to keep the path finding time
below 1sec.
memory_limit : {None, int, 'max_input'} (default: None)
Give the upper bound of the largest intermediate tensor contract will build.
Expand All @@ -412,7 +423,7 @@ def contract(*operands, **kwargs):
The default is None. Note that imposing a limit can make contractions
exponentially slower to perform.
backend : str, optional (default: ``numpy``)
backend : str, optional (default: ``auto``)
Which library to use to perform the required ``tensordot``, ``transpose``
and ``einsum`` calls. Should match the types of arrays supplied, See
:func:`contract_expression` for generating expressions which convert
Expand Down
2 changes: 2 additions & 0 deletions opt_einsum/path_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def setup(self, inputs, output, size_dict):
raise NotImplementedError

def __call__(self, inputs, output, size_dict, memory_limit):
self._check_args_against_first_call(inputs, output, size_dict)

# start a timer?
if self.max_time is not None:
t0 = time.time()
Expand Down
14 changes: 14 additions & 0 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
where ``path`` is a list of int-tuples specifiying a contraction order.
"""

def _check_args_against_first_call(self, inputs, output, size_dict):
"""Utility that stateful optimizers can use to ensure they are not
called with different contractions across separate runs.
"""
args = (inputs, output, size_dict)
if not hasattr(self, '_first_call_args'):
# simply set the attribute as currently there is no global PathOptimizer init
self._first_call_args = args
elif args != self._first_call_args:
raise ValueError("The arguments specifiying the contraction that this path optimizer "
"instance was called with have changed - try creating a new instance.")

def __call__(self, inputs, output, size_dict, memory_limit=None):
raise NotImplementedError

Expand Down Expand Up @@ -336,6 +349,7 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
>>> optimal(isets, oset, idx_sizes, 5000)
[(0, 2), (0, 1)]
"""
self._check_args_against_first_call(inputs, output, size_dict)

inputs = tuple(map(frozenset, inputs))
output = frozenset(output)
Expand Down
12 changes: 12 additions & 0 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,12 @@ def test_custom_random_greedy():
assert path_info.largest_intermediate == optimizer.best['size']
assert path_info.opt_cost == optimizer.best['flops']

# check error if we try and reuse the optimizer on a different expression
eq, shapes = oe.helpers.rand_equation(10, 4, seed=41)
views = list(map(np.ones, shapes))
with pytest.raises(ValueError):
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)


def test_custom_branchbound():
eq, shapes = oe.helpers.rand_equation(8, 4, seed=42)
Expand All @@ -320,6 +326,12 @@ def test_custom_branchbound():
assert path_info.largest_intermediate == optimizer.best['size']
assert path_info.opt_cost == optimizer.best['flops']

# check error if we try and reuse the optimizer on a different expression
eq, shapes = oe.helpers.rand_equation(8, 4, seed=41)
views = list(map(np.ones, shapes))
with pytest.raises(ValueError):
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)


@pytest.mark.skipif(sys.version_info < (3, 2), reason="requires python3.2 or higher")
def test_parallel_random_greedy():
Expand Down

0 comments on commit e88983f

Please sign in to comment.