# Credit
This implementation of the _constrained_ forward-backward algorithm is based on an [implementation](http://www.katrinerk.com/courses/python-worksheets/demo-the-forward-backward-algorithm) of the forward-backward algorithm from Katrin Erk.

# Constrained Forward-Backward Algorithm (Culotta & McCallum 2004)
Input:
- A sequence of observations
- A subsequence of labels (states)
- Transition scores

Output: Marginal probability of that subsequence of labels

## Example

Suppose you want to assign part-of-speech tags to each token in the sentence below:

> The dog jumped over the cat.

In particular, suppose you want to know the probabilty that "jumped over" is "V P". To calculate this, we calculate the score of each possible sequence that contains "V P" and sum those scores to get $Z'$. Then we calculate the score of all possible sequences and sum those scores to get $Z$. Then the marginal probablity of the label subsequence "V P" for "jumped over" is $Z' / Z$ (or $\exp(Z' - Z)$) when dealing with log likelihood.


- [Confidence Estimation for Information Extraction](https://www.aclweb.org/anthology/N04-4028.pdf)
- [The Forward-Backward Algorithm](http://www.cs.columbia.edu/~mcollins/fb.pdf)

Forward probability:

$$\alpha_{t+1}(s_i) = \sum_{s'} \left[ \alpha_t(s') \exp \left( \sum_k \lambda_k f_k (s', s_i, \mathbf{o}, t) \right) \right]$$

This can be read as: the forward probability for a state $s_i$ at the next time step is the sum of each state's score times $e$ to the CRF's output for transitioning from the current state to $s_i$. The constrained forward backward algorithm uses a modified forward pass that includes a state only if it is in the constraints. For example, any sequence passing through $s_{\text{jumped}}(N)$ should be scored 0 if our constraints are "V P" for "jumped over".

Confidence estimate is equal to $\exp (Z'_o - Z_o)$, where:

- $Z'_o = \sum_i \alpha'_T(s_i)$ ($\alpha'$ is a constrained forward value)
- $Z_o = \sum_i \alpha_T(s_i)$

In [1]:
import math
import random

import numpy as np


# Placeholder for the output of the CRF. The scores are coerced to be useful for this example.
def score(prev_state, state, word):
    if word == "jumped" and state == 3:
        return 1
    elif word == "over" and state == 2:
        return 1
    else:
        return random.uniform(0, .001)

In [2]:
def get_marginal(states, observations, constraints):
    cl, ul = get_lattices(states, observations, constraints)
    
    z_prime = sum(cl[:, -1])
    z = sum(ul[:, -1])
    
    return math.exp(z_prime - z)

def get_lattices(states, observations, constraints):
    # Constrained lattice (i.e. all sequences of labels that adhere to constraints)
    cl = np.zeros((len(states), len(observations)))
    
    # Unconstrained lattice (i.e. all possible sequences of labels)
    ul = np.zeros((len(states), len(observations)))
    
    # Initialization
    for s in range(len(states)):
        cl[s, 0] = score(None, states[s], observations[0])
        ul[s, 0] = score(None, states[s], observations[0])
    
    # Time steps 2 through T
    for t in range(1, len(observations)):
        for s in range(len(states)):
            ul[s, t] = sum([ul[s2, t - 1] * math.exp(score(s2, s, observations[t])) 
                            for s2 in range(len(states))])
            
            if t in constraints.keys():
                if s == states.index(constraints[t]):
                    cl[s, t] = sum([cl[s2, t - 1] * math.exp(score(s2, s, observations[t]))
                                    for s2 in range(len(states))])
                else:
                    cl[s, t] = 0
            else:
                cl[s, t] = sum([cl[s2, t - 1] * math.exp(score(s2, s, observations[t])) 
                                for s2 in range(len(states))])
    
    return cl, ul

In [3]:
# List of possible states (labels)
states = ["D", "N", "P", "V"]

# List of observed tokens
observations = ["the", "dog", "jumped", "over", "the", "cat"]

# Dictionary of constraints containing observation indexes as keys and POS tags as values
constraints = {2: "V", 3: "P"}

get_marginal(states, observations, constraints)

0.0793078287784557