In [114]:
import re

from collections import defaultdict
from textblob import TextBlob

In [115]:
def tokenize(text):
    return [str(t) for t in TextBlob(text).tokens]

In [17]:
def make_key(text):
    """Convert text -> normalized index key.
    """
    text = text.lower()
    text = text.strip()

    text = text.replace('.', '')
    text = re.sub('[,-]', ' ', text)
    text = re.sub('\s{2,}', ' ', text)

    return text

In [192]:
class TrieNode:
    
    def __init__(self):
        self.children = defaultdict(list)
        self.final = set()
        
    def __str__(self):
        """Index key.
        """
        raise NotImplementedError
        
    def __call__(self):
        """Accept fn.
        """
        raise NotImplementedError
        
    def __hash__(self):
        """Used for pairwise equality checks.
        """
        raise NotImplementedError
        
    def __eq__(self, other):
        return type(self) == type(other) and hash(self) == hash(other)
    
    def __add__(self, other):
        self.final.update(other.final)
        return self

    def __getitem__(self, token):
        return [c for c in self.children[make_key(token)] if c(token)]
    
    def __len__(self):
        return sum([1 + len(n) for sibs in self.children.values() for n in sibs])
        
    def insert(self, children):

        head = children[0]

        key = make_key(str(head))

        merged = False
        for other in self.children[key]:
            if head == other:
                head = other + head
                merged = True
                break
                
        print(head, merged)

        if not merged:
            self.children[key].append(head)
            
        if len(children) > 1:
            head.insert(children[1:])
            
    def query(self, tokens):
        
        matches = self[tokens[0]]
        
        if len(tokens) == 1:
            for match in matches:
                yield from match.final
                
        elif len(tokens) > 1:
            for match in matches:
                yield from match.query(tokens[1:])

In [193]:
class RootNode(TrieNode):
    
    def insert(self, id, children):
        children[-1].final.add(id)
        super().insert(children)

In [194]:
 class Token(TrieNode):
    
    def __init__(self, token, ignore_case=True, scrub_re='\.'):
        
        super().__init__()
        
        self.ignore_case = ignore_case
        self.scrub_re = scrub_re
        
        self.token = token
        self.token_clean = self._clean(token)
        
    def _clean(self, token):
        
        if self.ignore_case:
            token = token.lower()
            
        if self.scrub_re:
            token = re.sub(self.scrub_re, '', token)
            
        return token
    
    def __str__(self):
        return self.token
    
    def __call__(self, input_token):
        return self._clean(input_token) == self.token_clean
    
    def __hash__(self):
        return hash((self.token_clean, self.ignore_case, self.scrub_re))
    
    def __repr__(self):
        return '%s<%s>' % (self.__class__.__name__, self.token)

In [207]:
idx = RootNode()

In [208]:
idx.insert(1, [Token('Los'), Token('Angeles'), Token('CA', ignore_case=False)])
idx.insert(1, [Token('Los'), Token('Angeles'), Token('California')])

Los False
Angeles False
CA False
Los True
Angeles True
California False


In [209]:
len(idx)

4

In [214]:
list(idx.query(tokenize('Los Angeles CA')))

[1]

In [211]:
list(idx.query(tokenize('Los Angeles California')))

[1]

In [212]:
list(idx.query(tokenize('Los Angeles ca')))

[]