Skip to content

Commit

Permalink
stateful optimizers: check args against first call
Browse files Browse the repository at this point in the history
i.e. BranchBound and RandomGreedy
  • Loading branch information
jcmgray committed May 12, 2020
1 parent 07bfda4 commit 00096c4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
2 changes: 2 additions & 0 deletions opt_einsum/path_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def setup(self, inputs, output, size_dict):
raise NotImplementedError

def __call__(self, inputs, output, size_dict, memory_limit):
self.check_args_against_first_call(inputs, output, size_dict)

# start a timer?
if self.max_time is not None:
t0 = time.time()
Expand Down
13 changes: 13 additions & 0 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
where ``path`` is a list of int-tuples specifiying a contraction order.
"""

def check_args_against_first_call(self, inputs, output, size_dict):
"""Utility that stateful optimizers can use to ensure they are not
called with different contractions across separate runs.
"""
args = (inputs, output, size_dict)
if not hasattr(self, '_first_call_args'):
self._first_call_args = args
elif args != self._first_call_args:
raise ValueError("The arguments specifiying the contraction that this path optimizer "
"instance was called with have changed - try creating a new instance.")

def __call__(self, inputs, output, size_dict, memory_limit=None):
raise NotImplementedError

Expand Down Expand Up @@ -336,6 +348,7 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
>>> optimal(isets, oset, idx_sizes, 5000)
[(0, 2), (0, 1)]
"""
self.check_args_against_first_call(inputs, output, size_dict)

inputs = tuple(map(frozenset, inputs))
output = frozenset(output)
Expand Down

0 comments on commit 00096c4

Please sign in to comment.