Want to eventually add support for trie with word-level and maybe char-level
attributes. Ex: word embeddings, word frequencies, char->char transition probs,
parts of speech, etc.). Also experimenting with a slightly different interface
than the existing trie in htools. Note that names are the same so if you import
htools * in a notebook, things may get confusing. Might want to rename these.

In [1]:
from htools.core import listlike

In [18]:
word_dict = {
    'app': 18,
    'a': 6,
    'apple': 17,
    'about': 4,
    'able': 6,
    'zoo': 13,
    'zen': 11,
    'zesty': 14,
    'apply': 4,
    'cow': 18,
    'zigzag': 12
}

## Misc TODOs (still deciding whether to do these)

- support for char level features (e.g. prob of transitioning from "ap" -> "app". Have to think about if that use pattern is common enough to be needed.
- support non string seqs like in existing trie
- validate inputs like in existing trie
- support suffix trie, flip, etc.

In [163]:
class TrieNode:

    def __init__(self, edges=None, is_terminal=False, is_root=False,
                 **kwargs):
        self.edges = edges or {}
        self.is_terminal = is_terminal
        self.is_root = is_root
        self.kwarg_names = set(kwargs)
        self.set_kwargs(**kwargs)

    def set_kwargs(self, **kwargs):
        self.kwarg_names.update(kwargs.keys())
        self.__dict__.update(**kwargs)
        
    def get(self, key, default=None):
        return self.edges.get(key, default)

    def __contains__(self, char):
        return char in self.edges

    def __getitem__(self, char):
        return self.edges[char]

    def __setitem__(self, char, val):
        self.edges[char] = val
        
    def __delitem__(self, char):
        del self.edges[char]

    def __repr__(self):
        res = f'TrieNode(edges={list(self.edges)}, '\
              f'is_terminal={self.is_terminal}, ' \
              f'is_root={self.is_root}'
        if self.kwarg_names:
            kwarg_str = ', '.join(f'{kwarg}={getattr(self, kwarg)}'
                                  for kwarg in self.kwarg_names)
            res += ', ' + kwarg_str
        return res + ')'


class Trie:

    def __init__(self, vocab=None):
        self.root = TrieNode(is_root=True)
        # Preview is used in repr. This could become outdated if we delete
        # nodes from our trie, however.
        self._vocab_preview = []
        self._initialize(vocab)

    def _initialize(self, vocab):
        # Case 1: vocab is list/tuple. Must assign empty kwargs.
        if listlike(vocab):
            vocab = {word: {} for word in vocab}
        # Case 2: vocab is dict but values are not dicts. Assign default name.
        elif not isinstance(next(iter(vocab.values())), dict):
            vocab = {word: {'val': val} for word, val in vocab.items()}
        for i, (word, kwargs) in enumerate(vocab.items()):
            if i < 5:
                self._vocab_preview.append(word)
            self.add(word, **kwargs)
            
    def _find(self, word, add_missing=False):
        node = self.root
        for char in word:
            if char not in node:
                if add_missing:
                    node[char] = TrieNode()
                else:
                    # This is truthy but it means we can always check if the
                    # result is_terminal.
                    return TrieNode()
            node = node[char]
        return node

    def add(self, word, **kwargs):
        # These kwargs are associated with the whole word, e.g. if you want to
        # pass in word counts or word embeddings. Still need to implement 
        # support for character-level attributes if I want that (e.g. if we
        # want some kind of transition probability from 1 character to the
        # next).
        node = self._find(word, add_missing=True)
        node.is_terminal = True
        node.set_kwargs(**kwargs)

    def update(self, words):
        for word in words:
            self.add(word)
            
    def path(self, word, func=list.append, start=None):
        res = start or []
        # TODO: method to accumulate or reduce attrs along the path to a word.
        # Current pseudocode relies on unfinished coro/gen method I imagined.
        # E.g. the product or sequence of probs of moving from one char/word
        # to the next.
#         for node in self._iter_nodes(word):
#             res = func(res, node) or res
#         return res
            
    def __contains__(self, word):
        node = self._find(word, add_missing=False)
        return node.is_terminal
            
    def __repr__(self):
        return f'Trie({str(self._vocab_preview)[:-1] + ", ...]"})'
    
    # TODO: think about if I like this interface. Currently returns the node
    # corresponding to the input word, which is nice for getting word 
    # attributes. However, it means we can't cast the trie to a list directly.
    def __getitem__(self, word):
        node = self._find(word, add_missing=False)
        if not node.is_terminal:
            raise KeyError(f'Key "{word}" not found.')
        return node
    
    def __setitem__(self, word, val):
        self.add(word, val=val)
        
    def _keys(self, node=None, seq=None):
        node = node or self.root
        seq = seq or []
        if node.is_terminal:
            yield ''.join(seq)
        for char, new_node in node.edges.items():
            yield from self._keys(self, new_node, seq + [char])
            
    def keys(self):
        return list(self._keys())

In [169]:
# TODO - eventually want method that yields nodes as we add/search for a new
# word. Based on my coroutine/generator pattern. Still debugging.
def _find(self, word):
    node = self.root
    for i, char in enumerate(word):
        cur = yield node
        if cur:
            node = cur.get(char)
    # Recall that word of length n involves n+1 nodes, where the last 1 
    # is terminal. Note that this node may have edges, e.g. if we have both
    # "app" and "apple" in our trie and try to find the word "app".
    yield node
        
t = Trie(word_dict)
# print(t)

# coro = _find(t, 'apks')
# print(next(coro))
# for x in coro:
#     print('\nbefore send', x)
#     coro.send(x)
#     print('after send')

# Think this works on words that are present but not those that aren't.
coro = _find(t, 'apple')
x = next(coro)
while True:
    try:
        print('\nbefore send', x)
        x = coro.send(x)
        print('after send', x)
    except StopIteration:
        break


before send TrieNode(edges=['a', 'z', 'c'], is_terminal=False, is_root=True)
after send TrieNode(edges=['p', 'b'], is_terminal=True, is_root=False, val=6)

before send TrieNode(edges=['p', 'b'], is_terminal=True, is_root=False, val=6)
after send TrieNode(edges=['p'], is_terminal=False, is_root=False)

before send TrieNode(edges=['p'], is_terminal=False, is_root=False)
after send TrieNode(edges=['l'], is_terminal=True, is_root=False, val=18)

before send TrieNode(edges=['l'], is_terminal=True, is_root=False, val=18)
after send TrieNode(edges=['e', 'y'], is_terminal=False, is_root=False)

before send TrieNode(edges=['e', 'y'], is_terminal=False, is_root=False)
after send TrieNode(edges=[], is_terminal=True, is_root=False, val=17)

before send TrieNode(edges=[], is_terminal=True, is_root=False, val=17)


In [154]:
# Broken attempt at iterative _keys() method.
# def _dfs(self, node=None):
#     """
#     Traverse trie depth first until we hit a node with no edges.
#     At each step, add to our current sequence.
#     """
#     res = []
#     stack = [node or self.root]
#     cur = []
#     while stack:
#         node = stack.pop(-1)
#         print(node)
#         if not node.edges:
#             print('TERMINAL', cur)
#             cur.clear()
#             continue
#         for char in node.edges:
#             cur.append(char)
#         stack.extend(node.edges.values())


for word in _dfs(t):
    print(word)

a
app
apple
apply
about
able
zoo
zen
zesty
zigzag
cow
dog


In [156]:
for word in word_dict:
    assert word in t, f'Could not find word {word}.'
    assert word + 'ZZZ' not in t, f'Found unexpected word {word + "ZZZ"}.'

In [157]:
t['apple']

TrieNode(edges=[], is_terminal=True, is_root=False, val=17)

In [158]:
t['dog']

KeyError: 'Key "dog" not found.'

In [159]:
t['dog'] = 44

In [160]:
t['dog']

TrieNode(edges=[], is_terminal=True, is_root=False, val=44)

In [161]:
t.keys()

['a',
 'app',
 'apple',
 'apply',
 'about',
 'able',
 'zoo',
 'zen',
 'zesty',
 'zigzag',
 'cow',
 'dog']

In [162]:
list(t)

TypeError: 'int' object is not iterable