In [1]:
import numpy as np
from genlm_control.experimental.subtoken import SubtokenPotential
from genlm_control import EOT, EOS, BoolFSA, PromptedLLM
from arsenal.maths import sample_dict

  from .autonotebook import tqdm as notebook_tqdm


## Subtoken potentials

> Reinterpreting a potential defined on a coarser token alphabet $\mathcal{A}$ in terms of the finer-grained alphabet $\mathcal{B}$.

Suppose we have a potential $\Phi$ defined over a (finite) alphabet $\mathcal{A}$ with end-of-sequence token $\textsf{eos}$. Let $\mathcal{B}$ be a finer-grained alphabet such that $\mathcal{A} \subseteq \mathcal{B}^*$. We reinterpret $\Phi$ as a potential $\Psi_{\bm{x}}$ over $\mathcal{B}$ by projecting $\Phi$ from tokens in $\mathcal{A}$ to sequences of subtokens in $\mathcal{B}^*$, conditioned on a token context $\bm{x} \in \mathcal{A}^*$.

The complete potential is defined as:

$$
\psi_{\bm{x}}(\bm{s}) = \begin{cases}
\phi(\bm{s} \mid \bm{x}) &\text{if }\bm{s} \in \mathcal{B}^* \cup \{\textsf{eos}\}\\
0  &\text{otherwise }
\end{cases}
$$

And the next-(sub)token potential is defined as:

$$
\Psi_{\bm{x}}(s \mid \bm{s}) = \begin{cases}
\frac{\overrightarrow{\psi}_{\bm{x}}(\bm{s}s)}{\overrightarrow{\psi}_{\bm{x}}(\bm{s})} &\text{ if } s \in \mathcal{B}\\
\frac{\psi_{\bm{x}}(\bm{s})}{\overrightarrow{\psi}_{\bm{x}}(\bm{s})} &\text{ if } s = \textsf{eot}
\end{cases}
$$

where we use $\textsf{eot}$ (”end of token”) as the distinguished end symbol. 

### Marginalized subtoken potential

As is the case for token-level potentials, the **prefix potential** can be chosen by the user so long as it satisfies absolute continuity and consistency. The **optimal prefix potential** is the marginal of the complete potential (prove):

$$
\overrightarrow{\psi}_{\bm{x}}(\bm{s}) = \sum_{\bm{s}' \in \mathcal{B}^* \cup \{\textsf{eos}\} : \bm{s}' \succeq \bm{s}} \psi_{\bm{x}}(\bm{s}')
$$

The optimal prefix potential can be efficiently parameterized for all $\bm{s} \in \mathcal{A} \cup \{\textsf{eos}\}$ using a single sparse matrix multiplication.


In [2]:
llm = PromptedLLM.from_name("gpt2", backend="hf", temperature=0.5)
llm.set_prompt_from_str("the big red")



In [3]:
subtoken_llm = SubtokenPotential(llm)

In [4]:
llm.prompt

[b'the', b' big', b' red']

In [5]:
logps = await subtoken_llm.logw_next(b" i", [b" box"])
logps.exp().materialize(top=5).project(lambda x: bytes([x]).decode("utf-8"))

0,1
key,value
n,0.5599006419846053
s,0.43918110663135934
t,0.00043727665006992487
f,0.0004310077968950805
c,2.1143516265010492e-05


### Token sampling by sampling subtokens

As a sanity check, here we demonstrate that sequentially sampling subtokens from the subtoken potential is equivalent to directly sampling tokens from the original potential.

In [6]:
async def sample_token(subtoken_potential, token_ctx, draw=sample_dict, g=bytes):
    # Sample tokens by sampling subtokens until EOS is reached.

    # Initialize with the weight of the empty sequence.
    log_w = await subtoken_potential.prefix([], token_ctx)
    log_p = 0
    subtokens = []
    while True:
        # This is defined over bytes \cup {EOS, EOT}.
        # EOS is the end-of-sequence token (which is a subtoken).
        # EOT is the end-of-token token.
        logws = await subtoken_potential.logw_next(subtokens, token_ctx)

        logps = logws.normalize()
        x = draw(logps.exp())
        log_w += logws.sum()
        log_p += logps[x]

        if x == EOS:  # Special case post-processing for EOS.
            # Note: we could do the extra step of adding EOS to the subtoken and then immediately sampling EOT with p=1.
            assert not subtokens, "EOS can't come in the middle of a token."
            return EOS, log_w, log_p

        if x == EOT:  # We've decided to end this token.
            return g(subtokens), log_w, log_p

        subtokens.append(x)

In [7]:
token_ctx = []
token, logw, logp = await sample_token(subtoken_llm, token_ctx)
(token, logw, logp)

(b' guy', 2.028800662970882e-05, -3.8421235250695993)

Because this is a probability distribution, the weights will all be 1 (modulo floating point precision). Thus, we expect the log probability of each sample of subtokens to be equal to the log probabilty of the token.

In [8]:
logps = await llm.logw_next(token_ctx)
logps[token]

-3.8421032

This strategy is a properly weighted sampler (the following should work for non-probabilistic potentials):

In [10]:
from genlm_control.tracer import TraceSWOR
from arsenal.maths import logsumexp

context = []
tracer = TraceSWOR()
swor_logws = llm.alloc_logws()

while tracer.root.mass > 0:
    with tracer:
        token, logw, logp = await sample_token(subtoken_llm, context, draw=tracer)
        token_id = llm.lookup[token]
        swor_logws[token_id] = logsumexp([swor_logws[token_id], logw + logp])

have = llm.make_lazy_weights(swor_logws)
want = await llm.logw_next(context)

have.assert_equal(want, rtol=1e-5, atol=1e-5)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


### Sampling arbitrary units by sampling subtokens

In [11]:
def bern(p):
    return {True: p, False: 1 - p}

In [12]:
async def sample_unit(
    subtoken_potential, unit_potential, tokens, subtokens, draw=sample_dict, g=bytes
):
    unit = []
    starts_unit = await unit_potential.logw_next(g(unit))
    starts_unit = starts_unit.exp()

    V = subtoken_potential.vocab[:-1]  # Excluding EOS token.

    while True:
        # Get subtoken weights.
        u_ws = (await unit_potential.logw_next(g(unit))).exp()
        s_ps = (await subtoken_potential.logw_next(subtokens, tokens)).normalize().exp()

        # Assume, for now, that the unit potential is boolean.
        assert np.all(np.logical_or(u_ws.weights == 0, u_ws.weights == 1))

        # Maybe sample the EOS token (special case this decision).
        # We assume that EOS is a unit.
        eos_p = s_ps[EOS]
        if draw(bern(eos_p)):
            assert not subtokens, "EOS can't come in the middle of a token."
            return EOS, tokens, subtokens

        # Maybe end the token.
        p_eot = s_ps[EOT]
        if draw(bern(p_eot)):
            tokens.append(g(subtokens))
            subtokens = []
            continue

        # Maybe end the unit.
        # EOS in the unit potential is the unit boundary.
        can_stop = u_ws[EOS]
        if can_stop > 0:
            p_starts = sum(s_ps[x] * starts_unit[x] for x in V)
            p_continues = sum(s_ps[x] * u_ws[x] for x in V)
            p_eou = p_starts / (p_starts + p_continues)
            if draw(bern(p_eou)):
                return g(unit), tokens, subtokens

        # Sample next subtoken
        p_next = {x: u_ws[x] * s_ps[x] for x in V}
        subtoken = draw(p_next)

        unit.append(subtoken)
        subtokens.append(subtoken)

#### Sample words

In [13]:
subtoken_potential = SubtokenPotential(llm)
word_potential = BoolFSA.from_regex(r"\s[A-Za-z0-9]+")

In [14]:
tokens = []
subtokens = []

for _ in range(10):
    unit, tokens, subtokens = await sample_unit(
        subtoken_potential, word_potential, tokens, subtokens
    )
    print(unit, tokens, subtokens)

b' dollar' [b' dollar'] []
b' in' [b' dollar', b' in'] []
b' the' [b' dollar', b' in', b' the'] []
b' world' [b' dollar', b' in', b' the', b' world'] []
b' and' [b' dollar', b' in', b' the', b' world', b' and'] []
b' the' [b' dollar', b' in', b' the', b' world', b' and', b' the'] []
b' big' [b' dollar', b' in', b' the', b' world', b' and', b' the', b' big'] []
b' blue' [b' dollar', b' in', b' the', b' world', b' and', b' the', b' big', b' blue'] []
b' dollar' [b' dollar', b' in', b' the', b' world', b' and', b' the', b' big', b' blue', b' dollar'] []
b' in' [b' dollar', b' in', b' the', b' world', b' and', b' the', b' big', b' blue', b' dollar', b' in'] []


In [15]:
word_potential = BoolFSA.from_regex(
    r"\s?[A-Za-z0-9]+\s"
)  # Words must end with a space!

In [16]:
tokens = []
subtokens = []
# Problem: I don't think we will ever sample EOS here, because we'll almost always have token slop.
for _ in range(10):
    unit, tokens, subtokens = await sample_unit(
        subtoken_potential, word_potential, tokens, subtokens
    )
    print(unit, tokens, subtokens)

b' one ' [b' one'] [32]
b'from ' [b' one', b' from'] [32]
b'the ' [b' one', b' from', b' the'] [32]
b'left ' [b' one', b' from', b' the', b' left'] [32]
b'is ' [b' one', b' from', b' the', b' left', b' is'] [32]
b'the ' [b' one', b' from', b' the', b' left', b' is', b' the'] [32]
b'one ' [b' one', b' from', b' the', b' left', b' is', b' the', b' one'] [32]
b'that ' [b' one', b' from', b' the', b' left', b' is', b' the', b' one', b' that'] [32]
b'has ' [b' one', b' from', b' the', b' left', b' is', b' the', b' one', b' that', b' has'] [32]
b'been ' [b' one', b' from', b' the', b' left', b' is', b' the', b' one', b' that', b' has', b' been'] [32]


In [18]:
sentence = BoolFSA.from_regex(r".{10}")
tokens = []
subtokens = []

for _ in range(4):
    unit, tokens, subtokens = await sample_unit(
        subtoken_potential, sentence, tokens, subtokens
    )
    print(unit, tokens, subtokens)

b' boxes of ' [b' boxes', b' of'] [32]
b'the first-' [b' boxes', b' of', b' the', b' first', b'-'] []
b'time voter' [b' boxes', b' of', b' the', b' first', b'-', b'time'] [32, 118, 111, 116, 101, 114]
b's, with th' [b' boxes', b' of', b' the', b' first', b'-', b'time', b' voters', b',', b' with'] [32, 116, 104]
