diff --git a/cotengra/interface.py b/cotengra/interface.py index b026dd8..df95820 100644 --- a/cotengra/interface.py +++ b/cotengra/interface.py @@ -2,8 +2,9 @@ import autoray as ar -from .oe import get_path_fn, find_output_str +from .oe import get_path_fn from .core import ContractionTree +from .utils import find_output_str class Variadic: diff --git a/cotengra/pathfinders/path_basic.py b/cotengra/pathfinders/path_basic.py index 53794c7..b5034e5 100644 --- a/cotengra/pathfinders/path_basic.py +++ b/cotengra/pathfinders/path_basic.py @@ -331,7 +331,7 @@ def pop_node(self, i): try: ix_nodes = self.edges[ix] ix_nodes.pop(i, None) - if len(ix_nodes) == 1: + if len(ix_nodes) == 0: del self.edges[ix] except KeyError: # repeated index already removed @@ -345,10 +345,19 @@ def add_node(self, legs): i = self.ssa self.ssa += 1 self.nodes[i] = legs - for j, _ in legs: - self.edges.setdefault(j, {})[i] = None + for ix, _ in legs: + self.edges.setdefault(ix, {})[i] = None return i + def check(self): + """Check that the current graph is valid, useful for debugging.""" + for node, legs in self.nodes.items(): + for ix, _ in legs: + assert node in self.edges[ix] + for ix, ix_nodes in self.edges.items(): + for node in ix_nodes: + assert ix in {jx for jx, _ in self.nodes[node]} + def contract_nodes(self, i, j): """Contract the nodes ``i`` and ``j``, adding a new node to the graph and returning its index. diff --git a/cotengra/utils.py b/cotengra/utils.py index 3d876e6..6bd3475 100644 --- a/cotengra/utils.py +++ b/cotengra/utils.py @@ -1082,6 +1082,24 @@ def eq_to_inputs_output(eq): return inputs, output +def inputs_output_to_eq(inputs, output): + """Convert an explicit list of inputs and output to a str einsum equation. + + Parameters + ---------- + inputs : list[list[str]] + The input terms. + output : list[str] + The output term. + + Returns + ------- + eq : str + The einsum equation. + """ + return f"{','.join(map(''.join, inputs))}->{''.join(output)}" + + def make_rand_size_dict_from_inputs(inputs, d_min=2, d_max=3, seed=None): """Get a random size dictionary for a given set of inputs. diff --git a/tests/test_compute.py b/tests/test_compute.py index 56a872a..612655a 100644 --- a/tests/test_compute.py +++ b/tests/test_compute.py @@ -134,7 +134,7 @@ def test_rand_equation( seed=seed, ) arrays = [np.random.normal(size=s) for s in shapes] - eq = ",".join(map("".join, inputs)) + "->" + "".join(output) + eq = ctg.utils.inputs_output_to_eq(inputs, output) path, info = oe.contract_path(eq, *arrays, optimize="greedy") if info.largest_intermediate > 2**20: @@ -253,8 +253,7 @@ def test_contract_expression( sort_contraction_indices, ): inputs, output, shapes, size_dict = ctg.utils.lattice_equation([4, 8]) - - eq = f"{','.join(map(''.join, inputs))}->{''.join(output)}" + eq = ctg.utils.inputs_output_to_eq(inputs, output) arrays = [np.random.rand(*s) for s in shapes] x0 = oe.contract(eq, *arrays) diff --git a/tests/test_paths_basic.py b/tests/test_paths_basic.py index a72169b..578f523 100644 --- a/tests/test_paths_basic.py +++ b/tests/test_paths_basic.py @@ -121,7 +121,7 @@ def test_basic_rand(seed, which): d_max=3, seed=seed, ) - eq = ",".join(map("".join, inputs)) + "->" + "".join(output) + eq = ctg.utils.inputs_output_to_eq(inputs, output) path = { "greedy": pb.optimize_greedy, @@ -137,6 +137,27 @@ def test_basic_rand(seed, which): ) +@pytest.mark.parametrize("seed", range(10)) +@pytest.mark.parametrize("which", ["greedy", "optimal"]) +def test_basic_perverse(seed, which): + inputs, output, shapes, size_dict = ctg.utils.perverse_equation( + 10, seed=seed + ) + eq = ctg.utils.inputs_output_to_eq(inputs, output) + print(eq) + path = { + "greedy": pb.optimize_greedy, + "optimal": pb.optimize_optimal, + }[ + which + ](inputs, output, size_dict) + tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path) + arrays = [np.random.randn(*s) for s in shapes] + assert_allclose( + tree.contract(arrays), np.einsum(eq, *arrays, optimize=True) + ) + + def test_optimal_lattice_eq(): inputs, output, _, size_dict = ctg.utils.lattice_equation( [4, 5], d_max=3, seed=42