Skip to content

Commit

Permalink
update TensorNetwork.gen_inds_loops
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jul 22, 2024
1 parent aeca738 commit 3982ec2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 31 deletions.
53 changes: 22 additions & 31 deletions quimb/tensor/tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9831,12 +9831,13 @@ def gen_loops(self, max_loop_length=None):
hg = get_hypergraph(inputs, accel="auto")
return hg.compute_loops(max_loop_length=max_loop_length)

def gen_inds_loops(self, max_loop_length=None):
def gen_inds_loops(self, max_loop_length=None, intersect=False):
"""Generate all sequences of indices, up to a specified length, that
represent loops in this tensor network. Unlike ``gen_loops`` this
function will return the indices of the tensors in the loop rather
than the tensor ids, allowing one to differentiate between e.g. a
double loop and a 'figure of eight' loop.
double loop and a 'figure of eight' loop. Dangling and hyper indices
are ignored.
Parameters
----------
Expand All @@ -9853,41 +9854,28 @@ def gen_inds_loops(self, max_loop_length=None):
--------
gen_loops, gen_inds_connected
"""

def _normalize_loop(seq):
# this returns the lexicographically smallest equivalent
# sequence, up to rolling (rotating) and reversing.
N = len(seq)
i = min(enumerate(seq), key=lambda x: x[1])[0]
el_prev = seq[(i - 1) % N]
el_next = seq[(i + 1) % N]
if el_prev > el_next:
return (*seq[i:], *seq[:i])
else:
return (*seq[i::-1], *seq[-1:i:-1])

queue = []
for ind, tids in self.ind_map.items():
# initial starting points - we store both index and tid to keep track
# of direction properly, (only need one direction initially)
# initial starting points - we store both index and tid to keep
# track of direction properly, (only need one direction initially)
queue.append(((ind, next(iter(tids))),))

seen = set()
while queue:
s = queue.pop(0)
last_ind, last_tid = s[-1]

# get the other connecting tid and tensor
# XXX: this will break for dangling and hyper indices
(next_tid,) = (
tid for tid in self.ind_map[last_ind] if tid != last_tid
)
last_tids = self.ind_map[last_ind]
if len(last_tids) != 2:
# ignore dangling and hyper indices
continue

(next_tid,) = (tid for tid in last_tids if tid != last_tid)
next_t = self.tensor_map[next_tid]

# candidate expansions
expansions = [
nind for nind in next_t.inds
if nind != last_ind
]
expansions = [nind for nind in next_t.inds if nind != last_ind]

current_inds = {x[0] for x in s}
current_tids = {x[1] for x in s}
Expand All @@ -9897,20 +9885,23 @@ def _normalize_loop(seq):

if next_pair == s[0]:
# finished a loop! - normalize it to check for duplicates
loop = _normalize_loop([x[0] for x in s])
if loop not in seen:
seen.add(loop)
# loop = key = _normalize_loop([x[0] for x in s])
loop = tuple(x[0] for x in s)
key = frozenset(loop)

if key not in seen:
seen.add(key)
if max_loop_length is None:
max_loop_length = len(loop)
yield loop

elif (
# don't double up on indices
(nind not in current_inds) and
# and don't self intersect? XXX: make this a switch?
(next_tid not in current_tids) and
# and optionally avoid self intersection
(intersect or (next_tid not in current_tids)) and
# and don't make the loop too long
(max_loop_length is None) or (len(s) < max_loop_length)
((max_loop_length is None) or (len(s) < max_loop_length))
):
# valid candidate extension!
queue.append(s + (next_pair,))
Expand Down
14 changes: 14 additions & 0 deletions tests/test_tensor/test_tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,6 +1876,20 @@ def test_gen_inds_loops(self):
loops = tuple(tn.gen_inds_loops())
assert len(loops) == 6

def test_gen_inds_loops_intersect(self):
tn = qtn.TN2D_empty(5, 4, 2)
loops = tuple(tn.gen_inds_loops(8, False))
na = len(loops)
assert na == len(frozenset(loops))
assert na == len(frozenset(map(frozenset, loops)))

loops = tuple(tn.gen_inds_loops(8, True))
nb = len(loops)
assert nb == len(frozenset(loops))
assert nb == len(frozenset(map(frozenset, loops)))

assert nb > na

def test_gen_inds_connected(self):
tn = qtn.TN2D_rand(3, 4, 2)
patches = tuple(tn.gen_inds_connected(2))
Expand Down

0 comments on commit 3982ec2

Please sign in to comment.