In [1]:
import re

import networkx as nx

from IPython.display import Image, display
from collections import defaultdict, Counter
from textblob import TextBlob
from itertools import combinations
from tqdm import tqdm

from litecoder.db import City

In [2]:
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
mpl.style.use('seaborn-muted')

In [3]:
def tokenize(text):
    return [str(t) for t in TextBlob(text).tokens]

In [4]:
def keyify(text):
    text = text.lower()
    text = re.sub('[^a-z0-9]', '', text)
    return text

In [5]:
class Token:
    
    def __init__(self, token, ignore_case=True, scrub_re='\.'):
        
        self.ignore_case = ignore_case
        self.scrub_re = scrub_re
        
        self.token = token
        self.token_clean = self._clean(token)
        
    def _clean(self, token):
        
        if self.ignore_case:
            token = token.lower()
            
        if self.scrub_re:
            token = re.sub(self.scrub_re, '', token)
            
        return token
    
    def __call__(self, input_token):
        return self._clean(input_token) == self.token_clean
    
    def __repr__(self):
        return '%s<%s>' % (self.__class__.__name__, self.token_clean)
    
    def __hash__(self):
        # TODO: Class identifier?
        return hash((self.token_clean, self.ignore_case, self.scrub_re))
    
    def __eq__(self, other):
        return hash(self) == hash(other)
    
    def label(self):
        return '<%s>' % self.token_clean
    
    def key(self):
        return keyify(self.token)

In [6]:
class GeoFSA(nx.MultiDiGraph):
    
    def __init__(self):
        super().__init__()
        self._next_id = 0
        
    def add_node(self, node, **kwargs):
        defaults = dict(final=set())
        super().add_node(node, **{**defaults, **kwargs})
        
    def add_edge(self, u, v, entity=None, **kwargs):
        """Ensure edges have non-empty entity sets.
        """
        defaults = dict(accept_fn=None, entities=set([entity]), label=None)
        kwargs = {**defaults, **kwargs}
        
        if not len(kwargs['entities']) > 0:
            raise RuntimeError('All edges must have a non-empty entity set.')
        
        super().add_edge(u, v, **{**defaults, **kwargs})
        
    def next_node(self):
        """Get next integer node id, counting up.
        """
        node = self._next_id
        
        self.add_node(node)
        self._next_id += 1
        
        return node
        
    def add_token(self, accept_fn, entity, parent=None, optional=False):
        """Register a token transition.
        """
        s1 = parent if parent else self.next_node()
        s2 = self.next_node()
        
        self.add_edge(
            s1, s2, entity,
            accept_fn=accept_fn,
            label=accept_fn.label(),
        )
        
        last_node = s2
        
        # Add skip links if optional.
        if optional:
            s3 = self.next_node()
            self.add_edge(s2, s3, entity, label='ε')
            self.add_edge(s1, s3, entity, label='ε')
            last_node = s3
        
        return last_node
    
    def set_final(self, state, entity):
        self.node[state]['final'].add(entity)

    def start_nodes(self):
        return [n for n in self.nodes() if self.in_degree(n) == 0]
    
    def inner_nodes(self):
        return [n for n in self.nodes() if self.out_degree(n) > 0]
        
    def end_nodes(self):
        return [n for n in self.nodes() if self.out_degree(n) == 0]
    
    def _merge_nodes(self, u, v):
        """Merge two leaf nodes.
        """
        # Add v finals -> u finals.
        self.node[u]['final'].update(g.node[v]['final'])
        
        # Redirect in edges.
        for s, _, data in g.in_edges(v, data=True):
            g.add_edge(s, u, **data)
            
        # Redirect out edges.
        for _, t, data in g.out_edges(v, data=True):
            g.add_edge(u, t, **data)
            
        self.remove_node(v)
        
    def _merge_edges(self, u, v, k1, k2):
        """Merge two edges between a pair of nodes.
        """
        # Add k2 entities -> k1 entities.
        self[u][v][k1]['entities'].update(self[u][v][k2]['entities'])
        
        self.remove_edge(u, v, k2)
        
    def _end_node_merge_key(self, node):
        """Build merge key for end node.
        """
        return frozenset([
            data.get('accept_fn')
            for _, _, data in self.in_edges(node, data=True)
        ])
    
    def _inner_node_merge_key(self, node):
        """Build merge key for inner node.
        """
        out_edges = frozenset([
            data.get('accept_fn')
            for _, _, data in self.out_edges(node, data=True)
        ])
        
        descendants = frozenset(nx.descendants(self, node))
        
        return (out_edges, descendants)
    
    def reduce_end_nodes(self):
        """Reduce all redundant end nodes.
        """
        seen = {}
        for v in self.end_nodes():
            
            key = self._end_node_merge_key(v)
            u = seen.get(key)
            
            if u:
                self._merge_nodes(u, v)  
            else:
                seen[key] = v
                
    def _reduce_inner_nodes_iter(self):
        """Perform one iteration of inner node reduction.
        """
        seen = {}
        for v in self.inner_nodes():
            
            key = self._inner_node_merge_key(v)
            u = seen.get(key)
            
            if u:
                self._merge_nodes(u, v)
            else:
                seen[key] = v
                    
    def reduce_inner_nodes(self):
        """Reduce inner nodes until no more merges are possible.
        """
        while True:
            nc1 = len(self.nodes)
            self._reduce_inner_nodes_iter()
            nc2 = len(self.nodes)
            if nc2 == nc1:
                break
                
    def _reduce_node_out_edges(self, node):
        """Reduce out edges from node.
        """
        out_edges = list(self.out_edges(node, data=True, keys=True))
        
        seen = {}
        for s, t, k2, data in out_edges:
            
            key = (t, data['accept_fn'])
            k1 = seen.get(key)
            
            if k1 is not None:
                self._merge_edges(s, t, k1, k2)

            else:
                seen[key] = k2
                
    def reduce_out_edges(self):
        """Reduce all out edges.
        """
        for node in self.nodes():
            self._reduce_node_out_edges(node)
            
    def start_index_kv_iter(self):
        """Generate key -> start node pairs.
        """
        for node in self.start_nodes():
            for _, _, data in self.out_edges(node, data=True):
                if data['accept_fn']:
                    yield data['accept_fn'].key(), node
                    
    def start_index(self):
        """Map key -> start nodes.
        """
        idx = defaultdict(list)
        for k, n in self.start_index_kv_iter():
            idx[k].append(n)
            
        return idx

In [7]:
def plot(g):
    dot = nx.drawing.nx_pydot.to_pydot(g)
    dot.set_rankdir('LR')
    display(Image(dot.create_png()))

In [10]:
g = GeoFSA()

for city in tqdm(City.query.filter(City.country_iso=='US').limit(20000)):
    
    entity = (City.__tablename__, city.wof_id)
    
    name_tokens = tokenize(city.name)
    state_tokens = tokenize(city.name_a1)
    
    # City name
    parent = None
    for token in name_tokens:
        parent = g.add_token(Token(token), entity, parent)
        
    # Optional comma
    comma = g.add_token(Token(','), entity, parent, optional=True)
    
    # State name
    parent = comma
    for token in state_tokens:
        parent = g.add_token(Token(token), entity, parent)

    g.set_final(parent, entity)
        
    # Or, state abbr
    leaf = g.add_token(Token(city.us_state_abbr), entity, comma)
    g.set_final(leaf, entity)


0it [00:00, ?it/s][A
1it [00:01,  1.27s/it][A
378it [00:01, 275.98it/s][A
745it [00:01, 506.88it/s][A
1072it [00:01, 654.34it/s][A
1356it [00:01, 780.20it/s][A
1724it [00:01, 937.85it/s][A
2082it [00:01, 1073.95it/s][A
2452it [00:02, 1202.69it/s][A
2807it [00:02, 1312.28it/s][A
3134it [00:02, 1335.30it/s][A
3502it [00:02, 1431.02it/s][A
3871it [00:02, 1519.55it/s][A
4240it [00:02, 1601.45it/s][A
4600it [00:02, 1674.07it/s][A
4956it [00:02, 1740.34it/s][A
5304it [00:03, 1714.78it/s][A
5667it [00:03, 1774.78it/s][A
6035it [00:03, 1832.47it/s][A
6380it [00:03, 1880.12it/s][A
6733it [00:03, 1927.29it/s][A
7072it [00:03, 1967.17it/s][A
7454it [00:03, 2017.26it/s][A
7828it [00:03, 2062.59it/s][A
8186it [00:04, 2010.53it/s][A
8483it [00:04, 2033.49it/s][A
8834it [00:04, 2067.99it/s][A
9187it [00:04, 2101.28it/s][A
9536it [00:04, 2132.26it/s][A
9880it [00:04, 2160.66it/s][A
10231it [00:04, 2189.25it/s][A
10575it [00:04, 2215.66it/s][A
10933it [00:04, 2243.71it/

In [11]:
g.reduce_end_nodes()

In [12]:
g.reduce_inner_nodes()

In [13]:
g.reduce_out_edges()

In [22]:
class Matcher:
    
    def __init__(self, fsa):
        self.fsa = fsa
        self._start_index = fsa.start_index()
        self._states = set()
        
    def _get_next_states(self, start_state, token, visited=None):
        
        if not visited:
            visited = set()
            
        visited.add(start_state)
        
        next_states = set()
        for _, state, data in self.fsa.out_edges(start_state, data=True):
            
            accept_fn = data['accept_fn']
            
            # If non-empty transition, evaluate input.
            if accept_fn:
                if accept_fn(token):
                    next_states.add(state)
                
            # Recurisvely resolve epsilons.
            elif state not in visited:
                next_states.update(self._get_next_states(state, token, visited))
                
        return next_states
    
    def __call__(self, token):

        if not self._states:
            self._states.update(self._start_index[keyify(token)])
        
        next_states = set()
        for state in self._states:
            next_states.update(self._get_next_states(state, token))
            
        self._states = next_states

        print(self._states)

In [23]:
m = Matcher(g)

In [17]:
%time m('new')

{116992, 127488, 8454, 85766, 62986, 41998, 41486, 40466, 127763, 128278, 45592, 103704, 108315, 26142, 45599, 65058, 802, 61731, 123178, 19756, 23085, 121904, 79922, 54072, 125240, 66106, 25659, 37948, 27709, 77115, 25154, 103746, 78406, 126023, 55366, 66886, 88650, 79180, 90189, 89420, 128845, 48208, 64081, 15698, 42067, 91219, 128600, 67416, 121948, 112734, 126303, 111712, 39779, 25188, 125285, 124772, 51559, 43624, 67176, 69738, 55144, 85865, 623, 53359, 49007, 54899, 110197, 47477, 12152, 86137, 52089, 50299, 81786, 25213, 47998, 86143, 88441, 83585, 128897, 79745, 39556, 80517, 49798, 118661, 128390, 57736, 37770, 109195, 85643, 49546, 87694, 101772, 72591, 47762, 50835, 128662, 42648, 44696, 107672, 50075, 125084, 108187, 54943, 48290, 40866, 50594, 51106, 110501, 125348, 102056, 125858, 126889, 125355, 99755, 69547, 82606, 431, 47792, 110510, 89778, 50608, 76208, 100022, 74375, 80313, 48826, 53435, 54459, 111038, 74945, 66242, 43457, 75714, 3524, 1990, 122306, 124867, 29898, 11

In [19]:
m('pine')
m('RIDGE')
m(',')
m('AL')

{512, 17408, 111749, 120326, 100999, 30474, 116875, 29586, 6932, 2455, 91545, 66588, 81055, 10656, 32, 21028, 26791, 27180, 29614, 80175, 107824, 79153, 12850, 81597, 101438, 33474, 100552, 114252, 25037, 12238, 103761, 98644, 103894, 15577, 29913, 108252, 49633, 31842, 29412, 79460, 32871, 12647, 65002, 13045, 26486, 5879, 505, 79103}
{2, 95164, 26461, 103762}
{3, 95165, 26462}
{6}


In [21]:
tuple()

()