In [1]:
from collections import defaultdict, namedtuple, Counter

def make_suffix_tree(text):
    text += "$"
    trie = dict()
    node_counter = 0
    for i, _ in enumerate(text):
        node_iter = 0
        for c in text[i:]:
            if node_iter not in trie:
                node_counter += 1
                trie[node_iter] = {c: node_counter}
                node_iter += 1
            else:
                if c in trie[node_iter]:
                    node_iter = trie[node_iter][c]
                else:
                    node_counter += 1
                    trie[node_iter][c] = node_counter
                    node_iter = node_counter
    return trie

def iterate_trie(trie, start=0):
    for label, nid in trie.get(start,{}).items():
        print(f"{start}->{nid}:{label}")
        iterate_trie(trie, nid)

In [60]:
BLUE = 0
RED = 1
PURPLE = 2

class Node(object):
    def __init__(self, children, position, label, color=-1):
        self.children = children
        self.position = position
        self.length = len(label)
        self.label = label
        self.color = color
        
    def __repr__(self):
        children = "-".join(self.children.keys())
        return f"{self.label}|{self.color}|{self.position}|{self.length}:{children}"
              
    def get_color(self):
        if self.color != -1:
            return self.color
        else:
            colors = [child.get_color() for child in self.children.values()]
            if len(set(colors)) > 1:
                self.color = PURPLE
            elif len(colors)>0:
                color_sum = sum(colors)
                if color_sum == 0:
                    self.color = BLUE
                elif color_sum == len(colors):
                    self.color = RED
                elif color_sum == 2*len(colors):
                    self.color = PURPLE
                else:
                    raise ValueError(f"all colors should be same value {colors}")            
            elif self.node.label[-1] == "$": #leaf
                pass
            else:
                raise ValueError(f"{self} should be leaf")
                
            return self.color

class SuffixTree(object):
    #position of leaf node is the start position of the suffix 
    
    #Node = namedtuple('Node', ["children","position","label"])
    stop_symbol = "$"
    comparison_symbol = "#"
    
    def __init__(self, text):
        self.text = text + "$"
        self.separating_position = -1
        
        self._root = Node({}, -1, "")
        
        for i, s in enumerate(self.text):
            self._add_suffix(self.text[i:], i)
            if s == self.comparison_symbol:
                self.separating_position = i
            
        self.compact()
        if self.separating_position != -1:
            #print("coloring nodes")
            self.color_nodes()
        
    def _color_leafs(self, node):
        for child in node.children.values():
            if child.label[-1] == self.stop_symbol:
                #print(f"coloring leaf {child}")
                if child.position <= self.separating_position:
                    child.color = BLUE
                else:
                    child.color = RED
            else:
                self._color_leafs(child)
                
    def color_nodes(self):
        self._color_leafs(self._root)
        return self._root.get_color()
        
    def _add_suffix(self, suffix, start_position):
        self._current_node = self._root
        for i, symbol in enumerate(suffix):
            if symbol in self._current_node.children:
                self._current_node = self._current_node.children[symbol]
            else:
                if symbol == self.stop_symbol:
                    self._add_leaf(start_position)
                else:
                    newNode = self._add_symbol(symbol, start_position+i)
                    self._current_node = newNode
        
    def _add_symbol(self, symbol, position):
        newNode = Node({}, position, symbol)
        self._current_node.children[symbol] = newNode
        return newNode
    
    def _add_leaf(self, start_position):
        newNode = Node({}, start_position, self.stop_symbol)
        self._current_node.children[self.stop_symbol] = newNode
         
    def thread(self, pattern):
        self._current_node = self._root
        return self._thread(pattern, 0)
            
    def _thread(self, pattern, position):
        if position == len(pattern):
            return self._current_node
            
        symbol = pattern[position]
        self._current_node = self._current_node.children.get(symbol)
        if self._current_node is None:
            return None
        
        step = min(len(self._current_node.label), len(pattern)-position)
        if pattern[position:position+step] == self._current_node.label[:step]:
            position += step
            if position == len(pattern):
                return self._current_node
        else:
            return None        
        
        return self._thread(pattern, position)
    
    def compact(self):
        self._compact(self._root)
    
    def _compact(self, node):
        if len(node.children) == 1:
            child = next(iter(node.children.values()))
            node.label += child.label
            node.length = len(node.label)
            node.children = child.children
            if child.label == self.stop_symbol:
                node.position = child.position
            self._compact(node=node)
        else:
            for label, child in node.children.items():
                self._compact(node=child)
            
    def labels(self):
        node_labels = []
        self._labels(self._root, node_labels)
        return node_labels
    
    def _labels(self, node, node_labels=[]):
        for label, child in node.children.items():
            node_labels.append(child.label)
            self._labels(child, node_labels)
            
    def get_leafs(self, node):
        if node.label!="" and node.label[-1] == self.stop_symbol:
            return 1
        else:
            nleafs = 0
            for child in node.children.values():
                nleafs += self.get_leafs(child)
            return nleafs
            
    def find_longest_repeat(self):
        substring_count = Counter()
        self._find_longest_repeat(self._root, substring_count)
        return sorted(filter(lambda tpl: tpl[1]>=2, substring_count.items()), 
               key=lambda tpl: len(tpl[0]), reverse=True)[0][0]
    
    def _find_longest_repeat(self, node, substring_count, substring=""):
        for child in node.children.values():
            nleafs = self.get_leafs(child)
            substring_count[substring + child.label] = nleafs
            self._find_longest_repeat(child, substring_count, substring + child.label)
    
    def _find_common_substrings(self, node, substring="", common_strings=set()):
        if node.color != PURPLE:
            return
        
        stop = True
        for child in node.children.values():
            if child.color == PURPLE:
                self._find_common_substrings(child, substring + child.label, common_strings)
                stop = False
                
        if stop:
            common_strings.add(substring)
            
    def find_common_substrings(self):
        common_strings = set()
        self._find_common_substrings(self._root, "", common_strings)
        return common_strings
        
    def find_longest_substring(self):
        #todo remove special symbols from string if they appear in there
        substrings = self.find_common_substrings()
        return sorted([(s, len(s)) for s in substrings], key=lambda tpl: tpl[1])[-1][0]
    
    """
    def find_non_common_substrings(self, in_first_string=True):
        color = BLUE if in_first_string else RED
        non_common_strings = set()
        self._find_non_common_substrings(self._root, color, "", non_common_strings)
        return non_common_strings
        
    def _find_non_common_substrings(self, node, color, substring="", non_common_strings=set()):
        for child in node.children.values():
            if child.label[-1] == self.stop_symbol:
                if node.color == color:
                    s = substring + child.label
                    non_common_strings.add(s)
                    return
            else:
                self._find_non_common_substrings(child, color, substring+child.label, non_common_strings)
                
    def find_shortest_nonshared_substring(self, in_first_string=True):
        substrings = self.find_non_common_substrings(in_first_string)
        return sorted([(s, len(s)) for s in substrings], key=lambda tpl: tpl[1])[0][0]
    """
    
    def find_shortest_nonshared_substring(self, text):
        for 

In [62]:
suffix_tree = SuffixTree("panamabananas$")

In [66]:
#print(suffix_tree.thread("nam"))
print(suffix_tree.thread("nam"))
print(suffix_tree.thread("ananas"))
print(suffix_tree.thread("test"))
print(suffix_tree.thread("mab"))
print(suffix_tree.thread("an"))
print(suffix_tree.thread("ana"))
print(suffix_tree.thread("nas"))

mabananas$|-1|2|10:
nas$|-1|7|4:
None
mabananas$|-1|4|10:
na|-1|2|2:m-n-s
na|-1|2|2:m-n-s
s$|-1|10|2:


In [53]:
suffix_tree.find_longest_substring()

'ana'

In [119]:
suffix_tree.get_leafs(suffix_tree._root)

14

In [120]:
suffix_tree.find_longest_repeat()

'ana'

In [74]:
suffix_tree = SuffixTree("ATAAATG")
suffix_tree.compact()
labels = []
suffix_tree.labels()
labels = sorted(labels)

In [30]:
expected = sorted(["AAATG$","G$","T","ATG$","TG$","A","A","AAATG$","G$","T","G$","$"])
expected

['$', 'A', 'A', 'AAATG$', 'AAATG$', 'ATG$', 'G$', 'G$', 'G$', 'T', 'T', 'TG$']

In [31]:
labels == expected

False

In [32]:
with open("../data/dataset_296_4.txt") as fin:
    input_string = fin.read().strip()[:-1]
    suffix_tree = SuffixTree(input_string)
    suffix_tree.compact()
    labels = suffix_tree.labels()
    for l in labels:
        print(l)

C
C
C
C
TCAACAGCTGATGTGTGACTGCATCGTCTTCTCGGCGTGTCCCCGAGTCGCGGACTGGGGACAAAGAGCCAAGTGAGTATACTCTGGGATCATATCGCAGCACCCGTGCACTAATGCCTATATACAATTGCCTAAATTTACCCCCAGGGCGTACTAGAGAGTTAGTAGCACAGTCCGTGCAGAGTTAGGATGCCGCCTCGTAGAGTAGGCCCTATGAGGTTGCGTGCCTTAACAGTATGATACGGGAACCTATTCCAGTTTTCAGCGCCATCTTCATTTGCCGATCCACCAGAAGAGAAACGAAAACCCGACCCATCGCAGGAGCTTGGGGCACTCCCACTACCACGACGCAGCTGCTTAAGGACGATTTAAGCACTTGTTAGTGCACCATCGAGATGTTCCTGTCACGCCCGCAACCATCGATGCGGTAAAGGGGCAGGCCTGTAACGTGGTCGCAAACCATAGTGGAATTTTAACTCGTCACACAATTTCCTATCTCCTAAGCGGGAACTTGTAGTATCAGCCTGTAATCTGTCAGCCTGGACCCTTATTACTGACGGCGTAATGGTATCGGTCCCCGGTGGATTACGACCGTGTGAGTCAAGTGAGAACTTTTGACGGCGCATACTCGTCGCATTCCTGTTACACTAAGCGAGTCCACAGTCCGTCCGTGATGGTCTTGCAAATTGGTTCAATCAGCCGCGAAAACTGGAACAAATAGTCCAGCTACCACAACGGTTGGAATACCCATCAAACGGTAGGTTTAGGGTAATTAAGAGTCGGAAGCCATAGAGTGCGACAAATTCAACCCGGAACCCGATTGCGGCCCCAGCGCATCCTGACGTAACTCAGCCAGTAAGACCTGGCACCTCCTCTATGAATGGACATAGCTAATGGTGTGGGCAGCGGGGATCGGAAATGTTCCTCTCAAGAAATCCTTCGCGGACAGACGTGGTCTGGCTCATACTCTAC$
G
AGTCGCGGACTGGGGA

In [121]:
#find longest repeat in a string
text = "ATATCGTTTTATCGTT"

In [123]:
suffix_tree = SuffixTree(text)
suffix_tree.find_longest_repeat()

'TATCGTT'

In [124]:
with open("../data/dataset_296_5.txt") as fin:
    text = fin.read().strip()
    suffix_tree = SuffixTree(text)
    print(suffix_tree.find_longest_repeat())

AACACTCTGTGTACCCGAAGTCATAGGTGTCGACGGGATACAACCGGCGAGACCGTCTAGAGAATAGTACAGGAATCTGAG


In [54]:
text = "TCGGTAGATTGCGCCCACTC#AGGGGCTCGCAGTGTAAGAA"
suffix_tree = SuffixTree(text)
print(suffix_tree.find_longest_substring())

AGA


In [55]:
#find longest common string
with open("../data/dataset_296_6.txt") as fin:
    text = [line.strip() for line in fin]
    text = "#".join(text)
    suffix_tree = SuffixTree(text)
    print(suffix_tree.find_longest_substring())

ATTGTGCCC


In [61]:
#find shortest non-shared string
text = "CCAAGCTGCTAGAGG#CATGCTGGGCTGGCT"
suffix_tree = SuffixTree(text)
print(suffix_tree.find_non_common_substrings())

{'AGCTGCTAGAGG#CATGCTGGGCTGGCT$'}
