In [76]:
class TrieNode:
    def __init__(self):
        self.children = {}
    
    def is_empty(self):
        return len(self.children) == 0

In [103]:
class Trie:
    def __init__(self):
        self.root_node = TrieNode()
    
    def insert(self, word):
        current_node = self.root_node
        
        for letter in word:
            if letter not in current_node.children:
                current_node.children[letter] = TrieNode()
            current_node = current_node.children[letter]
        
        current_node.children["*"] = None
    
    def search(self, word):
        current_node = self.root_node
        
        for letter in word:
            if letter in current_node.children:
                current_node = current_node.children[letter]
            else:
                return None
        
        return current_node
    
    def collect_all_words(self):
        return self.collect_node_words(self.root_node)
    
    def collect_node_words(self, node, word="", words=None):
        if words is None:  
            words = []
        
        for letter in node.children:
            if letter == "*":
                words.append(word)
            else:
                self.collect_node_words(node.children[letter], word + letter, words)
        
        return words
    
    def autocomplete(self, prefix):
        found_node = self.search(prefix)
        if found_node is None:
            return None
        
        return [prefix + suffix for suffix in self.collect_node_words(found_node)]

    def print(self):
        self.print_node(self.root_node, "")
    
    def print_node(self, node, filler):
        if node is None:
            return
        
        if node.is_empty():
            print("Empty")
            return
        
        *all_but_last, last = node.children.keys()
        
        for letter in all_but_last:
            print(filler, end="")
            print(f"|---{letter}")
            self.print_node(node.children[letter], filler=filler + "|   ")
        
        print(filler, end="")
        print(f"L---{last}")
        self.print_node(node.children[last], filler=filler + "    ")
    
    def traverse_and_print(self):
        if self.root_node.is_empty():
            print("Empty trie")
            return
        
        self.traverse_and_print_node(self.root_node)
    
    def traverse_and_print_node(self, node):
        if node is None:
            return
        
        for letter in node.children:
            print(f"{letter}, ", end="")
            self.traverse_and_print_node(node.children[letter])
    
    # solution from book
    def traverse(self, node=None):
        currentNode = node or self.root_node
        for key, childNode in currentNode.children.items():
            print(key)
            if key != "*":
                self.traverse(childNode)
    
    def autocorrect(self, word):
        current_node = self.root_node
        prefix = ""
        
        for letter in word:
            if letter in current_node.children:
                current_node = current_node.children[letter]
                prefix += letter
            else:
                break
        
        if prefix == word:
            return [word]
        
        return [prefix + suffix for suffix in self.collect_node_words(current_node)]
    
    @staticmethod
    def from_list(words):
        t = Trie()
        for word in words:
            t.insert(word)
        return t
            

In [78]:
t = Trie()
t.insert("cat")
t.insert("can")
t.print()

L---c
    L---a
        |---t
        |   L---*
        L---n
            L---*


In [79]:
t = Trie.from_list(["ace", "act", "bad", "bake", "bat", "batter", "cab", "cat", "catnap", "catnip"])
t.print()
t.collect_all_words()

|---a
|   L---c
|       |---e
|       |   L---*
|       L---t
|           L---*
|---b
|   L---a
|       |---d
|       |   L---*
|       |---k
|       |   L---e
|       |       L---*
|       L---t
|           |---*
|           L---t
|               L---e
|                   L---r
|                       L---*
L---c
    L---a
        |---b
        |   L---*
        L---t
            |---*
            L---n
                |---a
                |   L---p
                |       L---*
                L---i
                    L---p
                        L---*


['ace',
 'act',
 'bad',
 'bake',
 'bat',
 'batter',
 'cab',
 'cat',
 'catnap',
 'catnip']

In [80]:
t.collect_node_words(t.search("cat"))

['', 'nap', 'nip']

In [81]:
t.autocomplete("cat")

['cat', 'catnap', 'catnip']

In [82]:
# 1. List all the words stored in the following trie
t = Trie.from_list(["tag", "tan", "tank", "tap", "today", "total", "we", "well", "went"])
t.print()
t.collect_all_words()

|---t
|   |---a
|   |   |---g
|   |   |   L---*
|   |   |---n
|   |   |   |---*
|   |   |   L---k
|   |   |       L---*
|   |   L---p
|   |       L---*
|   L---o
|       |---d
|       |   L---a
|       |       L---y
|       |           L---*
|       L---t
|           L---a
|               L---l
|                   L---*
L---w
    L---e
        |---*
        |---l
        |   L---l
        |       L---*
        L---n
            L---t
                L---*


['tag', 'tan', 'tank', 'tap', 'today', 'total', 'we', 'well', 'went']

In [100]:
# 2. Draw a trie that stores the following words: “get,” “go,” “got,” “gotten,”
# “hall,” “ham,” “hammer,” “hill,” and “zebra.”
t = Trie.from_list(["get", "go", "got", "gotten", "hall", "ham", "hammer", "hill", "zebra"])
t.print()

|---g
|   |---e
|   |   L---t
|   |       L---*
|   L---o
|       |---*
|       L---t
|           |---*
|           L---t
|               L---e
|                   L---n
|                       L---*
|---h
|   |---a
|   |   |---l
|   |   |   L---l
|   |   |       L---*
|   |   L---m
|   |       |---*
|   |       L---m
|   |           L---e
|   |               L---r
|   |                   L---*
|   L---i
|       L---l
|           L---l
|               L---*
L---z
    L---e
        L---b
            L---r
                L---a
                    L---*


In [101]:
# 3. Write a function that traverses each node of a trie and prints each key,
# including all "*" keys.
t.traverse_and_print()

g, e, t, *, o, *, t, *, t, e, n, *, h, a, l, l, *, m, *, m, e, r, *, i, l, l, *, z, e, b, r, a, *, 

In [102]:
t.traverse()

g
e
t
*
o
*
t
*
t
e
n
*
h
a
l
l
*
m
*
m
e
r
*
i
l
l
*
z
e
b
r
a
*


In [92]:
# 4. Write an autocorrect function that attempts to replace a user’s typo with
# a correct word. The function should accept a string that represents text
# a user typed in. If the user’s string is not in the trie, the function should
# return an alternative word that shares the longest possible prefix with
# the user’s string.
t = Trie.from_list(["cat", "catnap", "catnip"])
print(t.autocorrect("catnar")) # ["catnap"]
print(t.autocorrect("caxasfdij")) # could return any of the words "cat", "catnap" and "catnip"
# If the user’s string is found in the trie, the function should just return the word itself.
print(t.autocorrect("cat")) # ["cat"]

['catnap']
['cat', 'catnap', 'catnip']
['cat']


In [108]:
# If the user’s string is found in the trie, the function should just return
# the word itself. This should be true even if the user’s text is not a complete
# word, as we’re only trying to correct typos, not suggest endings to the
# user’s prefix.
t = Trie.from_list(["catnap", "catnip"])
t.print()
print(t.autocorrect("cat"))

L---c
    L---a
        L---t
            L---n
                |---a
                |   L---p
                |       L---*
                L---i
                    L---p
                        L---*
['cat']
