In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#hide
from nbdev.showdoc import *

In [41]:
from typing import List


class TrieNode:
    def __init__(self, terminal=False):
        self.children = [None]*256
        self.terminal = terminal
        
    def find_words(self, prefix):
        r = []
        if self.terminal:
            r.append(prefix)
        for i, c in enumerate(self.children):
            if c is None:
                continue
            ch = chr(i)
            r.extend(c.find_words(prefix+ch))
        return r
            
    def __str__(self):
        children = [chr(i) for i, c in enumerate(self.children) if c is not None]
        return f"(children:{','.join(children)}, terminal:{self.terminal})"
    
    
def print_trie(r: TrieNode, level=0):
    if r is None:
        return
    
    prefix = '-'*2*level
    print(f"{prefix}{r}")
    for child in r.children:
        if child is None:
            continue
        print_trie(child, level+1)
    
    
def insert(root: TrieNode, s:str) -> TrieNode:
    sz = len(s)
    if sz <= 0:
        return root
    
    if root is None:
        root = TrieNode()
    t = root
    for c in s:
        ind = ord(c)
        if t.children[ind] is None:
            t.children[ind] = TrieNode()
        t = t.children[ind]
    
    t.terminal = True
    return root


def query(root: TrieNode, prefix:str) -> List[str]:
    sz = len(prefix)
    if sz <= 0:
        return []
    t = root
    for c in prefix:
        ind = ord(c)
        if t.children[ind] is None:
            return []
        t = t.children[ind]
    return t.find_words(prefix)
    
    

In [46]:
root = None
root = insert(root, "was")
root = insert(root, "word")
root = insert(root, "war")
root = insert(root, "what")
root = insert(root, "where")
str(root)

'(children:w, terminal:False)'

In [47]:
print_trie(root)

(children:w, terminal:False)
--(children:a,h,o, terminal:False)
----(children:r,s, terminal:False)
------(children:, terminal:True)
------(children:, terminal:True)
----(children:a,e, terminal:False)
------(children:t, terminal:False)
--------(children:, terminal:True)
------(children:r, terminal:False)
--------(children:e, terminal:False)
----------(children:, terminal:True)
----(children:r, terminal:False)
------(children:d, terminal:False)
--------(children:, terminal:True)


In [50]:
query(root, "wh")

['what', 'where']