In [1]:
from random import randrange

In [2]:
test_words = [
    "bbb$",
    "aabbabd$",
    "ababcd$",
    "abaababaabaabaabab$",
    "".join([chr(randrange(97, 98)) for _ in range(1000)])]

with open("1997_714_head.txt", "r", encoding='UTF-8') as f:
    test_words.append(f.read())

test_words[-1] = test_words[-1] + '$'
test_words[-2] = test_words[-2] + '$'

In [3]:
def build_kmp_table(pattern):
    m = len(pattern)
    pi = [0 for _ in range(m)]
    k = 0
    for i in range(1, m):
        while k > 0 and pattern[i] != pattern[k]:
            k = pi[k-1]
        if pattern[i] == pattern[k]:
            k += 1
        pi[i] = k
    return pi

def kmp(pattern, text):
    m = len(pattern)
    pi = build_kmp_table(pattern)
    ans = []
    i = 0
    j = 0
    while j < len(text):
        while i > 0 and pattern[i] != text[j]:
            i = pi[i - 1]
        if pattern[i] == text[j]:
            i += 1
        if i == m:
            ans.append(j)
            i = pi[-1]
        j += 1
    return ans

In [4]:
from time import perf_counter

def timeit(func):
    def wrapper(*args, **kwargs):
        t1 = perf_counter()
        res = func(*args, **kwargs)
        t2 = perf_counter()
        return res, round(t2 - t1, 4)
    return wrapper

In [5]:
from abc import ABC, abstractmethod

class Tree(ABC):
    def __init__(self, text):
        self.text = text
        self.root = None

    @abstractmethod
    def find_word(self, word):
        pass

    @abstractmethod
    def initiate_trie(self, *args):
        pass

In [13]:
from math import inf

class Node:
    def __init__(self, begin=0, length=inf, link=None):
        self.begin = begin
        self.length = length
        self.edges = {}
        self.suffix = link

def add_word(root, word):
    cur = root
    for letter in word:
        if letter not in cur.edges:
            cur.edges[letter] = Node()
        cur = cur.edges[letter]

@timeit
def build_regular_trie(text):
    root = Node()
    for i in range(len(text)):
        add_word(root, text[i:])
    return root


def find_word_trie(root, word):
    cur = root
    for x in word:
        if x not in cur.edges:
            return False
        cur = cur.edges[x]
    return True

class RegularTrie(Tree):
    def find_word(self, word):
        return find_word_trie(self.root, word)

    def initiate_trie(self):
        self.root, time = build_regular_trie(self.text)
        return time

In [14]:
@timeit
def build_linked_trie(text):
    root = Node()
    deepest = root
    for letter in text:
        cur = deepest
        prev = deepest = cur.edges[letter] = Node(link=root)
        cur = cur.suffix
        while cur is not None:
            if letter not in cur.edges:
                cur.edges[letter] = Node(link=root)
            prev.suffix = cur.edges[letter]
            prev = cur.edges[letter]
            cur = cur.suffix
    return root

class LinkedTrie(RegularTrie):
    def initiate_trie(self):
        self.root, time = build_linked_trie(self.text)
        return time

In [15]:
def add_suffix(node, ind, text):
    length = 0
    edge = None
    for i in range(ind, len(text)):
        letter = text[i]
        if length == 0:
            if letter not in node.edges:
                node.edges[letter] = Node(i)
                return
            edge = node.edges[letter]
        if text[edge.begin + length] != letter:
            break
        length += 1
        if length >= edge.length:
            node = edge
            length = 0
    else:
        return

    new_node = Node(edge.begin, length)
    node.edges[text[edge.begin]] = new_node
    new_node.edges[text[edge.begin + length]] = edge
    new_node.edges[text[i]] = Node(i)
    edge.begin = edge.begin + length
    edge.length -= length


@timeit
def build_suffix_tree(text):
    root = Node()
    for i in range(len(text)):
        add_suffix(root, i, text)
    return root


def find_word(node, text, word):
    length = 0
    edge = None
    for letter in word:
        if length == 0:
            if letter not in node.edges:
                return False
            edge = node.edges[letter]
        if text[edge.begin + length] != letter:
            return False
        length += 1
        if length >= edge.length:
            node = edge
            length = 0
    return True


class SuffixTrie(Tree):
    def find_word(self, word):
        return find_word(self.root, self.text, word)

    def initiate_trie(self):
        self.root, time = build_suffix_tree(self.text)
        return time

In [25]:
class ActivePoint:
    def __init__(self, root, text):
        self.tree_root = root
        self.length = 0
        self.edge = 0
        self.node = root
        self.remainder = 0
        self.text = text

    def walk_down(self, node):
        if self.length >= node.length:
            self.edge += node.length
            self.length -= node.length
            self.node = node
            return True
        return False

@timeit
def ukkonen(text):
    root = Node(0, 0)
    root.suffix = root
    point = ActivePoint(root, text)

    for i, c in enumerate(text):
        prev = None
        point.remainder += 1
        while point.remainder > 0:
            if point.length == 0:
                point.edge = i
            if text[point.edge] not in point.node.edges:
                new_node = Node(i)
                point.node.edges[text[point.edge]] = new_node
                if prev is not None:
                    prev.suffix = point.node
                prev = point.node
            else:
                nxt = point.node.edges[text[point.edge]]
                if point.walk_down(nxt):
                    continue
                if text[nxt.begin + point.length] == c:
                    point.length += 1
                    if prev is not None:
                        prev.suffix = point.node
                    break

                split = Node(nxt.begin, point.length, root)
                point.node.edges[text[point.edge]] = split
                leaf = Node(i)
                split.edges[c] = leaf
                nxt.begin += point.length
                nxt.length -= point.length
                split.edges[text[nxt.begin]] = nxt
                if prev is not None:
                    prev.suffix = split
                prev = split

            point.remainder -= 1
            if point.node == root and point.length > 0:
                point.length -= 1
                point.edge = i - point.remainder + 1
            else:
                point.node = point.node.suffix

    return root

class SuffixTree(SuffixTrie):
    def initiate_trie(self):
        self.root, time = ukkonen(self.text)
        return time

In [26]:
def test_correctness(text, TreeType):
    tree = TreeType(text)
    build_time = tree.initiate_trie()
    print(f"Text size: {len(text)}, Build time: {build_time}")
    for i in range(100):
        left = randrange(len(text) - 1)
        right = randrange(left + 1, len(text))
        substring = text[left:right]
        if i % 2:
            m = len(substring) // 2
            substring = substring[:m] + 'a' + substring[m:]
        found = len(kmp(substring, text)) > 0
        assert tree.find_word(substring) == found, "Not OK"
    else:
        print("Tests passed\n")

In [27]:
for t in test_words:
    test_correctness(t, SuffixTree)

Text size: 4, Build time: 0.0
Tests passed

Text size: 8, Build time: 0.0001
Tests passed

Text size: 7, Build time: 0.0
Tests passed

Text size: 19, Build time: 0.0001
Tests passed

Text size: 1001, Build time: 0.0028
Tests passed

Text size: 2482, Build time: 0.0069
Tests passed



In [28]:
for t in test_words:
    test_correctness(t, RegularTrie)

Text size: 4, Build time: 0.0
Tests passed

Text size: 8, Build time: 0.0
Tests passed

Text size: 7, Build time: 0.0
Tests passed

Text size: 19, Build time: 0.0002
Tests passed

Text size: 1001, Build time: 0.0418
Tests passed

Text size: 2482, Build time: 4.8752
Tests passed



In [29]:
for t in test_words:
    test_correctness(t, SuffixTrie)

Text size: 4, Build time: 0.0
Tests passed

Text size: 8, Build time: 0.0
Tests passed

Text size: 7, Build time: 0.0
Tests passed

Text size: 19, Build time: 0.0
Tests passed

Text size: 1001, Build time: 0.0809
Tests passed

Text size: 2482, Build time: 0.0107
Tests passed



In [30]:
for t in test_words:
    test_correctness(t, LinkedTrie)

Text size: 4, Build time: 0.0
Tests passed

Text size: 8, Build time: 0.0
Tests passed

Text size: 7, Build time: 0.0
Tests passed

Text size: 19, Build time: 0.0001
Tests passed

Text size: 1001, Build time: 0.0663
Tests passed

Text size: 2482, Build time: 6.7894
Tests passed

