In [1]:
with open('input.txt', 'r') as fl:
    patterns = [s.strip() for s in next(fl).split(',')]
    next(fl)
    towels = [ln.strip() for ln in fl]

In [2]:
print('patterns:',len(patterns),'towels:', len(towels))

patterns: 447 towels: 400


In [3]:
class SuffixTreeNode:
    def __init__(self, label = "", start = -1, end = -1):
        self.children = {}  # Edges to children nodes
        self.label = label
        self.start = start
        self.end = end
        self.parents = []
        self.n_root_parents = 0


In [4]:
class SuffixTree:
    def __init__(self, text, patterns):
        self.text = text  # Append a unique terminator to the string
        self.root = SuffixTreeNode('root', end=len(text))
        self.root.n_root_parents = 1
        # the vocab that's in the string at all
        self.patterns = [p for p in patterns if p in text] 
        self.node_cache = {len(text):[self.root]}        
        self.build_tree()

    def build_tree(self):
        """
        Naive suffix tree construction: Insert all suffixes one by one.
        """
        for i in range(len(self.text)-1, -1, -1):
            self._insert_suffix(self.text[i:], i)
            
    def _insert_suffix(self, suffix, start_index):
        """
        Inserts a suffix into the suffix tree.
        """
        # try to figure out if my vocab fit into the current suffix
        for gram in self.patterns:
            n = len(gram)
            # does the current point start with this gram?
            if suffix.startswith(gram):
                # does the end of this gram bump something in the tree already?
                new_node = SuffixTreeNode(label = gram, start=start_index, end=start_index+n)
                #if len(self.node_cache) <20:
                #    print(start_index,n,len(self.text),gram,self.text, suffix)
                next_nodes = self.node_cache.get(start_index+n)
                if next_nodes:
                    for next_node in next_nodes:
                        new_node.n_root_parents += next_node.n_root_parents
                        next_node.children[gram] = new_node
                        new_node.parents.append(next_node)
                        new_node.end = max(new_node.end, next_node.end)
                node_cache = self.node_cache.get(start_index)
                self.node_cache[start_index] = node_cache + [new_node] if node_cache else [new_node]
    


In [5]:
# part 1
total = 0 
for towel in towels:
    test = SuffixTree(towel, patterns)
    if 0 in test.node_cache and sum(node.n_root_parents for node in test.node_cache[0]) > 0:
        total += 1
print(total)

302


In [6]:
# part 2
total = 0
for towel in towels:
    test = SuffixTree(towel, patterns)
    total += sum(node.n_root_parents for node in test.node_cache[0]) if 0 in test.node_cache else 0
total

771745460576799