Skip to content

Commit

Permalink
further updates to new greedy/optimal paths
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Sep 5, 2023
1 parent 5636650 commit d000106
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 80 deletions.
111 changes: 77 additions & 34 deletions cotengra/pathfinders/path_basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import heapq
import bisect
import functools
import itertools

Expand Down Expand Up @@ -281,14 +282,14 @@ def __init__(self, inputs, output, size_dict):
if ix is None:
# index not processed yet
ix = self.indmap[ind] = c
self.edges[ix] = {i}
self.edges[ix] = {i: None}
self.appearances.append(1)
self.sizes.append(size_dict[ind])
c += 1
else:
# seen index already
self.appearances[ix] += 1
self.edges[ix].add(i)
self.edges[ix][i] = None
legs.append((ix, 1))

legs.sort()
Expand All @@ -303,12 +304,10 @@ def __init__(self, inputs, output, size_dict):
def neighbors(self, i):
"""Get all neighbors of node ``i``."""
# only want to yield each neighbor once and not i itself
seen = {i}
for ix, _ in self.nodes[i]:
for j in self.edges[ix]:
if j not in seen:
if j != i:
yield j
seen.add(j)

def print_current_terms(self):
return ",".join(
Expand All @@ -328,12 +327,15 @@ def pop_node(self, i):
the legs of the node.
"""
legs = self.nodes.pop(i)
for j, _ in legs:
es = self.edges[j]
if len(es) == 1:
del self.edges[j]
else:
self.edges[j].discard(i)
for ix, _ in legs:
try:
ix_nodes = self.edges[ix]
ix_nodes.pop(i, None)
if len(ix_nodes) == 1:
del self.edges[ix]
except KeyError:
# repeated index already removed
pass
return legs

def add_node(self, legs):
Expand All @@ -344,7 +346,7 @@ def add_node(self, legs):
self.ssa += 1
self.nodes[i] = legs
for j, _ in legs:
self.edges.setdefault(j, set()).add(i)
self.edges.setdefault(j, {})[i] = None
return i

def contract_nodes(self, i, j):
Expand All @@ -368,7 +370,6 @@ def simplify_batch(self):
if len(ix_nodes) >= len(self.nodes):
ix_to_remove.append(ix)
for ix in ix_to_remove:
# print("removing batch", ix)
self.remove_ix(ix)

def simplify_single_terms(self):
Expand Down Expand Up @@ -461,7 +462,12 @@ def subgraphs(self):
groups.sort()
return groups

def optimize_greedy(self, costmod=1.0, temperature=0.0):
def optimize_greedy(
self,
costmod=1.0,
temperature=0.0,
seed=None,
):
""" """

if temperature == 0.0:
Expand All @@ -472,16 +478,16 @@ def local_score(sa, sb, sab):
else:
from ..utils import GumbelBatchedGenerator
import numpy as np
import math

gmblgen = GumbelBatchedGenerator(np.random.default_rng())
rng = np.random.default_rng(seed)
gmblgen = GumbelBatchedGenerator(rng)

def local_score(sa, sb, sab):
score = sab - costmod * (sa + sb)
if score < 0:
return -math.log(-score) - temperature * gmblgen()
return -np.log(-score) - temperature * gmblgen()
else:
return math.log(score) - temperature * gmblgen()
return np.log(score) - temperature * gmblgen()

# return sab - costmod * (sa + sb) - temperature * gmblgen()

Expand Down Expand Up @@ -686,7 +692,33 @@ def optimize_remaining_by_size(self):
heapq.heappush(nodes_sizes, (ksize, k))


def optimize_simplify(inputs, output, size_dict):
def ssa_to_linear(ssa_path, N=None):
"""
Convert a path with static single assignment ids to a path with recycled
linear ids. For example:
```python
ssa_to_linear([(0, 3), (2, 4), (1, 5)])
#> [(0, 3), (1, 2), (0, 1)]
```
"""
if N is None:
N = sum(map(len, ssa_path)) - len(ssa_path) + 1

ids = list(range(N))
path = []
ssa = N
for scon in ssa_path:
con = sorted([bisect.bisect_left(ids, s) for s in scon])
for j in reversed(con):
ids.pop(j)
ids.append(ssa)
path.append(con)
ssa += 1
return path


def optimize_simplify(inputs, output, size_dict, use_ssa=False):
"""Find the (likely only partial) contraction path corresponding to
simplifications only. Those simplifiactions are:
Expand All @@ -704,17 +736,20 @@ def optimize_simplify(inputs, output, size_dict):
The indices of the output tensor.
size_dict : dict[str, int]
A dictionary mapping indices to their dimension.
use_ssa : bool, optional
Whether to return the contraction path in 'SSA' format (i.e. as if each
intermediate is appended to the list of inputs, without removals).
Returns
-------
ssa_path : list[list[int]]
The contraction path, given as a sequence of pairs of node indices in
'SSA' format (i.e. as if each intermediate is appended to the list of
inputs, without removals).
path : list[list[int]]
The contraction path, given as a sequence of pairs of node indices.
"""
cp = ContractionProcessor(inputs, output, size_dict)
cp.simplify()
return cp.ssa_path
if use_ssa:
return cp.ssa_path
return ssa_to_linear(cp.ssa_path)


def optimize_greedy(
Expand All @@ -724,6 +759,7 @@ def optimize_greedy(
costmod=1.0,
temperature=0.0,
simplify=True,
use_ssa=False,
):
"""Find a contraction path using a greedy algorithm.
Expand Down Expand Up @@ -761,21 +797,24 @@ def optimize_greedy(
Such simpifications may be required in the general case for the proper
functioning of the core optimization, but may be skipped if the input
indices are already in a simplified form.
use_ssa : bool, optional
Whether to return the contraction path in 'SSA' format (i.e. as if each
intermediate is appended to the list of inputs, without removals).
Returns
-------
ssa_path : list[list[int]]
The contraction path, given as a sequence of pairs of node indices in
'SSA' format (i.e. as if each intermediate is appended to the list of
inputs, without removals).
path : list[list[int]]
The contraction path, given as a sequence of pairs of node indices.
"""
cp = ContractionProcessor(inputs, output, size_dict)
if simplify:
cp.simplify()
cp.optimize_greedy(costmod=costmod, temperature=temperature)
# handle disconnected subgraphs
cp.optimize_remaining_by_size()
return cp.ssa_path
if use_ssa:
return cp.ssa_path
return ssa_to_linear(cp.ssa_path)


def optimize_optimal(
Expand All @@ -786,6 +825,7 @@ def optimize_optimal(
cost_cap=2,
search_outer=False,
simplify=True,
use_ssa=False,
):
"""Find the optimal contraction path using a dynamic programming
algorithm (by default excluding outer products).
Expand Down Expand Up @@ -838,13 +878,14 @@ def optimize_optimal(
Such simpifications may be required in the general case for the proper
functioning of the core optimization, but may be skipped if the input
indices are already in a simplified form.
use_ssa : bool, optional
Whether to return the contraction path in 'SSA' format (i.e. as if each
intermediate is appended to the list of inputs, without removals).
Returns
-------
ssa_path : list[list[int]]
The contraction path, given as a sequence of pairs of node indices in
'SSA' format (i.e. as if each intermediate is appended to the list of
inputs, without removals).
path : list[list[int]]
The contraction path, given as a sequence of pairs of node indices.
"""
cp = ContractionProcessor(inputs, output, size_dict)
if simplify:
Expand All @@ -854,4 +895,6 @@ def optimize_optimal(
)
# handle disconnected subgraphs
cp.optimize_remaining_by_size()
return cp.ssa_path
if use_ssa:
return cp.ssa_path
return ssa_to_linear(cp.ssa_path)
Loading

0 comments on commit d000106

Please sign in to comment.