Skip to content

Commit

Permalink
add type-hints and explicit regression path testing
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jan 16, 2022
1 parent f041ca5 commit df54ffd
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 55 deletions.
89 changes: 44 additions & 45 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,21 +979,21 @@ def _dp_compare_size(


def _dp_compare_write(
cost1,
cost2,
i1_union_i2,
size_dict,
cost_cap,
s1,
s2,
xn,
g,
all_tensors,
inputs,
i1_cut_i2_wo_output,
memory_limit,
cntrct1,
cntrct2,
cost1: int,
cost2: int,
i1_union_i2: Set[int],
size_dict: List[int],
cost_cap: int,
s1: int,
s2: int,
xn: Dict[int, Any],
g: int,
all_tensors: int,
inputs: List[FrozenSet[int]],
i1_cut_i2_wo_output: Set[int],
memory_limit: Optional[int],
contract1: Union[int, Tuple[int]],
contract2: Union[int, Tuple[int]],
):
"""Like ``_dp_compare_flops`` but sieves the potential contraction based
on the total size of memory created, rather than the number of
Expand All @@ -1006,30 +1006,30 @@ def _dp_compare_write(
if cost <= cost_cap:
if s not in xn or cost < xn[s][1]:
if memory_limit is None or mem <= memory_limit:
xn[s] = (i, cost, (cntrct1, cntrct2))
xn[s] = (i, cost, (contract1, contract2))


DEFAULT_COMBO_FACTOR = 64


def _dp_compare_combo(
cost1,
cost2,
i1_union_i2,
size_dict,
cost_cap,
s1,
s2,
xn,
g,
all_tensors,
inputs,
i1_cut_i2_wo_output,
memory_limit,
cntrct1,
cntrct2,
factor=DEFAULT_COMBO_FACTOR,
combine=sum,
cost1: int,
cost2: int,
i1_union_i2: Set[int],
size_dict: List[int],
cost_cap: int,
s1: int,
s2: int,
xn: Dict[int, Any],
g: int,
all_tensors: int,
inputs: List[FrozenSet[int]],
i1_cut_i2_wo_output: Set[int],
memory_limit: Optional[int],
contract1: Union[int, Tuple[int]],
contract2: Union[int, Tuple[int]],
factor: Union[int, float]=DEFAULT_COMBO_FACTOR,
combine: Callable=sum,
):
"""Like ``_dp_compare_flops`` but sieves the potential contraction based
on some combination of both the flops and size,
Expand All @@ -1042,38 +1042,37 @@ def _dp_compare_combo(
if cost <= cost_cap:
if s not in xn or cost < xn[s][1]:
if memory_limit is None or mem <= memory_limit:
xn[s] = (i, cost, (cntrct1, cntrct2))
xn[s] = (i, cost, (contract1, contract2))


minimize_finder = re.compile("(flops|size|write|combo|limit)-*(\d*)")
minimize_finder = re.compile(r"(flops|size|write|combo|limit)-*(\d*)")


@functools.lru_cache(128)
def _parse_minimize(minimize):
def _parse_minimize(minimize: Union[str, Callable]):
"""This works out what local scoring function to use for the dp algorithm
as well as a `naive_scale` to account for the memory_limit checks.
"""
if minimize == "flops":
return _dp_compare_flops, 1
if minimize == "size":
elif minimize == "size":
return _dp_compare_size, 1
if minimize == "write":
elif minimize == "write":
return _dp_compare_write, 1

# default to naive_scale=inf as otherwise memory_limit check can cause problems

if callable(minimize):
elif callable(minimize):
# default to naive_scale=inf for this and remaining options
# as otherwise memory_limit check can cause problems
return minimize, float("inf")

# parse out a customized value for the combination factor
minimize, factor = minimize_finder.fullmatch(minimize).groups()
factor = float(factor) if factor else DEFAULT_COMBO_FACTOR
if minimize == "combo":
return functools.partial(_dp_compare_combo, factor=factor, combine=sum), float("inf")
if minimize == "limit":
elif minimize == "limit":
return functools.partial(_dp_compare_combo, factor=factor, combine=max), float("inf")

raise ValueError(f"Couldn't parse `minimize` value: {minimize}.")
else:
raise ValueError(f"Couldn't parse `minimize` value: {minimize}.")


def simple_tree_tuple(seq: Sequence[Tuple[int, ...]]) -> Tuple[Any, ...]:
Expand Down
21 changes: 11 additions & 10 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,21 +271,22 @@ def test_custom_dp_can_set_cost_cap():


@pytest.mark.parametrize(
"minimize,cost,width",
"minimize,cost,width,path",
[
("flops", 663054, 18900),
("size", 1114440, 2016),
("write", 983790, 2016),
("combo", 973518, 2016),
("limit", 983832, 2016),
("combo-256", 983790, 2016),
("limit-256", 983832, 2016),
],
("flops", 663054, 18900, [(4, 5), (2, 5), (2, 7), (5, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
("size", 1114440, 2016, [(2, 7), (3, 8), (3, 7), (2, 6), (1, 5), (1, 4), (1, 3), (1, 2), (0, 1)]),
("write", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("combo", 973518, 2016, [(4, 5), (2, 5), (6, 7), (2, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
("limit", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("combo-256", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("limit-256", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
]
)
def test_custom_dp_can_set_minimize(minimize, cost, width):
def test_custom_dp_can_set_minimize(minimize, cost, width, path):
eq, shapes = oe.helpers.rand_equation(10, 4, seed=43)
opt = oe.DynamicProgramming(minimize=minimize)
info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)[1]
assert info.path == path
assert info.opt_cost == cost
assert info.largest_intermediate == width

Expand Down

0 comments on commit df54ffd

Please sign in to comment.