In [1]:
%load_ext Cython

In [None]:
%%cython
cdef class TrieNode():
    """
    Trie node 
    """
    cdef: 
        public str letter
        public list children 
        public bint is_word
        public long counter

    def __init__(self, letter):
        self.letter = letter
        self.children = []
        self.is_word = False
        self.counter = 1


In [5]:
node = TrieNode('a')

In [6]:
node.letter, node.children, node.is_word, node.counter

('a', [], False, 1)

In [7]:
%timeit node.counter

49.6 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)


In [77]:
%%cython

cdef class TrieNode:
    """
    Trie node 
    """
    cdef: 
        readonly str letter
        public list children 
        public bint is_word
        public long counter

    def __init__(self, letter):
        self.letter = letter
        self.children = []
        self.is_word = False
        self.counter = 1

        
class Trie(object):

    def __init__(self):
        self.is_trained = False 
        self.root = TrieNode('*')

    @property
    def words(self):
        words = []        
        node = self.root
        self._iterate_until_leave(node, prefix='', words=words)
        return words

    def fit(self, words):
        for word in words:
            self.add_word(word)

    def add_word(self, word):
        node = self.root
        for char in word:
            found_in_children, node = self.__char_in_children_update(char, node)
            if not found_in_children:
                new_node = TrieNode(char)
                node.children.append(new_node)
                node = new_node

        node.is_word = True

    def word_count(self, word: str):
        """Check if a word is in a trie
        """
        node = self.root
        word_counts = 0
        word_finished = False
        for letter in word:
            found_in_children, node = self._check_letter_in_children(letter, node)
            word_counts = node.counter
            word_finished = node.is_word        
            if not found_in_children:
                return 0
            
        return word_counts * word_finished


    def find_words_with_prefix(self, prefix):
        words = []
        root_pref = self._prefix_node_in_trie(self.root, prefix)
        if root_pref:
            self._iterate_until_leave(root_pref, prefix, words)
        
        return words

    def _iterate_until_leave(self, node, prefix, words):
        if node.is_word:
            words.append(prefix)
        for child in node.children:
            self._iterate_until_leave(child, prefix + child.letter, words)

    def _prefix_node_in_trie(self, root, word: str):
        """find node that matches a prefix
        """
        node = root
        for letter in word:
            found_in_children, node = self._check_letter_in_children(letter, node)
            if not found_in_children:
                return False
            
        return node

    def _check_letter_in_children(self, letter, node):
        """Update `node` and `found_in_children` flag variable and return them.
        
        If `char` is in a children of `node` return the matching node and  `found_in_children=True`.
        Otherwise, return the incoming node and `found_in_children=False`
        """
        found_in_children = False
        for node_children in node.children:
            if node_children.letter == letter:
                node = node_children         # modifies node
                found_in_children = True
                break
        return found_in_children, node

    def __char_in_children_update(self, letter, node):
        """Update `node` and `found_in_children` flag variable and return them.
        
        If `letter` is in a children of `node` return the matching node and  `found_in_children=True`.
        Otherwise, return the incoming node and `found_in_children=False`
        """
        found_in_children = False
        for node_children in node.children:
            if node_children.letter == letter:
                node_children.counter += 1
                node = node_children         # modifies node
                found_in_children = True
                break
        return found_in_children, node


In [9]:
words = ["have", 'has', 'money', 'have', 'having', 'havana']

In [10]:
autocompleter = Trie()
autocompleter.fit(words)
autocompleter.words

['have', 'having', 'havana', 'has', 'money']

In [11]:
%timeit autocompleter.find_words_with_prefix('hav')

2.85 µs ± 307 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [12]:
import sklearn

from sklearn.datasets import fetch_20newsgroups

dataset = fetch_20newsgroups()

import re
token_pattern = r"(?u)\b\w\w+\b"
token_pattern = re.compile(token_pattern)

autocompleter = Trie()




In [13]:
%%time
for doc in dataset.data:
    words = token_pattern.findall(doc)
    autocompleter.fit(words)

CPU times: user 13.1 s, sys: 127 ms, total: 13.2 s
Wall time: 13.4 s


In [14]:
%%time
len(autocompleter.words)

CPU times: user 182 ms, sys: 12.1 ms, total: 194 ms
Wall time: 197 ms


155448

In [15]:
%timeit aux = autocompleter.find_words_with_prefix('hou')

22.2 µs ± 1.58 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### Cythonizing the Trie class

Note that the previous Trie did not specify almost anything in Cython, now let's type the code
and see that we can make `autocompleter.find_words_with_prefix` 4x faster

In [78]:
%%cython

cdef class TrieNode:
    """
    Trie node 
    """
    cdef: 
        str letter
        list children 
        bint is_word
        long counter

    def __init__(self, letter):
        self.letter = letter
        self.children = []
        self.is_word = False
        self.counter = 1

        
cdef class Trie01(object):
    cdef:
        public bint is_trained
        public TrieNode root
        
    def __init__(self):
        self.is_trained = False 
        self.root = TrieNode('*')

    @property
    def words(self):
        words = []        
        node = self.root
        self._iterate_until_leave(node, prefix='', words=words)
        return words

    cpdef fit(self, list words):
        cdef unicode word
        
        for word in words:
            self.add_word(word)

    cpdef add_word(self, unicode word):
        cdef:
            TrieNode node, new_node
            unicode char
            bint found_in_children
            
        node = self.root
        for char in word:
            found_in_children, node = self.__char_in_children_update(char, node)
            if not found_in_children:
                new_node = TrieNode(char)
                node.children.append(new_node)
                node = new_node

        node.is_word = True

    cpdef word_count(self, word: str):
        """
        Check how many times a word is in the Trie.
        """
        cdef:
            TrieNode new_node, node=self.root
            bint found_in_children, word_finished=False
            long word_counts = 0
            unicode letter

        for letter in word:
            found_in_children, node = self._check_letter_in_children(letter, node)
            word_counts = node.counter
            word_finished = node.is_word        
            if not found_in_children:
                return 0
            
        return word_counts * word_finished


    cpdef find_words_with_prefix(self, unicode prefix):
        cdef:
            list words = []
            TrieNode eroot_pref
            bint found_in_children
            
        root_pref, found_in_children = self._prefix_node_in_trie(self.root, prefix)
            
        if found_in_children == True:
            self._iterate_until_leave(root_pref, prefix, words)
        
        return words

    cdef inline _iterate_until_leave(self, TrieNode node, unicode prefix,list words):
        cdef TrieNode child 
        
        if node.is_word:
            words.append(prefix)
        for child in node.children:
            self._iterate_until_leave(child, prefix + child.letter, words)

    cdef inline _prefix_node_in_trie(self, TrieNode root, unicode word):
        """find node that matches a prefix
        """
        cdef:
            bint found_in_children
            TrieNode node = root

        for letter in word:
            found_in_children, node = self._check_letter_in_children(letter, node)
            if found_in_children == False:
                return node, found_in_children
            
        return node, found_in_children

    cdef inline _check_letter_in_children(self, unicode letter, TrieNode node):
        """Update `node` and `found_in_children` flag variable and return them.
        
        If `char` is in a children of `node` return the matching node and  `found_in_children=True`.
        Otherwise, return the incoming node and `found_in_children=False`
        """
        cdef:
            bint found_in_children = False
            TrieNode node_children

        for node_children in node.children:
            if node_children.letter == letter:
                node = node_children         # modifies node
                found_in_children = True
                break
        return found_in_children, node

    cdef inline __char_in_children_update(self, unicode letter, TrieNode node):
        """Update `node` and `found_in_children` flag variable and return them.
        
        If `letter` is in a children of `node` return the matching node and  `found_in_children=True`.
        Otherwise, return the incoming node and `found_in_children=False`
        """
        cdef:
            bint found_in_children = False
            TrieNode node_children


        found_in_children = False
        for node_children in node.children:
            if node_children.letter == letter:
                node_children.counter += 1
                node = node_children         # modifies node
                found_in_children = True
                break
        return found_in_children, node

In [79]:
autocompleter = Trie01()
words = ["have", 'has', 'money', 'have', 'having', 'havana']
autocompleter.fit(words)
autocompleter.words

['have', 'having', 'havana', 'has', 'money']

In [80]:
%%time
for doc in dataset.data:
    words = token_pattern.findall(doc)
    autocompleter.fit(words)

CPU times: user 4.97 s, sys: 47.2 ms, total: 5.02 s
Wall time: 5.04 s


In [81]:
%timeit aux = autocompleter.find_words_with_prefix('hou')

5.68 µs ± 157 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [83]:
autocompleter.find_words_with_prefix('hou')

['hou',
 'hour',
 'hours',
 'hourglass',
 'hourly',
 'house',
 'housekeeping',
 'housed',
 'houses',
 'household',
 'households',
 'housewares',
 'housewarming',
 'housewives',
 'houselights',
 'housing',
 'housings',
 'houston',
 'houghton',
 'hougen',
 'hould',
 'hoult',
 'hou281',
 'hound',
 'hounds',
 'hounding',
 'hounded',
 'houxa']