Skip to content

Commit

Permalink
Merge branch 'sssr'
Browse files Browse the repository at this point in the history
  • Loading branch information
stsouko committed Feb 13, 2019
2 parents 6dbf4a5 + 36c8b8e commit 5fddd16
Show file tree
Hide file tree
Showing 2 changed files with 329 additions and 33 deletions.
132 changes: 99 additions & 33 deletions CGRtools/algorithms/sssr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,59 +16,125 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, see <https://www.gnu.org/licenses/>.
#
from itertools import combinations
from itertools import combinations, product
from networkx import shortest_simple_paths, NetworkXNoPath, number_connected_components
from ..cache import cached_property


class SSSR:
""" SSSR calculation. based on idea from:
""" SSSR calculation. based on idea of PID matrices from:
Lee, C. J., Kang, Y.-M., Cho, K.-H., & No, K. T. (2009).
A robust method for searching the smallest set of smallest rings with a path-included distance matrix.
Proceedings of the National Academy of Sciences of the United States of America, 106(41), 17355–17358.
http://doi.org/10.1073/pnas.0813040106
"""
@cached_property
def sssr(self):
"""
SSSR search.
adj = self._adj
if len(adj) < 3:
return []

n_sssr = self.bonds_count - len(self) + 1

atoms = {x for x, y in adj.items() if len(y)} # ignore isolated atoms
terminals = {x for x, y in adj.items() if len(y) == 1}
if terminals:
bubble = terminals
while True:
bubble = {y for x in bubble for y in adj[x].keys() - terminals if len(adj[y].keys() - terminals) < 2}
if not bubble:
break
terminals.update(bubble)
atoms.difference_update(terminals) # skip not-cycle chains

:return: list of lists of rings nodes
"""
n_sssr = self.bonds_count - len(self) + number_connected_components(self)
if not n_sssr:
if not atoms:
return []

pid1 = {}
pid2 = {}
terminated = {}
tail = atoms.pop()
next_stack = {x: [[tail, x]] for x in adj[tail].keys() & atoms}

for ij in combinations(self, 2):
paths = shortest_simple_paths(self, *ij) # slowest part of algorithm
try: # stop if path not reachable
path = next(paths)
except NetworkXNoPath:
break
while True:
next_front = set()
found_odd = set()
stack, next_stack = next_stack, {}
for broom in stack.values():
tail = broom[0][-1]
next_front.add(tail)
neighbors = adj[tail].keys() & atoms
if len(neighbors) == 1:
n = neighbors.pop()
if n in found_odd:
continue
next_broom = [branch + [n] for branch in broom]
if n in stack: # odd rings
found_odd.add(tail)
if n in next_stack:
next_stack[n].extend(next_broom)
else:
next_broom.extend(stack[n])
terminated[n] = next_stack[n] = next_broom
elif n in next_stack: # even rings
next_stack[n].extend(next_broom)
if n not in terminated:
terminated[n] = next_stack[n]
else:
next_stack[n] = next_broom
elif neighbors:
for n in neighbors:
if n in found_odd:
continue
next_broom = [[tail, n]]
for branch in broom:
next_broom.append(branch + [n])
if n in stack: # odd rings
found_odd.add(tail)
if n in next_stack:
next_stack[n].extend(next_broom)
else:
next_broom.extend(stack[n])
terminated[n] = next_stack[n] = next_broom
elif n in next_stack: # even rings
next_stack[n].extend(next_broom)
if n not in terminated:
terminated[n] = next_stack[n]
else:
next_stack[n] = next_broom

pid1[ij] = [path]
ls = len(path)
lsp = ls + 1
atoms.difference_update(next_front)
if not atoms:
break
elif not next_stack:
n_sssr += 1
tail = atoms.pop()
next_stack = {x: [[tail, x]] for x in adj[tail].keys() & atoms}

pid1 = {}
pid2 = {}
for j, paths in terminated.items():
for path in paths:
lp = len(path)
if lp == ls:
pid1[ij].append(path)
elif lp == lsp:
if ij not in pid2:
pid2[ij] = [path]
else:
pid2[ij].append(path)
i = path[0]
k = (i, j)
if k in pid1:
ls = len(pid1[k][0])
lp = len(path)
if lp == ls:
pid1[k].append(path)
elif ls - lp == 1:
pid2[k], pid1[k] = pid1[k], [path]
elif lp - ls == 1:
pid2[k].append(path)
elif lp < ls:
pid1[k] = [path]
pid2[k] = []
else:
break
pid1[k] = [path]
pid2[k] = []

c_set = []
for ij, p1ij in pid1.items():
for k, p1ij in pid1.items():
dij = len(p1ij[0]) - 1
p2ij = pid2.get(ij)
p2ij = pid2[k]

if not p2ij and len(p1ij) == 1:
continue
Expand All @@ -83,9 +149,9 @@ def sssr(self):
for c_num, p1ij, p2ij in sorted(c_set):
if c_num % 2:
c1 = p1ij[0]
cs1 = set(c1)
c11 = c1[1]
for c2 in p2ij:
if len(cs1.intersection(c2)) == 2:
if c11 != c2[1]:
c = c1 + c2[-2:0:-1]
ck = tuple(sorted(c))
if ck not in c_sssr:
Expand All @@ -95,7 +161,7 @@ def sssr(self):
return list(c_sssr.values())
else:
for c1, c2 in zip(p1ij, p1ij[1:]):
if len(set(c1).intersection(c2)) == 2:
if c1[1] != c2[1]:
c = c1 + c2[-2:0:-1]
ck = tuple(sorted(c))
if ck not in c_sssr:
Expand Down

0 comments on commit 5fddd16

Please sign in to comment.