# Token Alignment

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from genparse import CFGLM
from genparse.cfglm import locally_normalize, EOS
from genparse.align import CharAlignedCFGLM
from genparse.util import display_table
from genparse.steer import generation_tree
from genparse.segmentation import bpe_wfst, segmentation_pfst

$$
\newcommand{\aa}[0]{\boldsymbol{a}}
\newcommand{\bb}[0]{\boldsymbol{b}}
\newcommand{\AA}[0]{\mathcal{A}}
\newcommand{\BB}[0]{\mathcal{B}}
$$

Let $p$ be a distribution over character strings $\aa \in \AA^*$.

Let $p'$ be a distribution over BPE strings $\bb \in \BB^*$.

Let $\phi\colon \BB^* \to \AA^*$ be a decoding function. The decoding function satisfies: $\phi(\bb \, \bb') = \phi(\bb) \, \phi(\bb')$ for all $\bb, \bb' \in \BB^*$.

In [None]:
p = CFGLM.from_string(
    """

1: S -> a
1: S -> a a
2: S -> a a a

"""
)

A = p.cfg.V
B = {'a', 'aa', 'aaa', EOS}
ϕ = lambda b: ''.join(b).strip(EOS)

Our goal is to transform the distribution $p$ into a distribution $p'$ such that the following correctness condition holds:

**Correctness Condition:**

$$
\forall \aa \in \AA\colon\quad   p(\aa) = \sum_{\bb\colon\ \phi(\bb) = \aa} p'(\bb)
$$

The correctness condition ensures that the process: $\bb \sim p'$, $\aa = \phi(\bb)$ generates $\aa$ that is distributed $p$.

**A stochastic tokenization model:**

$$
p'(\bb) = \sum_{\aa} \, p(\aa) \, \!\!\!\!\!\!\!\!\!\!\underbrace{p(\bb \mid \aa)}_{\substack{\text{probabilistic transducer} \\ \text{where } \phi(\bb) = \aa \text{ holds } w.p.1 }}
$$


When $p$ is PCFG-LMs, we may use composition with any segmentation PFST to construction $p'$.

**Tokenization Preferences.**
It is possible to satisfy the correctness conditions in undesirable ways because they would adversely affect the downstream components.
Consider the following **unwanted workaround** for BPE tokenization.  In that case, we have, by construction, individual characters included in the token vocabulary.  This means we can define a trivial segmentation that only takes unit-length segments as the only tokenization with nonzero probability.  This will not be the preferred prediction scheme of the LLM model, as it will generally prefer the segments that are more representative of those appearing in this context in the training data.  These tend to be the longest-matching tokens.  Some LMs are trained with subword-regularization schemes, which may make them more robust to the specific segmentation.  These are design choices, we suggested that relatively flat distributions over segmentations will likely work best.  However, the maximum-match version has computational benefits.

### Grafting Heuristic

Below, we explore some preliminary attempts to understand the distortion in the char-alignment adaptor (I might refer to this as "grafting" or a more global transduction-based approach).

The class `CharAlignedCFGLM` implements an LM $q$ over tokens based on the following conditional factorization:

$$
q(b_{N+1} \mid b_1, \ldots, b_N) \propto p( a_{N+1} \mid a_1 \cdots a_N )
$$

where $b_1, \ldots, b_N$ are a sequence of tokens ids, and $a_1, \ldots, a_N$ are their respective strings in $\mathcal{A}^*$ (i.e., $\phi(b_k) = a_k$ for each $k = 1, \ldots, N$).

In [None]:
graft = CharAlignedCFGLM(p, B, EOS)

Our target distribution is the following

In [None]:
generation_tree(p)

In [None]:
generation_tree(graft)

In [None]:
generation_tree(graft, chunked=True)

### Weighted Transducer

The following WFST simulates the BPE's desire to create chunks from character sequences.

In [None]:
T = bpe_wfst((b, tuple(b)) for b in B).T

We can push some specific character strings throught the transducer to see all of the ways that can be chunked.

In [None]:
T('aaa', None).epsremove.trim

In [None]:
T('aaa', None).total_weight()

In [None]:
b_lm = CFGLM(locally_normalize((p.cfg @ T).trim(), tol=1e-100).trim())

In [None]:
L = b_lm.cfg.language(10)

In [None]:
generation_tree(b_lm)

In [None]:
PL = L.project(ϕ)
PL

In [None]:
display_table(
    [[p.cfg.language(100).project(ϕ), generation_tree(graft).D, PL]],
    headings=['target', 'grafting-heuristic', 'composition'],
)

### The Probabilistic Segmentation Model

In [None]:
PT = segmentation_pfst(B, p.cfg.V - {EOS}, canonical=True)

In [None]:
PT

In [None]:
# pb_lm = CFGLM(locally_normalize((p.cfg @ PT).trim()).trim())
pb_lm = CFGLM((p.cfg @ PT).trim())

In [None]:
generation_tree(pb_lm)

In [None]:
L_PB = pb_lm.cfg.language(100).project(ϕ)
L_PB

In [None]:
L_PB.assert_equal(p.cfg.language(100).project(ϕ))  # character-level distribution matches!

In [None]:
display_table(
    [[p.cfg.language(100).project(ϕ), L_PB, generation_tree(graft).D, PL]],
    headings=['target', 'pfst', 'grafting-heuristic', 'composition'],
)