In [None]:
!pip install -qU matplotlib
!pip install -qU git+https://github.com/deepmind/synjax.git

# Linear Chain CRF

One common structure prediction problem is a task of labeling elements of a sequence. For instance, we may want to label a sequence of words with their part of speach tags (verb, noun, adjective etc.).

One simple way of accomplishing that task is by having a contextual embedding of each word in the sentence and then independently labeling each word with a classifier. However, that does not account for correlations between a sequence of tags. For example, this simple model doesn't capture the fact that in English a sequence of labels "adjective noun" is much more likely than "noun adjective". To solve that we can use a type of model called Linear Chain Conditional Random Field (CRF) that explicitly models this relation amoung labels.

CRF assigns a non-negative score, called potential, to each pair of labels at each point in the sequence. Concretelly, for an input sequence $\boldsymbol{x} = [x_1, x_2, \dots, x_n]$ we compute non-negative potentials $\phi(\boldsymbol{x}, i, a, b)$ that is a score of having a label $b$ at position $i$ if there is a label $a$ at position $i-1$. The potential can condition on the whole input sequence $\boldsymbol{x}$.
We will simplify notation in the rest of the notebook by not referring to it explicitly.
The label of the first element of the sequence of course does not have a preceding label but we will assume there is a fixed label $0$ at position $0$ that precedes the first element $x_1$. Now we can define a potential of a sequence of labels $\boldsymbol{y}=[y_1, y_2, \dots, y_n]$ with a product of individual potentials:

$$
\phi(\boldsymbol{y}) = \prod_{i=1}^{n} \phi(i, y_{i-1}, y_i)
$$

This potential represents the unnormalized probability of that sequence. To normalize it we divide it with a sum of potentials of all possible sequences of labels for that same length.

$$
P(\boldsymbol{y}) = \frac{ \phi(\boldsymbol{y}) }{\sum_{y' \in Y} \phi(\boldsymbol{y'})}
$$

Computing this normalization requires dynamic programming algorithm called forward algorithm. For the details of this algorithm see the [references section](#scrollTo=References). SynJax provides multiple versions of this algorithm that have different pros and cons.

## Defining the distribution

First we import the necessary libraries.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import synjax

As usual with most probability distributions we need to work in the log-space in order to avoid issues with numerical stability. Therefore instead of *potentials* we define *log-potentials* that can take any real value, but ideally they should not go outside of domain $(-1e^5, 1e^5)$ to be safe from numerical errors with most floating data types.

The shape of these log-potentials is $(n, m, m)$ where $n$ is the size of input sequcne and $m$ is the number of labels. Log-potential at position $[i, j, k]$ signifies the log-potential of position $i$ having a label $k$ given that the preceding position had label $j$.

As mentioned earlier we assume that the label that precedes any input is by convention $0$. That means that log-potentials provided at $[0, 1\!:, :]$ are ignored.

An additional argument can be provided that specifies the length of a sequence. That is useful in case we want to process a batch of sequences of different lenght. The log-potentials tensor will have to be of the same shape for all sequences in the batch, but the provided lenght will inform SynJax how to do the padding correctly.

Both log-potentials and lengths parameters can have preceding batch dimensions . Here is a simple Linear Chain CRF that is randomly initialized and has each sequence of a different length.

In [None]:
b, n, m = 3, 15, 5

potentials = jax.random.uniform(jax.random.PRNGKey(0), (b, n, m, m))
log_potentials = synjax.special.safe_log(potentials)
lengths = jnp.array([5, 10, 15])
dist = synjax.LinearChainCRF(log_potentials, lengths=lengths)

dist

## Computing most likely structures and other interesting quantities

We can compute many useful quantities with this object.

Some return (batched) scalars:
* `dist.log_prob(event)` finds log-probability of a particular sequence of transitions,
* `dist.unnormalized_log_prob(event)` finds log-potential  of a particular sequence of transitions,
* `dist.log_partition()` will return the sum of log of the normalization constant,
* `dist.entropy()` would compute the entropy $H(\operatorname{dist})$,
* `dist.cross_entropy(dist2)` computes cross-entropy against some other distribution dist2 $H(\operatorname{dist}, \operatorname{dist2})$,
* `dist.kl_divergence(dist2)` similarly computes $D_{\operatorname{KL}}(\operatorname{dist}||\operatorname{dist2})$.

Some returns structured objects:
* `dist.argmax()` will return the most probable labeling,
* `dist.top_k(k)` will return top k most probale labelings,
* `dist.sample(key)` will return a sample of labeling for a given sampling key,
* `dist.marginals()` will return marginal probability of each edge,
* `dist.log_marginals()` will return log of marginal probabilities of each edge,
* `dist.log_count()` will return log of the number of valid structures in the support.

These structured objects are of ***the same shape*** as log-potentials. In the case of `argmax`, `top_k` and `sample` these are one-hot versions of log-potentials that mark each edge present in the output structure as 1 and non-present edges with 0. The shape of this tensor is $(n, m, m)$. If instead of edges we want labels that can be retrieved with one line `jnp.sum(event, axis=-2)`. Here are some examples of this. Notice how SynJax correctly pads each structure depending on its provided length.

In [None]:
event_of_edges = dist.argmax()  # has shape (b, n, m, m)
event_of_labels = jnp.sum(event_of_edges, axis=-2)  # has shape (b, n, m)

for i in range(dist.batch_shape[0]):
  plt.title(f"Best labeling from batch entry {i} with length {dist.lengths[i]}")
  plt.imshow(event_of_labels[i].T)
  plt.show()

Another useful quantity is a marginal probability of each edge appearing in the correct labeling. Since visualizing marginals of edges is more difficult, we will marginalize marginals of labels in the same way: by summing over all the edges that end up in with same target label.

In [None]:
marginals_of_edges = dist.marginals()  # has shape (b, n, m, m)
marginals_of_labels = jnp.sum(marginals_of_edges, axis=-2)  # shape (b, n, m)

for i in range(dist.batch_shape[0]):
  plt.title(f"Marginal probability from batch entry {i} "
            f"with length {dist.lengths[i]}")
  plt.imshow(marginals_of_labels[i].T)
  plt.show()

## Available algorithms

All the quantities are computed with the same forward algorithm, however that algorithm can be implemented in two ways. The standard one is (mostly) sequential and processes a sequence from left-to-right in $O(m^2 n)$ time. There is an alternative implementation proposed by Hassan et al (2021) that processes the whole sequence in parallel with parallel runtime complexity of $O(m^3 \log n)$. Rush (2020) reports that parallel implementation was faster for long sequences. In our benchmarks sequential implementation was always faster and took less memory. For that reason we have set sequential implementation as the default one, but user can override that by providing keyword argument `forward_algorithm="parallel"` to any of the methods.

If there are many structures that share a tie for the most probable structure calling `dist.argmax()` will not return one-hot tensor but instead have fractional counts. This is a very unlikely situation in most applications but if you want to be sure that it doesn't happen you can provide additional keyword argument `strict_max=True` to argmax which would incur a small runtime penalty.

## References

* [Michael Collins -- Log-Linear Models, MEMMs, and CRFs](http://www.cs.columbia.edu/~mcollins/crf.pdf)
* [Sutton and McCallum 2012 -- An Introduction to Conditional Random Fields](https://homepages.inf.ed.ac.uk/csutton/publications/crftutv2.pdf)
* [Lafferty et al 2001 -- Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Dataand Labeling Sequence Data](https://repository.upenn.edu/cgi/viewcontent.cgi?article=1162&context=cis_papers)
* [Hassan et al 2021 -- Temporal Parallelization of Inference in Hidden Markov Models](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9512397)
* [Rush 2020 -- TorchStruct: Deep Structured Prediction Library](https://aclanthology.org/2020.acl-demos.38.pdf)