Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jan 15, 2022
1 parent 7f89e8d commit d1e560a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 99 deletions.
132 changes: 56 additions & 76 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
"""

def _check_args_against_first_call(
self,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
self, inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int],
) -> None:
"""Utility that stateful optimizers can use to ensure they are not
called with different contractions across separate runs.
Expand Down Expand Up @@ -169,10 +166,7 @@ def calc_k12_flops(


def _compute_oversize_flops(
inputs: Tuple[FrozenSet[str]],
remaining: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
inputs: Tuple[FrozenSet[str]], remaining: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int],
) -> int:
"""
Compute the flop count for a contraction of all remaining arguments. This
Expand All @@ -185,10 +179,7 @@ def _compute_oversize_flops(


def optimal(
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int] = None,
) -> PathType:
"""
Computes all possible pair contractions in a depth-first recursive manner,
Expand Down Expand Up @@ -346,11 +337,7 @@ class BranchBound(PathOptimizer):
"""

def __init__(
self,
nbranch=None,
cutoff_flops_factor=4,
minimize="flops",
cost_fn="memory-removed",
self, nbranch=None, cutoff_flops_factor=4, minimize="flops", cost_fn="memory-removed",
):
self.nbranch = nbranch
self.cutoff_flops_factor = cutoff_flops_factor
Expand Down Expand Up @@ -528,14 +515,7 @@ def _get_candidate(
two = k1 & k2
one = either - two
k12 = (either & output) | (two & dim_ref_counts[3]) | (one & dim_ref_counts[2])
cost = cost_fn(
compute_size_by_dict(k12, sizes),
footprints[k1],
footprints[k2],
k12,
k1,
k2,
)
cost = cost_fn(compute_size_by_dict(k12, sizes), footprints[k1], footprints[k2], k12, k1, k2,)
id1 = remaining[k1]
id2 = remaining[k2]
if id1 > id2:
Expand Down Expand Up @@ -566,9 +546,7 @@ def _push_candidate(


def _update_ref_counts(
dim_to_keys: Dict[str, Set[ArrayIndexType]],
dim_ref_counts: Dict[int, Set[str]],
dims: ArrayIndexType,
dim_to_keys: Dict[str, Set[ArrayIndexType]], dim_ref_counts: Dict[int, Set[str]], dims: ArrayIndexType,
) -> None:
for dim in dims:
count = len(dim_to_keys[dim])
Expand Down Expand Up @@ -659,16 +637,7 @@ def ssa_greedy_optimize(
for i, k1 in enumerate(dim_keys_list[:-1]):
k2s_guess = dim_keys_list[1 + i :]
_push_candidate(
output,
sizes,
remaining,
footprints,
dim_ref_counts,
k1,
k2s_guess,
queue,
push_all,
cost_fn,
output, sizes, remaining, footprints, dim_ref_counts, k1, k2s_guess, queue, push_all, cost_fn,
)

# Greedily contract pairs of tensors.
Expand Down Expand Up @@ -701,22 +670,11 @@ def ssa_greedy_optimize(
k2s.discard(k1)
if k2s:
_push_candidate(
output,
sizes,
remaining,
footprints,
dim_ref_counts,
k1,
list(k2s),
queue,
push_all,
cost_fn,
output, sizes, remaining, footprints, dim_ref_counts, k1, list(k2s), queue, push_all, cost_fn,
)

# Greedily compute pairwise outer products.
final_queue = [
(compute_size_by_dict(key & output, sizes), ssa_id, key) for key, ssa_id in remaining.items()
]
final_queue = [(compute_size_by_dict(key & output, sizes), ssa_id, key) for key, ssa_id in remaining.items()]
heapq.heapify(final_queue)
_, ssa_id1, k1 = heapq.heappop(final_queue)
while final_queue:
Expand Down Expand Up @@ -980,8 +938,23 @@ def _dp_compare_size(
xn[s] = (i, cost, (contract1, contract2))


def _dp_compare_write(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, xn, g, all_tensors, inputs,
i1_cut_i2_wo_output, memory_limit, cntrct1, cntrct2):
def _dp_compare_write(
cost1,
cost2,
i1_union_i2,
size_dict,
cost_cap,
s1,
s2,
xn,
g,
all_tensors,
inputs,
i1_cut_i2_wo_output,
memory_limit,
cntrct1,
cntrct2,
):
"""Like ``_dp_compare_flops`` but sieves the potential contraction based
on the total size of memory created, rather than the number of
operations, and so calculates that first.
Expand All @@ -999,9 +972,25 @@ def _dp_compare_write(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, xn
DEFAULT_COMBO_FACTOR = 64


def _dp_compare_combo(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, xn, g, all_tensors, inputs,
i1_cut_i2_wo_output, memory_limit, cntrct1, cntrct2,
factor=DEFAULT_COMBO_FACTOR, combine=sum):
def _dp_compare_combo(
cost1,
cost2,
i1_union_i2,
size_dict,
cost_cap,
s1,
s2,
xn,
g,
all_tensors,
inputs,
i1_cut_i2_wo_output,
memory_limit,
cntrct1,
cntrct2,
factor=DEFAULT_COMBO_FACTOR,
combine=sum,
):
"""Like ``_dp_compare_flops`` but sieves the potential contraction based
on some combination of both the flops and size,
"""
Expand All @@ -1024,25 +1013,25 @@ def _parse_minimize(minimize):
"""This works out what local scoring function to use for the dp algorithm
as well as a `naive_scale` to account for the memory_limit checks.
"""
if minimize == 'flops':
if minimize == "flops":
return _dp_compare_flops, 1
if minimize == 'size':
if minimize == "size":
return _dp_compare_size, 1
if minimize == 'write':
if minimize == "write":
return _dp_compare_write, 1

# default to naive_scale=inf as otherwise memory_limit check can cause problems

if callable(minimize):
return minimize, float('inf')
return minimize, float("inf")

# parse out a customized value for the combination factor
minimize, factor = minimize_finder.fullmatch(minimize).groups()
factor = float(factor) if factor else DEFAULT_COMBO_FACTOR
if minimize == 'combo':
return functools.partial(_dp_compare_combo, factor=factor, combine=sum), float('inf')
if minimize == 'limit':
return functools.partial(_dp_compare_combo, factor=factor, combine=max), float('inf')
if minimize == "combo":
return functools.partial(_dp_compare_combo, factor=factor, combine=sum), float("inf")
if minimize == "limit":
return functools.partial(_dp_compare_combo, factor=factor, combine=max), float("inf")

raise ValueError(f"Couldn't parse `minimize` value: {minimize}.")

Expand Down Expand Up @@ -1287,10 +1276,7 @@ def __call__(
# outer products should be performed pairwise (to use BLAS functions)
subgraph_contractions = [
subgraph_contractions[j]
for j in sorted(
range(len(subgraph_contractions_size)),
key=subgraph_contractions_size.__getitem__,
)
for j in sorted(range(len(subgraph_contractions_size)), key=subgraph_contractions_size.__getitem__,)
]

# build the final contraction tree
Expand Down Expand Up @@ -1321,10 +1307,7 @@ def dynamic_programming(


def auto(
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int] = None,
) -> PathType:
"""Finds the contraction path by automatically choosing the method based on
how many input arguments there are.
Expand All @@ -1341,10 +1324,7 @@ def auto(


def auto_hq(
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int] = None,
) -> PathType:
"""Finds the contraction path by automatically choosing the method based on
how many input arguments there are, but targeting a more generous
Expand Down
38 changes: 15 additions & 23 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,8 @@
import opt_einsum as oe

explicit_path_tests = {
"GEMM1": (
[set("abd"), set("ac"), set("bdc")],
set(""),
{"a": 1, "b": 2, "c": 3, "d": 4},
),
"Inner1": (
[set("abcd"), set("abc"), set("bc")],
set(""),
{"a": 5, "b": 2, "c": 3, "d": 4},
),
"GEMM1": ([set("abd"), set("ac"), set("bdc")], set(""), {"a": 1, "b": 2, "c": 3, "d": 4},),
"Inner1": ([set("abcd"), set("abc"), set("bc")], set(""), {"a": 5, "b": 2, "c": 3, "d": 4},),
}

# note that these tests have no unique solution due to the chosen dimensions
Expand Down Expand Up @@ -51,10 +43,7 @@

# note that these tests have no unique solution due to the chosen dimensions
path_scalar_tests = [
[
"a,->a",
1,
],
["a,->a", 1,],
["ab,->ab", 1],
[",a,->a", 2],
[",,a,->a", 3],
Expand Down Expand Up @@ -270,15 +259,18 @@ def test_custom_dp_can_set_cost_cap():
assert info1.opt_cost == info2.opt_cost == info3.opt_cost


@pytest.mark.parametrize('minimize,cost,width', [
('flops', 663054, 18900),
('size', 1114440, 2016),
('write', 983790, 2016),
('combo', 973518, 2016),
('limit', 983832, 2016),
('combo-256', 983790, 2016),
('limit-256', 983832, 2016),
])
@pytest.mark.parametrize(
"minimize,cost,width",
[
("flops", 663054, 18900),
("size", 1114440, 2016),
("write", 983790, 2016),
("combo", 973518, 2016),
("limit", 983832, 2016),
("combo-256", 983790, 2016),
("limit-256", 983832, 2016),
],
)
def test_custom_dp_can_set_minimize(minimize, cost, width):
eq, shapes = oe.helpers.rand_equation(10, 4, seed=43)
opt = oe.DynamicProgramming(minimize=minimize)
Expand Down

0 comments on commit d1e560a

Please sign in to comment.