Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more custom dp scores for faster contractions #181

Merged
merged 7 commits into from
Jan 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/paths/dp_path.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,23 @@ oe.contract(eq, *arrays, optimize=optimizer)
!!! warning
Note that searching outer products will most likely drastically slow down
the optimizer on all but the smallest examples.


The values that `minimize` can take are:

- `'flops'`: minimize the total number of scalar operations.
- `'size'`: minimize the size of the largest intermediate.
- `'write'`: minimize the combined size of all intermediate tensors -
approximately speaking the amount of memory that will be written. This is
relevant if you were to automatically differentiate through the
contraction, which naively would require storing all intermediates.
- `'combo'` - minimize `flops + alpha * write` summed over intermediates, a
default ratio of `alpha=64` is used, or it can be customized with
`f'combo-{alpha}'`.
- `'limit'` - minimize `max(flops, alpha * write)` summed over intermediates, a
default ratio of `alpha=64` is used, or it can be customized with `f'limit-{alpha}'`.

The last two take into account the fact that real contraction performance can
be bound by memory speed, and so favor paths with higher arithmetic
intensity. The default value of `alpha=64` is reasonable for both typical
CPUs and GPUs.
171 changes: 135 additions & 36 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import itertools
import operator
import random
import re
from collections import Counter, OrderedDict, defaultdict
from typing import Any, Callable
from typing import Counter as CounterType
from typing import Dict, FrozenSet, Generator, List, Optional, Sequence, Set, Tuple, Union

import numpy as np

from . import helpers
from .helpers import compute_size_by_dict, flop_count
from .typing import ArrayIndexType, PathType

__all__ = [
Expand Down Expand Up @@ -162,7 +163,7 @@ def calc_k12_flops(
keep = frozenset.union(output, *map(inputs.__getitem__, remaining - {i, j}))

k12 = either & keep
cost = helpers.flop_count(either, bool(shared - keep), 2, size_dict)
cost = flop_count(either, bool(shared - keep), 2, size_dict)

return k12, cost

Expand All @@ -180,7 +181,7 @@ def _compute_oversize_flops(
idx_contraction = frozenset.union(*map(inputs.__getitem__, remaining)) # type: ignore
inner = idx_contraction - output
num_terms = len(remaining)
return helpers.flop_count(idx_contraction, bool(inner), num_terms, size_dict)
return flop_count(idx_contraction, bool(inner), num_terms, size_dict)


def optimal(
Expand Down Expand Up @@ -252,7 +253,7 @@ def _optimal_iterate(path, remaining, inputs, flops):
try:
size12 = size_cache[k12]
except KeyError:
size12 = size_cache[k12] = helpers.compute_size_by_dict(k12, size_dict)
size12 = size_cache[k12] = compute_size_by_dict(k12, size_dict)

# possibly terminate this path with an all-terms einsum
if size12 > memory_limit:
Expand Down Expand Up @@ -398,7 +399,7 @@ def __call__(
inputs: Tuple[FrozenSet[str]] = tuple(map(frozenset, inputs_)) # type: ignore
output: FrozenSet[str] = frozenset(output_)

size_cache = {k: helpers.compute_size_by_dict(k, size_dict) for k in inputs}
size_cache = {k: compute_size_by_dict(k, size_dict) for k in inputs}
result_cache: Dict[Tuple[FrozenSet[str], FrozenSet[str]], Tuple[FrozenSet[str], int]] = {}

def _branch_iterate(path, inputs, remaining, flops, size):
Expand All @@ -420,7 +421,7 @@ def _assess_candidate(k1: FrozenSet[str], k2: FrozenSet[str], i: int, j: int) ->
try:
size12 = size_cache[k12]
except KeyError:
size12 = size_cache[k12] = helpers.compute_size_by_dict(k12, size_dict)
size12 = size_cache[k12] = compute_size_by_dict(k12, size_dict)

new_flops = flops + flops12
new_size = max(size, size12)
Expand Down Expand Up @@ -528,7 +529,7 @@ def _get_candidate(
one = either - two
k12 = (either & output) | (two & dim_ref_counts[3]) | (one & dim_ref_counts[2])
cost = cost_fn(
helpers.compute_size_by_dict(k12, sizes),
compute_size_by_dict(k12, sizes),
footprints[k1],
footprints[k2],
k12,
Expand Down Expand Up @@ -649,7 +650,7 @@ def ssa_greedy_optimize(
}

# Compute separable part of the objective function for contractions.
footprints = {key: helpers.compute_size_by_dict(key, sizes) for key in remaining}
footprints = {key: compute_size_by_dict(key, sizes) for key in remaining}

# Find initial candidate contractions.
queue: List[GreedyContractionType] = []
Expand Down Expand Up @@ -692,7 +693,7 @@ def ssa_greedy_optimize(
dim_to_keys[dim].add(k12)
remaining[k12] = next(ssa_ids)
_update_ref_counts(dim_to_keys, dim_ref_counts, k1 | k2 - output)
footprints[k12] = helpers.compute_size_by_dict(k12, sizes)
footprints[k12] = compute_size_by_dict(k12, sizes)

# Find new candidate contractions.
k1 = k12
Expand All @@ -713,16 +714,14 @@ def ssa_greedy_optimize(
)

# Greedily compute pairwise outer products.
final_queue = [
(helpers.compute_size_by_dict(key & output, sizes), ssa_id, key) for key, ssa_id in remaining.items()
]
final_queue = [(compute_size_by_dict(key & output, sizes), ssa_id, key) for key, ssa_id in remaining.items()]
heapq.heapify(final_queue)
_, ssa_id1, k1 = heapq.heappop(final_queue)
while final_queue:
_, ssa_id2, k2 = heapq.heappop(final_queue)
ssa_path.append((min(ssa_id1, ssa_id2), max(ssa_id1, ssa_id2)))
k12 = (k1 | k2) & output
cost = helpers.compute_size_by_dict(k12, sizes)
cost = compute_size_by_dict(k12, sizes)
ssa_id12 = next(ssa_ids)
_, ssa_id1, k1 = heapq.heappushpop(final_queue, (cost, ssa_id12, k12))

Expand Down Expand Up @@ -937,12 +936,12 @@ def _dp_compare_flops(
"""

# TODO: Odd usage with an Iterable[int] to map a dict of type List[int]
cost = cost1 + cost2 + helpers.compute_size_by_dict(i1_union_i2, size_dict)
cost = cost1 + cost2 + compute_size_by_dict(i1_union_i2, size_dict)
if cost <= cost_cap:
s = s1 | s2
if s not in xn or cost < xn[s][1]:
i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
mem = helpers.compute_size_by_dict(i, size_dict)
mem = compute_size_by_dict(i, size_dict)
if memory_limit is None or mem <= memory_limit:
xn[s] = (i, cost, (contract1, contract2))

Expand Down Expand Up @@ -971,14 +970,115 @@ def _dp_compare_size(

s = s1 | s2
i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
mem = helpers.compute_size_by_dict(i, size_dict)
mem = compute_size_by_dict(i, size_dict)
cost = max(cost1, cost2, mem)
if cost <= cost_cap:
if s not in xn or cost < xn[s][1]:
if memory_limit is None or mem <= memory_limit:
xn[s] = (i, cost, (contract1, contract2))


def _dp_compare_write(
dgasmith marked this conversation as resolved.
Show resolved Hide resolved
cost1: int,
cost2: int,
i1_union_i2: Set[int],
size_dict: List[int],
cost_cap: int,
s1: int,
s2: int,
xn: Dict[int, Any],
g: int,
all_tensors: int,
inputs: List[FrozenSet[int]],
i1_cut_i2_wo_output: Set[int],
memory_limit: Optional[int],
contract1: Union[int, Tuple[int]],
contract2: Union[int, Tuple[int]],
) -> None:
"""Like ``_dp_compare_flops`` but sieves the potential contraction based
on the total size of memory created, rather than the number of
operations, and so calculates that first.
"""
s = s1 | s2
i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
mem = compute_size_by_dict(i, size_dict)
cost = cost1 + cost2 + mem
if cost <= cost_cap:
if s not in xn or cost < xn[s][1]:
if memory_limit is None or mem <= memory_limit:
xn[s] = (i, cost, (contract1, contract2))


DEFAULT_COMBO_FACTOR = 64


def _dp_compare_combo(
cost1: int,
cost2: int,
i1_union_i2: Set[int],
size_dict: List[int],
cost_cap: int,
s1: int,
s2: int,
xn: Dict[int, Any],
g: int,
all_tensors: int,
inputs: List[FrozenSet[int]],
i1_cut_i2_wo_output: Set[int],
memory_limit: Optional[int],
contract1: Union[int, Tuple[int]],
contract2: Union[int, Tuple[int]],
factor: Union[int, float] = DEFAULT_COMBO_FACTOR,
combine: Callable = sum,
) -> None:
"""Like ``_dp_compare_flops`` but sieves the potential contraction based
on some combination of both the flops and size,
"""
s = s1 | s2
i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
mem = compute_size_by_dict(i, size_dict)
f = compute_size_by_dict(i1_union_i2, size_dict)
cost = cost1 + cost2 + combine((f, factor * mem))
if cost <= cost_cap:
if s not in xn or cost < xn[s][1]:
if memory_limit is None or mem <= memory_limit:
xn[s] = (i, cost, (contract1, contract2))


minimize_finder = re.compile(r"(flops|size|write|combo|limit)-*(\d*)")


@functools.lru_cache(128)
def _parse_minimize(minimize: Union[str, Callable]) -> Tuple[Callable, Union[int, float]]:
"""This works out what local scoring function to use for the dp algorithm
as well as a `naive_scale` to account for the memory_limit checks.
"""
if minimize == "flops":
dgasmith marked this conversation as resolved.
Show resolved Hide resolved
return _dp_compare_flops, 1
elif minimize == "size":
return _dp_compare_size, 1
elif minimize == "write":
return _dp_compare_write, 1
elif callable(minimize):
# default to naive_scale=inf for this and remaining options
# as otherwise memory_limit check can cause problems
return minimize, float("inf")

# parse out a customized value for the combination factor
match = minimize_finder.fullmatch(minimize)
if match is None:
raise ValueError(f"Couldn't parse `minimize` value: {minimize}.")

minimize, custom_factor = match.groups()
factor = float(custom_factor) if custom_factor else DEFAULT_COMBO_FACTOR
if minimize == "combo":
return functools.partial(_dp_compare_combo, factor=factor, combine=sum), float("inf")
elif minimize == "limit":
return functools.partial(_dp_compare_combo, factor=factor, combine=max), float("inf")
else:
raise ValueError(f"Couldn't parse `minimize` value: {minimize}.")


def simple_tree_tuple(seq: Sequence[Tuple[int, ...]]) -> Tuple[Any, ...]:
"""Make a simple left to right binary tree out of iterable `seq`.

Expand Down Expand Up @@ -1035,8 +1135,17 @@ class DynamicProgramming(PathOptimizer):

**Parameters:**

- **minimize** - *({'flops', 'size'}, optional)* Whether to find the contraction that minimizes the number of
operations or the size of the largest intermediate tensor.
- **minimize** - *({'flops', 'size', 'write', 'combo', 'limit', callable}, optional)* What to minimize:

- 'flops' - minimize the number of flops
- 'size' - minimize the size of the largest intermediate
- 'write' - minimize the size of all intermediate tensors
- 'combo' - minimize `flops + alpha * write` summed over intermediates, a default ratio of alpha=64
is used, or it can be customized with `f'combo-{alpha}'`
- 'limit' - minimize `max(flops, alpha * write)` summed over intermediates, a default ratio of alpha=64
is used, or it can be customized with `f'limit-{alpha}'`
- callable - a custom local cost function

- **cost_cap** - *({True, False, int}, optional)* How to implement cost-capping:

- True - iteratively increase the cost-cap
Expand All @@ -1049,21 +1158,8 @@ class DynamicProgramming(PathOptimizer):
"""

def __init__(self, minimize: str = "flops", cost_cap: bool = True, search_outer: bool = False) -> None:

# set whether inner function minimizes against flops or size
self.minimize = minimize
self._check_contraction = {
"flops": _dp_compare_flops,
"size": _dp_compare_size,
}[self.minimize]

# set whether inner function considers outer products
self.search_outer = search_outer
self._check_outer = {
False: lambda x: x,
True: lambda x: True,
}[self.search_outer]

self.cost_cap = cost_cap

def __call__(
Expand Down Expand Up @@ -1106,6 +1202,9 @@ def __call__(
#> [(1, 2), (0, 4), (1, 2), (0, 2), (0, 1)]
```
"""
_check_contraction, naive_scale = _parse_minimize(self.minimize)
_check_outer = (lambda x: True) if self.search_outer else (lambda x: x)

ind_counts = Counter(itertools.chain(*inputs_, output_))
all_inds = tuple(ind_counts)

Expand All @@ -1115,7 +1214,7 @@ def __call__(
output = frozenset(symbol2int[c] for c in output_)
size_dict_canonical = {symbol2int[c]: v for c, v in size_dict_.items() if c in symbol2int}
size_dict = [size_dict_canonical[j] for j in range(len(size_dict_canonical))]
naive_cost = len(inputs) * functools.reduce(operator.mul, size_dict)
naive_cost = naive_scale * len(inputs) * functools.reduce(operator.mul, size_dict)

inputs, inputs_done, inputs_contractions = _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts)

Expand Down Expand Up @@ -1158,7 +1257,7 @@ def __call__(
# output index dimensions as initial cost_cap
subgraph_inds = frozenset.union(*_bitmap_select(bitmap_g, inputs))
if self.cost_cap is True:
cost_cap = helpers.compute_size_by_dict(subgraph_inds & output, size_dict)
cost_cap = compute_size_by_dict(subgraph_inds & output, size_dict)
elif self.cost_cap is False:
cost_cap = float("inf") # type: ignore
else:
Expand All @@ -1184,10 +1283,10 @@ def __call__(
i1_cut_i2_wo_output = (i1 & i2) - output

# maybe ignore outer products:
if self._check_outer(i1_cut_i2_wo_output):
if _check_outer(i1_cut_i2_wo_output):

i1_union_i2 = i1 | i2
self._check_contraction(
_check_contraction(
cost1,
cost2,
i1_union_i2,
Expand All @@ -1213,7 +1312,7 @@ def __call__(

i, cost, contraction = list(x[-1].values())[0]
subgraph_contractions.append(contraction)
subgraph_contractions_size.append(helpers.compute_size_by_dict(i, size_dict))
subgraph_contractions_size.append(compute_size_by_dict(i, size_dict))

# sort the subgraph contractions by the size of the subgraphs in
# ascending order (will give the cheapest contractions); note that
Expand Down
21 changes: 21 additions & 0 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,27 @@ def test_custom_dp_can_set_cost_cap():
assert info1.opt_cost == info2.opt_cost == info3.opt_cost


@pytest.mark.parametrize(
"minimize,cost,width,path",
[
("flops", 663054, 18900, [(4, 5), (2, 5), (2, 7), (5, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
("size", 1114440, 2016, [(2, 7), (3, 8), (3, 7), (2, 6), (1, 5), (1, 4), (1, 3), (1, 2), (0, 1)]),
("write", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("combo", 973518, 2016, [(4, 5), (2, 5), (6, 7), (2, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
("limit", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("combo-256", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("limit-256", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
],
)
def test_custom_dp_can_set_minimize(minimize, cost, width, path):
eq, shapes = oe.helpers.rand_equation(10, 4, seed=43)
opt = oe.DynamicProgramming(minimize=minimize)
info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)[1]
assert info.path == path
assert info.opt_cost == cost
assert info.largest_intermediate == width
dgasmith marked this conversation as resolved.
Show resolved Hide resolved


def test_dp_errors_when_no_contractions_found():
eq, shapes, size_dict = oe.helpers.rand_equation(10, 3, seed=42, return_size_dict=True)

Expand Down