# Exploring FSTs for Token Alignment

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
from arsenal import timeit
from IPython.display import HTML
from genparse import FST, Float, EarleyLM as CFGLM, MockLLM, locally_normalize, EOS
from genparse.proposal import TokenProposal
from genparse.util import LarkStuff, interegular_to_wfsa
from genparse.trace import TraceSWOR

In [None]:
def bpe2term_approx(tokenizer, bpe_sequence):
    from genparse import FST, Float

    # approximate the transducer using a single canonical path;
    # UPDATE: the unpruned answer should match this - it's the uncertainty over bpe that's tricky
    c = tuple(
        ([b], tokenizer.convert_ids_to_tokens(b).replace('Ġ', ' ')) for b in bpe_sequence
    )
    tmp = FST.from_pairs([([], '')], Float)
    for pair in c:
        tmp = tmp * FST.from_pairs([pair], Float)
    return tmp
    # TODO: approximate this transducer by a canonical path
    # return c2t(c, None).trim.epsremove.trim

In [None]:
# was a method on LarkStuff
def lark_stuff_transducer(self, decay=0.99):
    from genparse import EPSILON, FST, Float

    m = FST(Float)
    START = 0
    STOP = 1
    m.add_I(START, 1)
    m.add_F(STOP, decay)
    m.add_arc(STOP, (EPSILON, EPSILON), START, 1)
    for token_id, token_class in enumerate(self.terminals):
        fsm = interegular_to_wfsa(token_class.pattern.to_regexp())
        for i, w in fsm.I:
            m.add_arc(START, (EPSILON, token_class.name), (token_id, i), w)
        for i, w in fsm.F:
            m.add_arc((token_id, i), (EPSILON, EPSILON), STOP, w)
        for state in fsm.states:
            for char, next_state, w in fsm.arcs(state):
                m.add_arc(
                    (token_id, state),
                    (char, EPSILON),
                    (token_id, next_state),
                    w * decay,
                )
    return m

## Accounting for BPE's Tokenization Ambiguity with Transduction 

In [None]:
lark_stuff = LarkStuff(
    r"""
    start: NAME
    NAME: /(a|b)*c/
    """
)
foo = lark_stuff.char_cfg()
foo = locally_normalize(foo)
assert len(foo.trim()) > 0

In [None]:
foo

In [None]:
foo.cnf.language(3)

In [None]:
lm = CFGLM(foo)

In [None]:
trace = TraceSWOR()
for t in range(15):
    with trace:
        print(t, lm.sample(draw=trace))

In [None]:
def about(m):
    print(len(m.states), 'states')

In [None]:
import transformers
from genparse.tokenization import decode_tokenizer_vocab
from genparse.segmentation import bpe_wfst

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
decode = decode_tokenizer_vocab(tokenizer)
T = bpe_wfst(enumerate(decode))
about(T)

Let's shrink the BPE transducer down to something managable by limiting it alphabet 

In [None]:
# b2c = T.prune_to_alphabet(None, foo.V | {''}).renumber
# about(b2c)

In [None]:
b2c = T.prune_to_alphabet(None, foo.V | {''}).renumber

We can look at our little language's strings thru the lense of their possible BPE sequences.  Notice that these strings are ambiguously mapped to BPE --- meaning that there are many BPE sequences that would give rise to the same string!

In [None]:
for x in foo.cnf.language(3):
    display(HTML('<hr/>'))
    print(x)
    bpe_x = b2c(None, x).epsremove.trim
    print('total weight of BPE sequences (i.e., ambiguity):', bpe_x.total_weight())
    display(bpe_x)
    print()

In [None]:
tmp = (b2c @ foo).trim()

In [None]:
L = 5

In [None]:
c = Float.chart()
for (
    x,
    w,
) in tmp.cnf.language(5).items():
    y = tokenizer.decode(x)
    if len(y) > L:
        continue
    c[y] += w

In [None]:
ambig = Float.chart({x: b2c(None, x).total_weight() for x in c})

In [None]:
ccc = Float.chart()
for x in c:
    ccc[x] = c[x] / ambig[x]

In [None]:
cc = Float.chart()
for x, w in foo.cnf.language(L + 2).items():
    if len(x) > L:
        continue
    cc[''.join(x)] += w
# cc

In [None]:
ccc.assert_equal(cc, tol=1e-10)

In [None]:
# cc.metric(c)

In [None]:
# tmp.trim(bottomup_only=True)

In [None]:
# show_grammar(tmp, showzero=True)

In [None]:
# print(tmp.agenda().__str__(style_value=lambda k, v: (colors.light.red % v) if v > 1.000001 or v < 0 else v))

In [None]:
# for q in c2t.states:
#    for (a,b), r, w in c2t.arcs(q):
#        print(f'--{a or "ε"}:{b or "ε"}/{w}-->', r)

In [None]:
# {x: v for x,v in tmp.agenda().items() if v > 1.001 or v < 0}

In [None]:
# len(tmp.N - tmp.agenda(tol=1e-40, maxiter=np.inf).trim().keys()), len(tmp.N), len(tmp.agenda(tol=1e-40).trim())

In [None]:
# tmp.cnf.language(4)

In [None]:
# show_grammar(tmp)

In [None]:
p = locally_normalize(tmp, tol=1e-20, maxiter=np.inf).trim()

In [None]:
lm2 = CFGLM(p.cnf)

In [None]:
# lm2.sample(verbose=1)

In [None]:
# context = (64,65,6485,39305)
context = (
    64,
    65,
    6485,
)

In [None]:
char_context = tokenizer.decode(context)
char_context

In [None]:
df = []
for x, w in sorted(lm2.p_next(context).normalize().items(), key=lambda kv: -kv[1]):
    df.append((x, (decode[x] if x != EOS else EOS), w))
pd.DataFrame(df, columns=['token_id', 'chars', 'prob']).set_index('token_id')

In [None]:
lm.p_next(char_context).normalize()

## Lexing

In [None]:
lark_stuff = LarkStuff(
    r"""
    start: "SELECT" WS STAR WS "FROM" WS NAME WS EOS
    EOS: "</s>"
    NAME: /[a-z]+/
    STAR: "*"
    WS: /[ ]/
    """
)

In [None]:
foo = lark_stuff.char_cfg()

In [None]:
# foo['NAME'].trim().agenda()

In [None]:
# foo.agenda()

In [None]:
foo = locally_normalize(foo, tol=1e-100).trim()
assert len(foo) > 0

In [None]:
# foo

In [None]:
lm = CFGLM(foo)

In [None]:
trace = TraceSWOR()
for _ in range(15):
    print('mass=', trace.root.mass)
    with trace:
        print(''.join(lm.sample(draw=trace, prob=False)))

In [None]:
cfg = lark_stuff.convert().renumber()

In [None]:
c2t = lark_stuff_transducer(lark_stuff, decay=0.0125)
len(c2t.states)

In [None]:
c2t

The `lark` library will only lex it one way because it has a deterministic semantics for prioritized lexing:

In [None]:
x = 'SELECT * FROM data'

In [None]:
list(lark_stuff.lex(x))

However, this string can lex many different ways:

In [None]:
ambig = (
    (FST.from_string(x, Float) @ c2t)
    .trim.project(1)
    .epsremove.trim.to_cfg()
    .cnf.language(15)
)

In [None]:
# ambig

It might be fine to allow ambiguous lexing because very few of the possible lexing options will survive the parser.

In [None]:
for y in ambig:
    v = cfg.prefix_weight(y)  # show all options with a nonzero prefix weight
    if v == 0:
        continue
    print(v, y)

In [None]:
# ((FST.from_string('SELECT', Float) @ c2t) @ P.T @ cfg).trim().cnf.language(15)

In [None]:
# (P.T @ cfg).trim().cnf.language(100)

In [None]:
cfg_t = (c2t.renumber @ cfg).trim()
pcfg_t = locally_normalize(cfg_t, tol=1e-100, maxiter=10_000_000)

In [None]:
cfg_t('SELECT * FROM data </s>')

In [None]:
cfg_t('SELECT * FROM data </s>')

In [None]:
lm = CFGLM(pcfg_t.cnf)

In [None]:
for _ in range(10):
    print(''.join(lm.sample(prob=False)))

In [None]:
lm.p_next('SELECT * FROM ')

## BPE Basics

In [None]:
b2c = T
len(b2c.states)

In [None]:
x = 'SELECT * FROM data'
b = tokenizer.encode(x)
b

In [None]:
[tokenizer.decode(bb) for bb in b]

In [None]:
with timeit('composition'):
    c = FST.from_string(tuple(b), Float) @ b2c
about(c)

In [None]:
c.trim

We can build this "transducer" more efficiently

In [None]:
t = bpe2term_approx(tokenizer, tokenizer.encode(x)).epsremove.trim
t

## BPE Ambiguity

In [None]:
x = x = 'SELECT * FROM data'

In [None]:
with timeit('composition'):
    bs = b2c @ FST.from_string(x, Float)
with timeit('trim'):
    bs.trim
about(bs)

In [None]:
# bs.trim

The automaton below describes all the BPE sequences that generate the string `x` and the number below is the total weight of these paths (in the count semiring these are the number of distinct paths):

In [None]:
bs.trim.project(0).epsremove.trim.total_weight()

In [None]:
bs.trim.project(0).epsremove.trim

To see all the BPE sequences that generate `x` run the cell below:

In [None]:
# for y in bs.trim.project(0).epsremove.trim.to_cfg().language(10):
#    print(tokenizer.decode(y), y)

## The Grafting Heuristic

In [None]:
lark_stuff = LarkStuff(
    r"""
start: "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS
EOS: "</s>"
select_expr: STAR | select_list
bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")"
bool_expr: var "=" value | var ">" value | var "<" value
from_expr: "data"
orderby_expr: var_list WS "ASC" | var_list WS "DESC"
select_list: select_var ("," WS select_var)*
var_list: var ("," WS var)*
select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")"
var: "age" | "gender" | "year" | "state_color" | "zipcode" | "vote" | "race_ethnicity"
value: NUMBER | "red" | "blue" | "white" | "black" | "latino" | "republican" | "democrat" | "male" | "female"
STAR: "*"
NUMBER: /\d+/
WS: " "
"""
)

foo = lark_stuff.char_cfg()
foo = locally_normalize(foo, tol=1e-100).trim()
assert len(foo) > 0
lm = CFGLM(foo)

In [None]:
print(''.join(lm.sample(prob=False)))

In [None]:
bpe_lm = TokenProposal(
    guide=lm, llm=MockLLM(V={x for x in decode}, eos=tokenizer.eos_token)
)

In [None]:
lm.p_next('')

In [None]:
bpe_lm._prompt = ()
bpe_lm._p_next(())

In [None]:
lm.p_next('SELECT')

In [None]:
bpe_lm._p_next(('SELECT',))

In [None]:
ys = bpe_lm.sample()
ys