Skip to content

Commit

Permalink
black v21 formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jan 15, 2022
1 parent d1e560a commit f041ca5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 14 deletions.
71 changes: 60 additions & 11 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
"""

def _check_args_against_first_call(
self, inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int],
self,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
) -> None:
"""Utility that stateful optimizers can use to ensure they are not
called with different contractions across separate runs.
Expand Down Expand Up @@ -166,7 +169,10 @@ def calc_k12_flops(


def _compute_oversize_flops(
inputs: Tuple[FrozenSet[str]], remaining: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int],
inputs: Tuple[FrozenSet[str]],
remaining: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
) -> int:
"""
Compute the flop count for a contraction of all remaining arguments. This
Expand All @@ -179,7 +185,10 @@ def _compute_oversize_flops(


def optimal(
inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int] = None,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
) -> PathType:
"""
Computes all possible pair contractions in a depth-first recursive manner,
Expand Down Expand Up @@ -337,7 +346,11 @@ class BranchBound(PathOptimizer):
"""

def __init__(
self, nbranch=None, cutoff_flops_factor=4, minimize="flops", cost_fn="memory-removed",
self,
nbranch=None,
cutoff_flops_factor=4,
minimize="flops",
cost_fn="memory-removed",
):
self.nbranch = nbranch
self.cutoff_flops_factor = cutoff_flops_factor
Expand Down Expand Up @@ -515,7 +528,14 @@ def _get_candidate(
two = k1 & k2
one = either - two
k12 = (either & output) | (two & dim_ref_counts[3]) | (one & dim_ref_counts[2])
cost = cost_fn(compute_size_by_dict(k12, sizes), footprints[k1], footprints[k2], k12, k1, k2,)
cost = cost_fn(
compute_size_by_dict(k12, sizes),
footprints[k1],
footprints[k2],
k12,
k1,
k2,
)
id1 = remaining[k1]
id2 = remaining[k2]
if id1 > id2:
Expand Down Expand Up @@ -546,7 +566,9 @@ def _push_candidate(


def _update_ref_counts(
dim_to_keys: Dict[str, Set[ArrayIndexType]], dim_ref_counts: Dict[int, Set[str]], dims: ArrayIndexType,
dim_to_keys: Dict[str, Set[ArrayIndexType]],
dim_ref_counts: Dict[int, Set[str]],
dims: ArrayIndexType,
) -> None:
for dim in dims:
count = len(dim_to_keys[dim])
Expand Down Expand Up @@ -637,7 +659,16 @@ def ssa_greedy_optimize(
for i, k1 in enumerate(dim_keys_list[:-1]):
k2s_guess = dim_keys_list[1 + i :]
_push_candidate(
output, sizes, remaining, footprints, dim_ref_counts, k1, k2s_guess, queue, push_all, cost_fn,
output,
sizes,
remaining,
footprints,
dim_ref_counts,
k1,
k2s_guess,
queue,
push_all,
cost_fn,
)

# Greedily contract pairs of tensors.
Expand Down Expand Up @@ -670,7 +701,16 @@ def ssa_greedy_optimize(
k2s.discard(k1)
if k2s:
_push_candidate(
output, sizes, remaining, footprints, dim_ref_counts, k1, list(k2s), queue, push_all, cost_fn,
output,
sizes,
remaining,
footprints,
dim_ref_counts,
k1,
list(k2s),
queue,
push_all,
cost_fn,
)

# Greedily compute pairwise outer products.
Expand Down Expand Up @@ -1276,7 +1316,10 @@ def __call__(
# outer products should be performed pairwise (to use BLAS functions)
subgraph_contractions = [
subgraph_contractions[j]
for j in sorted(range(len(subgraph_contractions_size)), key=subgraph_contractions_size.__getitem__,)
for j in sorted(
range(len(subgraph_contractions_size)),
key=subgraph_contractions_size.__getitem__,
)
]

# build the final contraction tree
Expand Down Expand Up @@ -1307,7 +1350,10 @@ def dynamic_programming(


def auto(
inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int] = None,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
) -> PathType:
"""Finds the contraction path by automatically choosing the method based on
how many input arguments there are.
Expand All @@ -1324,7 +1370,10 @@ def auto(


def auto_hq(
inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int] = None,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
) -> PathType:
"""Finds the contraction path by automatically choosing the method based on
how many input arguments there are, but targeting a more generous
Expand Down
17 changes: 14 additions & 3 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@
import opt_einsum as oe

explicit_path_tests = {
"GEMM1": ([set("abd"), set("ac"), set("bdc")], set(""), {"a": 1, "b": 2, "c": 3, "d": 4},),
"Inner1": ([set("abcd"), set("abc"), set("bc")], set(""), {"a": 5, "b": 2, "c": 3, "d": 4},),
"GEMM1": (
[set("abd"), set("ac"), set("bdc")],
set(""),
{"a": 1, "b": 2, "c": 3, "d": 4},
),
"Inner1": (
[set("abcd"), set("abc"), set("bc")],
set(""),
{"a": 5, "b": 2, "c": 3, "d": 4},
),
}

# note that these tests have no unique solution due to the chosen dimensions
Expand Down Expand Up @@ -43,7 +51,10 @@

# note that these tests have no unique solution due to the chosen dimensions
path_scalar_tests = [
["a,->a", 1,],
[
"a,->a",
1,
],
["ab,->ab", 1],
[",a,->a", 2],
[",,a,->a", 3],
Expand Down

0 comments on commit f041ca5

Please sign in to comment.