Skip to content

Commit

Permalink
fix get_can_dot for output
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed May 20, 2024
1 parent 08a1ffb commit 1eef888
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions cotengra/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def __init__(
for ix in self.output:
self.appearances[ix] = self.appearances.get(ix, 0) + 1

#
# this stores potentialy preprocessing steps that are not part of the
# main contraction tree, but assumed to have been applied, for example
# tracing or summing over indices that appear only once
self.preprocessing = {}

# mapping of parents to children - the core binary tree object
Expand Down Expand Up @@ -738,6 +740,13 @@ def has_preprocessing(self):
self.get_legs(node)
return bool(self.preprocessing)

def has_hyper_indices(self):
"""Check if there are any 'hyper' indices in the contraction, i.e.
indices that don't appear exactly twice, when considering the inputs
and output.
"""
return any(ix_count != 2 for ix_count in self.appearances.values())

@cached_node_property("legs")
def get_legs(self, node):
"""Get the effective 'outer' indices for the collection of tensors
Expand Down Expand Up @@ -795,15 +804,7 @@ def get_can_dot(self, node):
"""
l, r = self.children[node]
sp, sl, sr = map(self.get_legs, (node, l, r))

srl_symmdiff = sl.copy()
for ix, ix_count in sr.items():
if ix in srl_symmdiff:
srl_symmdiff.pop(ix)
else:
srl_symmdiff[ix] = ix_count

return srl_symmdiff == sp
return set(sp) == set(sl).symmetric_difference(sr)

@cached_node_property("inds")
def get_inds(self, node):
Expand Down

0 comments on commit 1eef888

Please sign in to comment.