Skip to content

Commit

Permalink
tweaks to dp to allow - FLOPS+WRITE minimization
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Mar 24, 2021
1 parent 8be6349 commit 37e9573
Showing 1 changed file with 58 additions and 28 deletions.
86 changes: 58 additions & 28 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np

from . import helpers
from .helpers import compute_size_by_dict

__all__ = [
"optimal", "BranchBound", "branch", "greedy", "auto", "auto_hq", "get_path_fn", "DynamicProgramming",
Expand Down Expand Up @@ -216,7 +217,7 @@ def _optimal_iterate(path, remaining, inputs, flops):
try:
size12 = size_cache[k12]
except KeyError:
size12 = size_cache[k12] = helpers.compute_size_by_dict(k12, size_dict)
size12 = size_cache[k12] = compute_size_by_dict(k12, size_dict)

# possibly terminate this path with an all-terms einsum
if size12 > memory_limit:
Expand Down Expand Up @@ -355,7 +356,7 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
inputs = tuple(map(frozenset, inputs))
output = frozenset(output)

size_cache = {k: helpers.compute_size_by_dict(k, size_dict) for k in inputs}
size_cache = {k: compute_size_by_dict(k, size_dict) for k in inputs}
result_cache = {}

def _branch_iterate(path, inputs, remaining, flops, size):
Expand All @@ -377,7 +378,7 @@ def _assess_candidate(k1, k2, i, j):
try:
size12 = size_cache[k12]
except KeyError:
size12 = size_cache[k12] = helpers.compute_size_by_dict(k12, size_dict)
size12 = size_cache[k12] = compute_size_by_dict(k12, size_dict)

new_flops = flops + flops12
new_size = max(size, size12)
Expand Down Expand Up @@ -464,7 +465,7 @@ def _get_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2,
two = k1 & k2
one = either - two
k12 = (either & output) | (two & dim_ref_counts[3]) | (one & dim_ref_counts[2])
cost = cost_fn(helpers.compute_size_by_dict(k12, sizes), footprints[k1], footprints[k2], k12, k1, k2)
cost = cost_fn(compute_size_by_dict(k12, sizes), footprints[k1], footprints[k2], k12, k1, k2)
id1 = remaining[k1]
id2 = remaining[k2]
if id1 > id2:
Expand Down Expand Up @@ -560,7 +561,7 @@ def ssa_greedy_optimize(inputs, output, sizes, choose_fn=None, cost_fn='memory-r
}

# Compute separable part of the objective function for contractions.
footprints = {key: helpers.compute_size_by_dict(key, sizes) for key in remaining}
footprints = {key: compute_size_by_dict(key, sizes) for key in remaining}

# Find initial candidate contractions.
queue = []
Expand Down Expand Up @@ -592,7 +593,7 @@ def ssa_greedy_optimize(inputs, output, sizes, choose_fn=None, cost_fn='memory-r
dim_to_keys[dim].add(k12)
remaining[k12] = next(ssa_ids)
_update_ref_counts(dim_to_keys, dim_ref_counts, k1 | k2 - output)
footprints[k12] = helpers.compute_size_by_dict(k12, sizes)
footprints[k12] = compute_size_by_dict(k12, sizes)

# Find new candidate contractions.
k1 = k12
Expand All @@ -602,14 +603,14 @@ def ssa_greedy_optimize(inputs, output, sizes, choose_fn=None, cost_fn='memory-r
_push_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2s, queue, push_all, cost_fn)

# Greedily compute pairwise outer products.
queue = [(helpers.compute_size_by_dict(key & output, sizes), ssa_id, key) for key, ssa_id in remaining.items()]
queue = [(compute_size_by_dict(key & output, sizes), ssa_id, key) for key, ssa_id in remaining.items()]
heapq.heapify(queue)
_, ssa_id1, k1 = heapq.heappop(queue)
while queue:
_, ssa_id2, k2 = heapq.heappop(queue)
ssa_path.append((min(ssa_id1, ssa_id2), max(ssa_id1, ssa_id2)))
k12 = (k1 | k2) & output
cost = helpers.compute_size_by_dict(k12, sizes)
cost = compute_size_by_dict(k12, sizes)
ssa_id12 = next(ssa_ids)
_, ssa_id1, k1 = heapq.heappushpop(queue, (cost, ssa_id12, k12))

Expand Down Expand Up @@ -807,12 +808,12 @@ def _dp_compare_flops(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, xn
3. If the intermediate tensor corresponding to ``s`` is going to break the
memory limit.
"""
cost = cost1 + cost2 + helpers.compute_size_by_dict(i1_union_i2, size_dict)
cost = cost1 + cost2 + compute_size_by_dict(i1_union_i2, size_dict)
if cost <= cost_cap:
s = s1 | s2
if s not in xn or cost < xn[s][1]:
i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
mem = helpers.compute_size_by_dict(i, size_dict)
mem = compute_size_by_dict(i, size_dict)
if memory_limit is None or mem <= memory_limit:
xn[s] = (i, cost, (cntrct1, cntrct2))

Expand All @@ -825,14 +826,43 @@ def _dp_compare_size(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, xn,
"""
s = s1 | s2
i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
mem = helpers.compute_size_by_dict(i, size_dict)
mem = compute_size_by_dict(i, size_dict)
cost = max(cost1, cost2, mem)
if cost <= cost_cap:
if s not in xn or cost < xn[s][1]:
if memory_limit is None or mem <= memory_limit:
xn[s] = (i, cost, (cntrct1, cntrct2))


def _dp_compare_write(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, xn, g, all_tensors, inputs,
i1_cut_i2_wo_output, memory_limit, cntrct1, cntrct2):
"""Like ``_dp_compare_flops`` but sieves the potential contraction based
on the total size of memory created, rather than the number of
operations, and so calculates that first.
"""
s = s1 | s2
i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
mem = compute_size_by_dict(i, size_dict)
cost = cost1 + cost2 + mem
if cost <= cost_cap:
if s not in xn or cost < xn[s][1]:
if memory_limit is None or mem <= memory_limit:
xn[s] = (i, cost, (cntrct1, cntrct2))


def _dp_compare_combo(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, xn, g, all_tensors, inputs,
i1_cut_i2_wo_output, memory_limit, cntrct1, cntrct2):
s = s1 | s2
i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
mem = compute_size_by_dict(i, size_dict)
f = compute_size_by_dict(i1_union_i2, size_dict)
cost = cost1 + cost2 + f + 256 * mem
if cost <= cost_cap:
if s not in xn or cost < xn[s][1]:
if memory_limit is None or mem <= memory_limit:
xn[s] = (i, cost, (cntrct1, cntrct2))


def simple_tree_tuple(seq):
"""Make a simple left to right binary tree out of iterable ``seq``.
Expand Down Expand Up @@ -899,21 +929,10 @@ class DynamicProgramming(PathOptimizer):
slow down the path finding considerably on all but very small graphs.
"""
def __init__(self, minimize='flops', cost_cap=True, search_outer=False):

# set whether inner function minimizes against flops or size
self.minimize = minimize
self._check_contraction = {
'flops': _dp_compare_flops,
'size': _dp_compare_size,
}[self.minimize]

# set whether inner function considers outer products
self.search_outer = search_outer
self._check_outer = {
False: lambda x: x,
True: lambda x: True,
}[self.search_outer]

self.cost_cap = cost_cap

def __call__(self, inputs, output, size_dict, memory_limit=None):
Expand Down Expand Up @@ -953,6 +972,17 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
>>> o(i_all, set(), s)
[(1, 2), (0, 4), (1, 2), (0, 2), (0, 1)]
"""
_check_contraction = {
'flops': _dp_compare_flops,
'size': _dp_compare_size,
'write': _dp_compare_write,
'combo': _dp_compare_combo,
}[self.minimize]
_check_outer = {
False: lambda x: x,
True: lambda x: True,
}[self.search_outer]

ind_counts = Counter(itertools.chain(*inputs, output))
all_inds = tuple(ind_counts)

Expand Down Expand Up @@ -1005,7 +1035,7 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
# output index dimensions as initial cost_cap
subgraph_inds = set.union(*_bitmap_select(g, inputs))
if self.cost_cap is True:
cost_cap = helpers.compute_size_by_dict(subgraph_inds & output, size_dict)
cost_cap = compute_size_by_dict(subgraph_inds & output, size_dict)
elif self.cost_cap is False:
cost_cap = float('inf')
else:
Expand All @@ -1028,12 +1058,12 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
i1_cut_i2_wo_output = (i1 & i2) - output

# maybe ignore outer products:
if self._check_outer(i1_cut_i2_wo_output):
if _check_outer(i1_cut_i2_wo_output):

i1_union_i2 = i1 | i2
self._check_contraction(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2,
xn, g, all_tensors, inputs, i1_cut_i2_wo_output,
memory_limit, cntrct1, cntrct2)
_check_contraction(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2,
xn, g, all_tensors, inputs, i1_cut_i2_wo_output,
memory_limit, cntrct1, cntrct2)

if (cost_cap > naive_cost) and (len(x[-1]) == 0):
raise RuntimeError("No contraction found for given `memory_limit`.")
Expand All @@ -1043,7 +1073,7 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):

i, cost, contraction = list(x[-1].values())[0]
subgraph_contractions.append(contraction)
subgraph_contractions_size.append(helpers.compute_size_by_dict(i, size_dict))
subgraph_contractions_size.append(compute_size_by_dict(i, size_dict))

# sort the subgraph contractions by the size of the subgraphs in
# ascending order (will give the cheapest contractions); note that
Expand Down

0 comments on commit 37e9573

Please sign in to comment.