## trie

In [1]:
from typing import Dict
from queue import LifoQueue

In [2]:
class TrieNode:

    def __init__(self, letter: str = "", parent: 'TrieNode' = None, depth: int = 0) -> None:
        super().__init__()
        self.letter: str = letter
        self.children: Dict[str, TrieNode] = dict()
        self.depth: int = depth
        self.parent: TrieNode = parent
        self.link: TrieNode = None

    def graft(self, text: str, sibling: 'TrieNode' = None) -> 'TrieNode':
        node = self
        for current_letter in list(text):
            if current_letter not in node.children:
                node.children[current_letter] = TrieNode(current_letter, node, node.depth + 1)
            node = node.children[current_letter]
            if sibling:
                sibling = sibling.children[current_letter]
                sibling.link = node
        return node

    def __contains__(self, item):
        node = self
        if not isinstance(item, str):
            return False
        while item:
            if not item[0] in node.children.keys():
                return False
            node = node.children[item[0]]
            item = item[1:]
        return True


class Trie:

    def __init__(self, text: str) -> None:
        self.root: TrieNode = TrieNode()
        leaf = self.root.graft(text)
        self.root.children[text[0]].link = self.root
        for i in range(1, len(text)):
            head, sibling = self.up_link_down(leaf)
            if not head:
                sibling = self.root.children[text[i - 1]]
                sibling.link = self.root
                head = self.root
            leaf = head.graft(text[i + head.depth:], sibling)

    def __contains__(self, item):
        return isinstance(item, str) and item in self.root

    def up_link_down(self, sibling: TrieNode) -> (TrieNode, TrieNode):
        letters = LifoQueue()
        while sibling and not sibling.link:
            letters.put(sibling.letter)
            sibling = sibling.parent
        if not sibling:
            return None, None
        node = sibling.link
        while not letters.empty():
            current_letter = letters.get()
            if current_letter in node.children.keys():
                node = node.children[current_letter]
                sibling = sibling.children[current_letter]
                sibling.link = node
            else:
                break
        return node, sibling

## suffix tree

In [3]:
class SuffixTreeNode:

    def __init__(self, text: str, start: int = 0, end: int = 0, depth: int = 0, parent: 'SuffixTreeNode' = None) -> None:
        super().__init__()
        self.depth = depth
        self.start = start
        self.end = end
        self.full_text = text
        self.children: Dict[str, SuffixTreeNode] = dict()
        self.parent: SuffixTreeNode = parent
        self.link: SuffixTreeNode = None

    def graft(self, start) -> 'SuffixTreeNode':
        start = start + self.depth
        text = self.full_text[start:]
        child = SuffixTreeNode(self.full_text, start, len(self.full_text), self.depth + len(text), self)
        self.children[text[0]] = child
        return child

    def break_path(self, text: str) -> 'SuffixTreeNode':
        length = len(text)
        child = self.children[text[0]]
        new_node = SuffixTreeNode(self.full_text, child.start, child.start + length, self.depth + length, self)
        child.start = child.start + length
        child.parent = new_node
        new_node.children[child.label[0]] = child
        self.children[text[0]] = new_node
        return new_node

    def fast_find(self, text: str) -> 'SuffixTreeNode':
        if len(text) == 0:
            return self
        child = self.children[text[0]]
        if len(child.label) < len(text):
            return child.fast_find(text[len(child.label):])
        elif len(child.label) == len(text):
            return child
        else:
            return self.break_path(text)

    def slow_find(self, text: str) -> 'SuffixTreeNode':
        if len(text) == 0 or text[0] not in self.children.keys():
            return self
        child = self.children[text[0]]
        for i in range(1, len(child.label)):
            if child.label[i] != text[i]:
                return self.break_path(text[:i])
        return child.slow_find(text[len(child.label):])

    @property
    def label(self):
        return self.full_text[self.start:self.end]

    def __contains__(self, item):
        if len(item) == 0:
            return True
        if not isinstance(item, str) or item[0] not in self.children:
            return False
        child = self.children[item[0]]
        for i in range(1, min(len(child.label), len(item))):
            if child.label[i] != item[i]:
                return False
        return len(item) < len(child.label) or item[len(child.label):] in child

    def __repr__(self) -> str:
        return f"[{self.start}:{self.end}] {self.full_text[self.start:self.end]}"


class SuffixTree:

    def __init__(self, text: str, slow_mode = False) -> None:
        self.root = last_head = SuffixTreeNode(text)
        leaf = self.root.graft(0)
        if slow_mode:
            for i in range(1, len(text)):
                head = self.root.slow_find(text[i:])
                head.graft(i)
        else:
            for i in range(1, len(text)):
                if last_head == self.root:
                    last_head = self.root.slow_find(leaf.label[1:])
                    leaf = last_head.graft(i)
                    continue
                parent = last_head.parent
                if parent == self.root:
                    link = parent.fast_find(last_head.label[1:])
                else:
                    link = parent.link.fast_find(last_head.label)
                if len(link.children) == 1:
                    head = link
                else:
                    head = link.slow_find(leaf.label)
                leaf = head.graft(i)
                last_head.link = link
                last_head = head

    def __contains__(self, item):
        return isinstance(item, str) and item in self.root

## test

In [4]:
from time import perf_counter
def time_eval(func, args, w_print=False, name=None, count=10):
    start = perf_counter()
    for i in range(count):
        func(*args)
    end = perf_counter()
    avg = (end-start)/count
    if w_print:
        print(f"{name} average time: {avg}")
    else:
        return avg


with open('1997_714.txt', 'r') as file:
    text = file.read()
    text = text[:2000]
    text += '\0'
    time_eval(Trie, [text], True, "Trie",count=1)
    time_eval(SuffixTree, [text], True, "McCreight",count=1)
    time_eval(SuffixTree, [text, False], True, "Slow McCreight",count=1)


Trie average time: 9.797388134997163
McCreight average time: 0.011338224001519848
Slow McCreight average time: 0.010572726001555566
