# Install
Installing SynJax is simple -- one `pip install` will do the work.

In [None]:
# !pip install -qU git+https://github.com/deepmind/synjax.git

# Imports and Visualizations

We will import just JAX, matplotlib and SynJax.
Matplotlib will be used only through the simple `show` function that gives a heat map of a matrix. This function will be handy in visualizing representations of different structures.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import synjax

In [None]:
def show(x, title):
  plt.title(title)
  plt.imshow(jnp.asarray(x))
  plt.show()

key = jax.random.PRNGKey(42)

# Linear Chain CRF

One common structure prediction problem is a task of labeling elements of a sequence. For instance, we may want to label a sequence of words with their part of speach tags (verb, noun, adjective etc.).

One simple way of accomplishing that task is by having a contextual embedding of each word in the sentence and then independently labeling each word with a classifier. However, that does not account for correlations between a sequence of tags. For example, this simple model doesn't capture the fact that in English a sequence of labels "adjective noun" is much more likely than "noun adjective". To solve that we can use a type of model called Linear Chain Conditional Random Field (CRF) that explicitly models this relation amoung labels.

CRF assigns a non-negative score, called potential, to each pair of labels at each point in the sequence. Concretelly, for an input sequence $\boldsymbol{x} = [x_1, x_2, \dots, x_n]$ we compute non-negative potentials $\phi(\boldsymbol{x}, i, a, b)$ that is a score of having a label $b$ at position $i$ if there is a label $a$ at position $i-1$. The potential can condition on the whole input sequence $\boldsymbol{x}$.
We will simplify notation in the rest of the notebook by not referring to it explicitly.
The label of the first element of the sequence of course does not have a preceding label but we will assume there is a fixed label $0$ at position $0$ that precedes the first element $x_1$. Now we can define a potential of a sequence of labels $\boldsymbol{y}=[y_1, y_2, \dots, y_n]$ with a product of individual potentials:

$$
\phi(\boldsymbol{y}) = \prod_{i=1}^{n} \phi(i, y_{i-1}, y_i)
$$

This potential represents the unnormalized probability of that sequence. To normalize it we divide it with a sum of potentials of all possible sequences of labels for that same length.

$$
P(\boldsymbol{y}) = \frac{ \phi(\boldsymbol{y}) }{\sum_{y' \in Y} \phi(\boldsymbol{y'})}
$$

Computing this normalization requires dynamic programming algorithm called forward algorithm that SynJax implements in an optimized vectorized way that can be compiled with XLA.

## Defining the distribution

As usual with most probability distributions we need to work in the log-space in order to avoid issues with numerical stability. Therefore instead of *potentials* we define *log-potentials* that can take any real value, but ideally they should not go outside of domain $(-1e^5, 1e^5)$ to be safe from numerical errors with most floating data types.

The shape of these log-potentials is $(n, m, m)$ where $n$ is the size of input sequcne and $m$ is the number of labels. Log-potential at position $[i, j, k]$ signifies the log-potential of position $i$ having a label $k$ given that the preceding position had label $j$.

As mentioned earlier we assume that the label that precedes any input is by convention $0$. That means that log-potentials provided at $[0, 1\!:, :]$ are ignored.

An additional argument can be provided that specifies the length of a sequence. That is useful in case we want to process a batch of sequences of different length. The log-potentials tensor will have to be of the same shape for all sequences in the batch, but the provided length will inform SynJax how to do the padding correctly.

Both log-potentials and lengths parameters can have preceding batch dimensions . Here is a simple Linear Chain CRF that is randomly initialized and has each sequence of a different length.

In [None]:
b, n, states = 3, 100, 50

potentials = jax.random.uniform(jax.random.PRNGKey(0), (b, n, states, states))
log_potentials = synjax.special.safe_log(potentials)
lengths = jnp.array([n//3, n//2, n])
dist = synjax.LinearChainCRF(log_potentials, lengths=lengths)

dist

## Computing most likely structures and other interesting quantities

We can compute many useful quantities with this object.

Some return (batched) scalars:
* `dist.log_prob(event)` finds log-probability of a particular sequence of transitions,
* `dist.unnormalized_log_prob(event)` finds log-potential  of a particular sequence of transitions,
* `dist.log_partition()` will return the sum of log of the normalization constant,
* `dist.entropy()` would compute the entropy $H(\operatorname{dist})$,
* `dist.cross_entropy(dist2)` computes cross-entropy against some other distribution dist2 $H(\operatorname{dist}, \operatorname{dist2})$,
* `dist.kl_divergence(dist2)` similarly computes $D_{\operatorname{KL}}(\operatorname{dist}||\operatorname{dist2})$.

Some returns structured objects:
* `dist.argmax()` will return the most probable labeling,
* `dist.top_k(k)` will return top k most probale labelings,
* `dist.sample(key)` will return a sample of labeling for a given sampling key,
* `dist.marginals()` will return marginal probability of each edge,
* `dist.log_marginals()` will return log of marginal probabilities of each edge,
* `dist.log_count()` will return log of the number of valid structures in the support.

These structured objects are of ***the same shape*** as log-potentials. In the case of `argmax`, `top_k` and `sample` these are one-hot versions of log-potentials that mark each edge present in the output structure as 1 and non-present edges with 0. The shape of this tensor is $(n, m, m)$. If instead of edges we want labels that can be retrieved with one line `jnp.sum(event, axis=-2)`. Here are some examples of this. Notice how SynJax correctly pads each structure depending on its provided length.

In [None]:
event_of_edges = dist.argmax()  # has shape (b, n, states, states)
event_of_labels = jnp.sum(event_of_edges, axis=-2)  # has shape (b, n, states)

for i in range(dist.batch_shape[0]):
  plt.title(f"Best labeling from batch entry {i} with length {dist.lengths[i]}")
  plt.imshow(event_of_labels[i].T)
  plt.show()

Another useful quantity is a marginal probability of each edge appearing in the correct labeling. Since visualizing marginals of edges is more difficult, we will marginalize marginals of labels in the same way: by summing over all the edges that end up in with same target label.

In [None]:
marginals_of_edges = dist.marginals()  # has shape (b, n, states, states)
marginals_of_labels = jnp.sum(marginals_of_edges, axis=-2)  # shape (b, n, states)

for i in range(dist.batch_shape[0]):
  plt.title(f"Marginal probability from batch entry {i} "
            f"with length {dist.lengths[i]}")
  plt.imshow(marginals_of_labels[i].T)
  plt.show()

# HMM

Hidden Markov Model (HMM) is a generative locally normalized model for sequential tagging.

In [None]:
n, states, voc = 100, 50, 30
keys = jax.random.split(jax.random.PRNGKey(0), 4)
dist = synjax.HMM(init_logits=jax.random.normal(keys[0], (states,)),
                  transition_logits=jax.random.normal(keys[1], (states,states)),
                  emission_dist=jax.random.normal(keys[2], (states, voc)),
                  observations=jax.random.randint(keys[3], (n,), 0, voc))

def show_chain(chain, title):
  show(chain.sum(-2).T, title)

show_chain(dist.marginals(), "marginals")
show_chain(dist.argmax(), "argmax")
show_chain(dist.sample(key, 2).sum((0)), "samples")

# Alignment CRF

AlignmentCRF can model both monotone and non-monotone alignment. Log-potentials are a rectangular table of weights for matching two items. Monotone alignments support all probabilistic inferences, while non-monotone supports only argmax because other quantities are intractable.

In [None]:
n, m = 100, 150
key = jax.random.PRNGKey(0)
show(synjax.AlignmentCRF(jax.random.normal(key, (n, n)), alignment_type='non_monotone_one_to_one').argmax(), "non_monotone_one_to_one")
show(synjax.AlignmentCRF(jax.random.normal(key, (n, n)), alignment_type='monotone_many_to_many').argmax(), "monotone_many_to_many")
show(synjax.AlignmentCRF(jax.random.normal(key, (n, m)), alignment_type='monotone_one_to_many').argmax(), "monotone_one_to_many")

Here are some additional quantities that can be computed for monotone alignment.

In [None]:
dist = synjax.AlignmentCRF(jax.random.normal(key, (n, m)),
                           alignment_type='monotone_many_to_many')
print("Entropy", dist.entropy())
print("Log-partition", dist.log_partition())
print("Log-count", dist.log_count())

show(dist.marginals(), "marginals")
show(dist.sample(key, 2).sum(0), "2 samples")

# CTC

[Connectionist Temporal Classification (CTC)](https://distill.pub/2017/ctc/) is often used in speech recognition where the alignment between the input (speech signal) and gold output (words) is not observed in the training data. The assumption of this model is that the alignment is monotone (no reorderings), which is clearly the case in speech recognition. CTC is a wrapper on top of Alignment CRF from above. SynJax implementation of CTC provides not only the computation of the loss but all other useful quantities like argmax (forced alignment), sampling of alignments etc.

In [None]:
gold_label_n, prediction_n, voc = 16, 32, 10_000
logits = jax.random.normal(key, (prediction_n, voc))
gold_labels = jax.random.randint(key, (gold_label_n,), 0, voc)
dist = synjax.CTC(log_potentials=logits, labels=gold_labels)
show(dist.marginals(), "marginals")
show(dist.argmax(), "forced alignment")
show(dist.sample(key), "sampled alignment")
print("CTC loss", dist.loss())

# Spanning Trees

All types of spanning trees are accessed trough `synjax.SpanningTreeCRF` class. Naturally, it takes (optionally batched) matrix of log-potentials, lengths per instance and flags that signify if the spanning tree is directed (it is an arborescence), if it is projective (used mostly in NLP) and if it it has only one root outgoing edge. Below are some examples of how different variations of spanning tree distribution could be instantiated.

## Directed Non-Projective Dependency CRF

In [None]:
n = 20

log_potentials = jax.random.normal(key, (n, n))
dist = synjax.SpanningTreeCRF(log_potentials, directed=True,
                              projective=False, single_root_edge=True)

show(dist.marginals(), "marginals")
# Line below may be slightly slower on first run because of Numba compilation.
show(dist.argmax(), "argmax")
show(dist.sample(key, 2).sum((0)), "samples")

## Undirected Non-Projective Dependency CRF

In [None]:
n = 20

log_potentials = jax.random.normal(key, (n, n))
dist = synjax.SpanningTreeCRF(log_potentials, directed=False,
                              projective=False, single_root_edge=True)

show(dist.marginals(), "marginals")
show(dist.argmax(), "argmax")
show(dist.sample(key, 3).sum((0)), "samples")

## Directed Projective Dependencies CRF

In [None]:
n = 20

log_potentials = jax.random.normal(jax.random.PRNGKey(0), (n, n))
dist = synjax.SpanningTreeCRF(log_potentials, directed=True,
                              projective=True, single_root_edge=True)

show(dist.marginals(), "marginals")
show(dist.argmax(), "argmax")
show(dist.sample(key, 3).sum((0)), "samples")

# Constituency Trees

## Tree CRF

  The model structure is very similar to [Stern et al (2017)](https://aclanthology.org/P17-1076.pdf) except SynJax
  additionally supports properly normalizing the distribution.

In [None]:
nt = 1
n = 15

log_potentials = jax.random.normal(key, (n, n, nt))
dist = synjax.TreeCRF(log_potentials, lengths=None)

show(dist.marginals().sum(-1), "marginals")
show(dist.argmax().sum(-1), "argmax")
show(dist.sample(key, 2).sum((0, -1)), "samples")
show(dist.top_k(2)[0].sum((0, -1)), "top_k")

## PCFG

  
Note that this is a conditional PCFG, i.e. it is a distribution over trees
provided by PCFG conditioned by a provided sentence. Because of that calling
`dist.log_probability(tree)` returns a `p(tree | sentence; pcfg)`. To get a
joint probability of a tree and a sentence `p(tree, sentence ; pcfg)` call
`dist.unnormalized_log_probability(tree)`.


In [None]:
t, pt, n, voc = 4, 8, 10, 10

normal = jax.random.normal
keys = jax.random.split(jax.random.PRNGKey(0), 4)

dist = synjax.PCFG(
    root=normal(keys[0], (nt,)),
    rule=normal(keys[1], (nt, nt+pt, nt+pt)),
    emission=normal(keys[2], (pt, voc)),
    word_ids = jax.random.randint(keys[3], (n,), 0, voc)
)


show(dist.marginals().chart.sum(-1), "marginals")
show(dist.argmax().chart.sum(-1), "argmax")

## Tensor-Decomposition PCFG

[Cohen et al (2013)](https://aclanthology.org/N13-1052.pdf#page=8) showed that PCFG with large number of non-terminals can be
approximated using CPD tensor decomposition. [Yang et al (2022)](https://aclanthology.org/2022.naacl-main.353.pdf) used this to
do efficient grammar induction with large number of non-terminals and
relatively small rank dimension. They avoid tensor-decomposition step by
keeping all parameters always in the rank space and enforcing all decomposed
rules to be normalized. Just like a regular PCFG implementation, the implementation of Tensor-Decomposition PCFG is also a conditional model for a given sentence.

In [None]:
nt, pt, n, voc, rank = 4, 8, 10, 10, 6

normal = jax.random.normal
keys = jax.random.split(jax.random.PRNGKey(0), 5)

dist = synjax.TensorDecompositionPCFG(
    root=normal(keys[0], (nt,)),
    nt_to_rank=normal(keys[1], (nt, rank)),
    rank_to_left_nt=normal(keys[2], (rank, nt+pt)),
    rank_to_right_nt=normal(keys[3], (rank, nt+pt)),
    emission=normal(keys[4], (pt, voc)),
    word_ids = jax.random.randint(keys[4], (n,), 0, voc)
)

show(dist.marginals().chart.sum(-1), "marginals")
show(dist.mbr(marginalize_labels=True).chart.sum(-1), "MBR")