## A* decoding and language modeling

Matthew Stone, CS 533, Spring 2019

This is a homework assignment: it asks to you to extend and adapt the $A^*$ segmentation code explained in class to do vowel restoration.  

This file demonstrates and explains the use of generic search techniques for decoding using a language model.  The code is based on an implementation of A$^*$ search, which generalizes in interesting ways to other problems.

A$^*$ search is a generic search method that provably finds the least-cost solution to a specified problem.  It assumes that you have two bits of information about a partial solution: you know the cost that you've incurred so far, and you have a heuristic which gives a lower bound on the cost you still have to incur to flesh it out into a complete solution.  By combining the cost and the heuristic, A$^*$ explores only the parts of the search space that look at least as good as the best solution.  In cases where the heuristic function provides a tight lower bound and no suboptimal steps need to be taken in the course of the search, A$^*$ only explores states that are part of optimal paths.  Wikipedia has [a comprehensive introduction to A$^*$][1].  You can also find [good visualizations of A$^*$][2], especially for path planning.

To think of segmentation as search, you can think of the decoding problem as finding the shortest path in an abstract graph.  The vertices in the graph correspond to hypotheses that explain the input symbols up to position $j$ and ending in word $w$.  When the search process reaches one of these vertices, it finds the best path ending in $(j, w)$.  

The analogy between probability and path planning depends on thinking of the cost of a path as its negative log probability.  By working with log probabilities rather than probabilities, we make the score associated with a path the sum of the scores associated with each edge.  We use negative log probabilities to make more likely steps smaller (since each step represents something that happens only a fraction of the time, log probabilities themselves are negative numbers, with strongly negative numbers representing extremely unlikely outcomes).

Document outline:
1. [Generic search implementation](#search)
1. [Basic bigram language model in NLTK](#bigrams)
1. [Your code to define search space and heuristic](#code) Fill in the code cells here.
1. [Best-first search implementation](#bfs)
1. [Best-first search examples and demonstrations](#bfseg)
1. [A-star search implementation](#astar)
1. [A-star search examples and demonstrations](#astareg)
1. [Analysis](#conc) Reflect on the difference between $A^*$ and BFS, including code cells as appropriate to back up your conclusions.

[1]:https://en.wikipedia.org/wiki/A*_search_algorithm
[2]:http://www.redblobgames.com/pathfinding/a-star/introduction.html


### Generic A$^*$ search implementation
<a id="search"/>

In [1]:
import heapq
import itertools
import math
import functools
import re
import numbers
import nltk
from nltk.corpus import brown

$A^*$ depends on a special priority queue data structure.  We need to be able to do a number of operations efficiently:
* Find the node with the least cost so far.  This is an essential operation of any priority queue.
* Add a new node to the queue with a specified priority.  Ditto.
* Find a record indicating the status of a state in the queue.  We need to answer whether there's already a node for this state in the queue, and find out what priority it currently has, and we also need to be able to tell whether we've explored a node for this state already.
* Replace the node for a given state with a different node with a new path and a lower priority.

Normally doing all these things efficiently and with a minimal amount of space requires some pointer manipulation that python doesn't easily support.  So the standard way of achieving this functionality in python is to use extra memory and amortize the costs of reprioritizing nodes by marking nodes as redundant (rather than deleting them) and then ignoring them at removal time.  There's a good explanation of the basic idea [in the python documentation][1].  This code is modified from there to add a bit of object-oriented cleanliness and to make explicit the distinction in $A^*$ between states and nodes, and between costs, heuristics, and priorities.

[1]:https://docs.python.org/2/library/heapq.html

In [2]:
class AStarQ(object) :
    """Priority Queue that distinguishes states and nodes,
       keeps track of priority as cost plus heuristic,
       enables changing the node and priority associated with a state,
       and keeps a record of explored states."""
    
    # state label for zombie nodes in the queue
    REMOVED = '<removed-state>'     

    # sequence count lets us keep heap order stable
    # despite changes to mutable objects 
    counter = itertools.count()     
    
    @functools.total_ordering
    class QEntry(object) :
        """AStarQ.QEntry objects package together
        state, node, cost, heuristic and priority
        and enable the queue comparison operations"""

        def __init__(self, state, node, cost, heuristic) :
            self.state = state
            self.node = node
            self.cost = cost
            self.heuristic = heuristic
            self.priority = cost + heuristic
            self.sequence = next(AStarQ.counter)
        
        def __le__(self, other) :
            return ((self.priority, self.sequence) <= 
                    (other.priority, other.sequence))
        
        def __eq__(self, other) :
            return ((self.priority, self.sequence) == 
                    (other.priority, other.sequence))
   
    def __init__(self) :
        """Set up a new problem with empty queue and nothing explored"""
        self.pq = []   
        self.state_info = {} 
        self.added = 0
        self.pushed = 0

    def add_node(self, state, node, cost, heuristic):
        """Add a new state or update the priority of an existing state
           Returns outcome (added or not) for visualization"""        
        self.added = self.added + 1
        if state in self.state_info:
            already = self.state_info[state].priority
            if already <= cost + heuristic :
                return False
            self.remove_state_entry(state)
        entry = AStarQ.QEntry(state, node, cost, heuristic)
        self.state_info[state] = entry
        heapq.heappush(self.pq, entry)
        self.pushed = self.pushed + 1
        return True

    def remove_state_entry(self, state):
        'Mark an existing task as REMOVED.  Raise KeyError if not found.'
        entry = self.state_info.pop(state)
        entry.state = AStarQ.REMOVED

    def pop_node(self):
        'Remove and return the lowest priority task. Raise KeyError if empty.'
        while self.pq:
            entry = heapq.heappop(self.pq)
            if entry.state is not AStarQ.REMOVED:
                return entry
        raise KeyError('pop from an empty priority queue')
        
    def statistics(self):
        return {'added': self.added, 'pushed': self.pushed}

When you're thinking about efficient implementation of algorithms, you need to be careful not to inadvertently handle incidental bookkeeping in a way that's going to incur asymptotic complexity penalities.  The treatment of history information in $A^*$ nodes is one of those cases.  The same path history needs to be able to grow efficiently into the paths for all of its children.  That basically rules out list structures in python; fast updates to a list will affect everybody; but copying lists is an operation that's linear time in the length of the input string.  To be efficient, you have to recreate lists using tuples (the way it's historically done in languages like LISP).  At the end, you'll have a history representation that looks like

$$(e, (t_n, (t_{n-1}, \ldots (s, \emptyset) \ldots )))$$

This `unpack_tuples` function gets you back something reasonable.

In [3]:
def unpack_tuples(t) :
    result = []
    while t :
        (a, t) = t
        result.append(a)
    result.reverse()
    return result

Unit test code so that the assignment can include example input and output of the functions that you should be writing, so you can check your work.

In [4]:
def match(v1, v2, tolerance) :
    if isinstance(v1, numbers.Number) :
        return isinstance(v2, numbers.Number) and v2 < v1 + tolerance and v1 < v2 + tolerance
    elif isinstance(v1, str) :
        return v1 == v2
    else: 
        try: 
            if len(v1) != len(v2) :
                return False
            return all(map(lambda x,y: match(x,y,tolerance), v1, v2))
        except TypeError:
            return v1 == v2
                       
def test(thunk, result) :
    r = thunk()
    if match(r, result, 1e-10) :
        print("test", thunk.__doc__, "passed")
    else:
        print("test", thunk.__doc__, "got {} instead of {} (failed)".format(str(r), str(result)))

### A simple bigram language model
<a id="bigrams"/>

We'll continue to use counts from the Brown corpus and a simple smoothing technique to associate bigrams with probabilities.

In [5]:
def unknown_prob(w) :
    return 1.e-6 / (5.**len(w))
alpha = 0.1

cfreq_brown_2gram = nltk.ConditionalFreqDist(nltk.bigrams(s.lower() for s in brown.words()))
cprob_brown_2gram = nltk.ConditionalProbDist(cfreq_brown_2gram, nltk.MLEProbDist)
freq_brown_1gram = nltk.FreqDist(s.lower() for s in brown.words())
len_brown = len(brown.words())
def unigram_prob(word):
    return freq_brown_1gram[word.lower()] / len_brown
def bigram_prob(word1, word2) :
    return cprob_brown_2gram[word1.lower()].prob(word2.lower())
def prob(word1, word2) :
    if word2 not in freq_brown_1gram:
        return unknown_prob(word2)
    elif not word1 or word1 not in freq_brown_1gram:
        return unigram_prob(word2)
    else:
        return alpha * unigram_prob(word2) + (1-alpha) * bigram_prob(word1, word2)
def score(word1, word2) :
    return -math.log(prob(word1,word2))

### Your new code
<a id="code"/>

### Step 1 Create a lexical resource that describes all the reasonable ways of restoring vowels.

Write a function `remove_vowels(s)` that takes a string `s` and returns a corresponding string with the letters a, e, i, o and u removed.

In [6]:
def remove_vowels(s) :
    return re.sub('[aeiou]', '', s)

In [7]:
def t1() :
    "removing vowels from {empty string}"
    return remove_vowels("")

test(t1, "")

def t2() :
    "removing vowels from oui"
    return remove_vowels("oui")
    
test(t2, "")

def t3() :
    "removing vowels from ohio"
    return remove_vowels("ohio")

test(t3, "h")

def t4() :
    "removing vowels from learning"
    return remove_vowels("learning")

test(t4, "lrnng")

test removing vowels from {empty string} passed
test removing vowels from oui passed
test removing vowels from ohio passed
test removing vowels from learning passed


Starting from the words in the Brown corpus (which you can obtain with `freq_brown_1gram.keys()`), create a dictionary `expansions` which describes all the ways to add vowels to tokens without vowels to create valid words.  In other words, if `k` is a string without vowels, then `expansions[k]` is defined if some English word maps to `k` when its vowels are removed, and `expansions[k]` is the set of English words that map to `k` when their vowels are removed.

In [8]:
expansions = dict()
for word_with_vowels in freq_brown_1gram.keys() :
    word_without_vowels = remove_vowels(word_with_vowels)
    if word_without_vowels in expansions :
        expansions[word_without_vowels].add(word_with_vowels)
    else:
        expansions[word_without_vowels] = {word_with_vowels}

In [9]:
def t5() :
    "expansions of {empty string}"
    return sorted(list(expansions['']))

test(t5, ['a', 'aa', 'aaa', 'ai', 'aia', 'aiee', 'e', 'i', 'io', 'o', 'oooo', 'oui', 'u'])

def t6() :
    "expansions of nw"
    return sorted(list(expansions['nw']))

test(t6, ['anew', 'naw', 'new', 'now', 'nw'])

def t7() :
    "expansions of xprmnt"
    return sorted(list(expansions['xprmnt']))

test(t7, ['experiment'])

test expansions of {empty string} passed
test expansions of nw passed
test expansions of xprmnt passed


Write a function `expand` that takes a list of tokens and returns a list of sets, where the set at position `i` in the result gives the expansions of token `i` in the input.

In [10]:
def expand(tokens) :
    def e(t) :
        if t in expansions:
            return expansions[t]
        else:
            return {t}
    return list(map(e, tokens))

In [11]:
def t8() :
    "expanding ['', 'nw', 'xprmnt']"
    return list(map(lambda s: sorted(list(s)), expand(['', 'nw', 'xprmnt'])))

test(t8, [['a', 'aa', 'aaa', 'ai', 'aia', 'aiee', 'e', 'i', 'io', 'o', 'oooo', 'oui', 'u'], 
          ['anew', 'naw', 'new', 'now', 'nw'],
          ['experiment']])

test expanding ['', 'nw', 'xprmnt'] passed


### Step Two: Compute heuristics

The essence of a heuristic function in $A^*$ search is to find a quantity that tracks the lowest cost that could possibly be incurred in completing a partial solution to the search problem.  In the case of vowel restoration, a partial solution will say how the first $k$ tokens should be transformed into valid words with vowels.  The heuristic should therefore track the best possible cost that might be incurred in transforming the words at position $k+1$ through the end of the string.  It's complicated to compute that cost exactly, but you do know that at every step, you have to transition between one of the possible words at position $k$ to one of the possible words at position $k+1$.   So you can get a heuristic function by taking the smallest possible cost for each of these transitions (without worrying about whether the word you transition to at $k+1$ in turn allows the best possible transition to position $k+2$ etc).

Write a function `compute_heuristics` that takes as input the kind of list produced by `expand`: a specification of the set of possible reconstructed words at each position.   Return a list of heuristic values that should be one item longer than the input list.  In the output, the value at position $k$ should be the sum of the best scores from each position to the next, starting from the transition to position $k$ and continuing up to the end of the string.  Thus, if the length of the input is $n$, the value at position $n$ should be 0, the value at position $n-1$ should be the best score from position $n-2$ to $n-1$, the value at position $n-2$ should be the best score from position $n-3$ to $n-2$ plus the best score from position $n-2$ to $n-1$, and so forth, all the way back to the beginning of the string, when you use `None` as the preceding word to factor in the best unigram probability for the initial word at position $0$. 

In [12]:
def best_steps(altlist) :
    def best_step(i) :
        if i==0:
            prev = {None}
        else:
            prev = altlist[i-1]
        return min(score(wp, wi) for wp in prev for wi in altlist[i])
    return list(map(best_step, range(0,len(altlist))))

def compute_heuristics(altlist) :
    steps = reversed(best_steps(altlist))
    score = 0
    h = [0.]
    for s in steps :
        score = score + s
        h.append(score)
    return list(reversed(h))

In [13]:
def t9() :
    "computing heuristics for ['', 'nw', 'xprmnt']"
    return compute_heuristics(expand(['', 'nw', 'xprmnt']))

test(t9, [20.55769324990428, 16.644427645405976, 12.12440798828666, 0.0])

test computing heuristics for ['', 'nw', 'xprmnt'] passed


### Best first search for vowel restoration
<a id="bfs"/>

Let's start with best-first search.  This illustrates the key ideas of approaching text reconstruction as a search problem.  Best first search is basically $A^*$ with heuristic 0. 

The basic structure of a search algorithm looks like this:
- add the intial state to the queue
- then, until you have a solution or you've run out of options
    - get the next item from the queue
    - if it's a solution, return it    
    - create nodes for all of its children and add them to the queue

Recall the general features of this search implementation.
- You need to calculate the cost using our probabilistic language model scoring function
- It's convenient to be able to visualize the actions of the algorithm on small data sets, but you also want to be able to run the same code efficiently on interesting problems.  The design pattern I've used here is to swap out the basic queue operations with new functions that print diagnostic output in the case that the keyword argument `verbose` is `True`.

The key things for vowel restoration:
- We compute the expansion at the start of the search using your `expand` function.
- Whenever we process a node, we create notes for its children by considering all the valid expansions for the next token in the input string.

In [14]:
def bfs(characters, verbose=False) :
    sentence = characters.split(' ')
    altlist = expand(sentence)
    print("Analyzing", sentence)
    
    queue = AStarQ()

    def loud_pop() :
        entry = queue.pop_node()
        print("looking at", entry.state[0], entry.state[1], entry.priority)
        return entry
    def loud_add(i, w, n, c) :
        did = queue.add_node((i, w), (w, n), c, 0.)
        if did :
            print("added node for", i, w, c)
        else :
            print("redundant node for", i, w, c)
            
    if verbose :
        pop, add = loud_pop, loud_add
    else :
        pop, add = (queue.pop_node, 
                    lambda i,w,n,c: queue.add_node((i,w),(w,n),c,0.))
        
    add(0, None,None, 0.)
    while True:
        entry = pop()
        j, w = entry.state
        if j == len(sentence) :
            return unpack_tuples(entry.node)[1:], entry.cost, queue.statistics()

        if j < len(sentence) :
            for w2 in altlist[j] :
                new_score = score(w, w2)
                cost = entry.cost + new_score 
                add(j+1, w2, entry.node, cost)


### Examples and demonstrations with best-first search
<a id="bfseg"/>

In [15]:
bfs(remove_vowels("a new experiment"), verbose=False)

Analyzing ['', 'nw', 'xprmnt']


(['a', 'new', 'experiment'], 20.55769324990428, {'added': 84, 'pushed': 21})

### $A^*$ implementation
<a id="astar"/>

Now we can look at the heuristic.  The point of the heuristic is that it should be fast to calculate (without having to do a search that's anywhere near as complicated as the overall problem) but still give a good bound on the solution quality.  In a tagging problem, once you've tagged the words up to position $j$, what you have left to do is to tag the words from position $j+1$ through to the end of the string (and supply the extra `END` token at the end).  That will occur some cost because of the transitions that you have to use, and exactly what that will be is going to be hard to figure out.  But you know at each step that you will have to incur at least the most likely word-word transition anywhere. Since there's no dependence on the word chosen, just on the index, you know you're going to need all of the heuristic values to solve the problem and you can just store them all in advance as needed.

Now literally the only difference between `astar` and `bfs` is that we precompute the heuristic list, using your `compute_heuristics` function at the beginning and instrument our custom `add` operation to look up the appropriate heuristic value corresponding to the position in the string that we're considering!

In [16]:
def astar(characters, verbose=False) :
    sentence = characters.split(' ')
    altlist = expand(sentence)
    print("Analyzing", sentence)
    queue = AStarQ()

    heuristics = compute_heuristics(altlist)
    
    def loud_pop() :
        entry = queue.pop_node()
        print("looking at", entry.state[0], entry.state[1], entry.priority)
        return entry
    def loud_add(i, t, n, c) :
        did = queue.add_node((i,t), (t,n), c, heuristics[i])
        if did :
            print("added node for", i, t, c + heuristics[i])
        else :
            print("redundant node for", i, t, c + heuristics[i])
            
    if verbose :
        pop, add = loud_pop, loud_add
    else :
        pop, add = (queue.pop_node, 
                    lambda i,t,n,c: queue.add_node((i,t), (t,n), c, 
                                                   heuristics[i]))

    add(0, 'START', None, 0.)
    while True:
        entry = pop()
        j, w = entry.state
        if j == len(sentence) :
            return unpack_tuples(entry.node)[1:], entry.cost, queue.statistics()
        if j < len(sentence) :
            for w2 in altlist[j] :
                new_score = score(w, w2)
                cost = entry.cost + new_score 
                add(j+1, w2, entry.node, cost)


### $A^*$ examples and demonstrations
<a id="astareg"/>

In [20]:
astar(remove_vowels('i now leave'), verbose=False)

Analyzing ['', 'nw', 'lv']


(['i', 'now', 'live'], 17.2941643545853, {'added': 46, 'pushed': 32})

### Analysis 
<a id="conc"/>

Looking at the empirical behavior of best-first search and $A^*$, how do they differ and why? Do they get different answers? Do they differ in effiency?  More generally, how should you think about the time complexity and performance of the two different algorithms on different problems.  Justify your answer with sample runs or mathematical analyses, as appropriate.

### model answer

Best-first search will enumerate all incomplete solution with cost no greater than the best overall solution.  A* search will only consider incomplete solutions if the heuristic function predicts that they might lead to a comparable solution to the best.  The A* search will therefore expand substantially fewer nodes than the best-first search.  However, the A* search does incur an initial cost in computing the overall heuristic tables, so for A* search to make sense the problem needs to have enough complexity to justify this initial cost.
