Skip to content

Commit

Permalink
fixes and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Sep 5, 2023
1 parent d000106 commit 997772a
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 8 deletions.
3 changes: 2 additions & 1 deletion cotengra/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import autoray as ar

from .oe import get_path_fn, find_output_str
from .oe import get_path_fn
from .core import ContractionTree
from .utils import find_output_str


class Variadic:
Expand Down
15 changes: 12 additions & 3 deletions cotengra/pathfinders/path_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def pop_node(self, i):
try:
ix_nodes = self.edges[ix]
ix_nodes.pop(i, None)
if len(ix_nodes) == 1:
if len(ix_nodes) == 0:
del self.edges[ix]
except KeyError:
# repeated index already removed
Expand All @@ -345,10 +345,19 @@ def add_node(self, legs):
i = self.ssa
self.ssa += 1
self.nodes[i] = legs
for j, _ in legs:
self.edges.setdefault(j, {})[i] = None
for ix, _ in legs:
self.edges.setdefault(ix, {})[i] = None
return i

def check(self):
"""Check that the current graph is valid, useful for debugging."""
for node, legs in self.nodes.items():
for ix, _ in legs:
assert node in self.edges[ix]
for ix, ix_nodes in self.edges.items():
for node in ix_nodes:
assert ix in {jx for jx, _ in self.nodes[node]}

def contract_nodes(self, i, j):
"""Contract the nodes ``i`` and ``j``, adding a new node to the graph
and returning its index.
Expand Down
18 changes: 18 additions & 0 deletions cotengra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,24 @@ def eq_to_inputs_output(eq):
return inputs, output


def inputs_output_to_eq(inputs, output):
"""Convert an explicit list of inputs and output to a str einsum equation.
Parameters
----------
inputs : list[list[str]]
The input terms.
output : list[str]
The output term.
Returns
-------
eq : str
The einsum equation.
"""
return f"{','.join(map(''.join, inputs))}->{''.join(output)}"


def make_rand_size_dict_from_inputs(inputs, d_min=2, d_max=3, seed=None):
"""Get a random size dictionary for a given set of inputs.
Expand Down
5 changes: 2 additions & 3 deletions tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_rand_equation(
seed=seed,
)
arrays = [np.random.normal(size=s) for s in shapes]
eq = ",".join(map("".join, inputs)) + "->" + "".join(output)
eq = ctg.utils.inputs_output_to_eq(inputs, output)

path, info = oe.contract_path(eq, *arrays, optimize="greedy")
if info.largest_intermediate > 2**20:
Expand Down Expand Up @@ -253,8 +253,7 @@ def test_contract_expression(
sort_contraction_indices,
):
inputs, output, shapes, size_dict = ctg.utils.lattice_equation([4, 8])

eq = f"{','.join(map(''.join, inputs))}->{''.join(output)}"
eq = ctg.utils.inputs_output_to_eq(inputs, output)
arrays = [np.random.rand(*s) for s in shapes]
x0 = oe.contract(eq, *arrays)

Expand Down
23 changes: 22 additions & 1 deletion tests/test_paths_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_basic_rand(seed, which):
d_max=3,
seed=seed,
)
eq = ",".join(map("".join, inputs)) + "->" + "".join(output)
eq = ctg.utils.inputs_output_to_eq(inputs, output)

path = {
"greedy": pb.optimize_greedy,
Expand All @@ -137,6 +137,27 @@ def test_basic_rand(seed, which):
)


@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("which", ["greedy", "optimal"])
def test_basic_perverse(seed, which):
inputs, output, shapes, size_dict = ctg.utils.perverse_equation(
10, seed=seed
)
eq = ctg.utils.inputs_output_to_eq(inputs, output)
print(eq)
path = {
"greedy": pb.optimize_greedy,
"optimal": pb.optimize_optimal,
}[
which
](inputs, output, size_dict)
tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path)
arrays = [np.random.randn(*s) for s in shapes]
assert_allclose(
tree.contract(arrays), np.einsum(eq, *arrays, optimize=True)
)


def test_optimal_lattice_eq():
inputs, output, _, size_dict = ctg.utils.lattice_equation(
[4, 5], d_max=3, seed=42
Expand Down

0 comments on commit 997772a

Please sign in to comment.