# Structured Generation

## Motivation

When we use an LLM to answer a question, the output is free-form text. But in many real-world applications we need the output in a **specific format**: a number, a JSON object, a date, a choice from a list, etc. The usual workaround is to generate free-form text and then parse it with a regular expression:


In [None]:
import re

answer = """The first 10 digits of pi (π) are as follows:

3.1415926535
"""

regex = r"([0-9]+)?\.[0-9]+"
print(re.search(regex, answer))

This works, but it is fragile: the LLM might produce text that doesn't match the pattern at all, forcing us to retry or fail. **Structured generation** solves this by constraining the LLM _during_ generation so that every output is guaranteed to match the desired format.

## Goal of this practical

In this practical you will implement structured generation from scratch, progressing through three increasingly efficient approaches:

1. **Naive approach** -- at each generation step, try every token in the vocabulary against a regex partial match. Simple but $O(V)$ per step.
2. **DFA-based approach** -- compile the regex to a Deterministic Finite Automaton (DFA), then precompute which tokens are valid from each DFA state. The per-step cost drops to $O(1)$.
3. **Coalescence** -- observe that many DFA states allow the exact same set of tokens, so we can precompute and _share_ mask arrays across equivalent states, eliminating even the array-allocation overhead.

We follow the ideas of [Efficient Guided Generation for Large Language Models](https://arxiv.org/abs/2307.09702) by Brandon T. Willard and Remi Louf.

## Setup

Throughout the notebook we use a tiny toy vocabulary of 4 tokens: `["a", ".", ".2", "1"]` and the regex pattern `([0-9]+)?\.[0-9]+` (a decimal number like `3.14`). This keeps things small enough to inspect by hand. At the end, we scale up to GPT-2's real 50k-token vocabulary.


## Part 1: Naive constrained generation with regex partial matching

The key idea is simple: at each generation step, we **mask out** every token that would make the output incompatible with the target regex. The LLM can then only sample from tokens that keep a valid match possible.

Concretely, the algorithm works as follows:

1. Start with an empty string as prefix.
2. For every token in the vocabulary, concatenate it to the prefix and check whether the result is a **partial match** of the regex (i.e. could still lead to a full match if we kept appending characters).
3. If not, set that token's logit to $-\infty$ so it cannot be sampled.
4. Sample the next token from the masked logits.
5. Append the sampled token to the prefix and go back to step 2.

The diagram below illustrates this process for the regex `[0-9]+\.[0-9]` with the vocabulary `["a", ".", ".2", "1"]`:

![](https://cdn.prod.website-files.com/665725b00d910f65bec567fc/668c29d45780ee71a367c839_naive.png)

### Partial matching

A **partial match** is a string that matches the regex up to its last character -- one for which we could find a continuation that would produce a full match. Python's standard `re` library does not support partial matching, but the [regex](https://github.com/mrabarnett/mrab-regex) library does, via the `partial=True` flag:


In [None]:
import regex as re

regex = r"([0-9]+)?\.[0-9]+"
print(re.fullmatch(regex, '.2', partial=True))

In [None]:
print(re.fullmatch(regex, '1.2', partial=True))

In [None]:
print(re.fullmatch(regex, '1.2a', partial=True))

### Exercise 1: Implement naive constrained generation

Use `re.fullmatch(regex, string, partial=True)` to build a logit mask at each generation step. The mask should be a numpy array with `0` for tokens whose concatenation with the current prefix gives a partial match, and `-math.inf` otherwise. We simulate an LLM with uniform logits (all equal to 1) so the generation is purely driven by the mask.


In [None]:
import math
import regex as re
import numpy as np

from scipy.special import softmax

np.random.seed(12349)

logits = np.array([1., 1., 1., 1.])  # Random model with equal probabilities
vocabulary = ["a", ".", ".2", "1"]

regex = r"([0-9]+)?\.[0-9]+"

completion = ""
for _ in range(7):

    # Build the logit mask
    # For each token in the vocabulary, check if appending it to the current
    # completion could lead to a valid match (using partial matching).
    # Set mask to 0 for valid tokens, -inf for invalid ones.
    # The result should be a numpy array called `mask`.
    #
    # your code here
    #
    
    masked_logits = logits + mask

    # Sample the next token
    probs = softmax(masked_logits)
    next_token_id = np.random.choice(len(vocabulary), p=probs)

    completion += vocabulary[next_token_id]

print(completion)

## Part 2: DFA-based constrained generation

The naive approach works correctly but is **far too slow** for real LLMs. With a vocabulary of $V \approx 50{,}000$ tokens, we perform 50k regex partial matches _per generated token_. In Python this easily dominates inference time.

The key insight is that regular expressions are equivalent to **Deterministic Finite Automata (DFA)**. A DFA is a directed graph where:
- Each **node** is a state.
- Each **edge** is a transition labeled with a character (or character class).
- There is one **initial state** and one or more **accept (final) states**.

To check whether a string matches, you walk the DFA one character at a time:

1. Start in the initial state with the full string.
2. Read the next character. If there is a matching transition, follow it to the next state. Otherwise, **reject**.
3. After consuming the entire string, **accept** if you are in a final state.

### Why does this help?

Instead of running a regex engine on every `(prefix + token)` pair at every step, we can:
1. **Precompute**, for each DFA state, which tokens lead to valid transitions.
2. At generation time, just **look up** the current state to get the set of allowed tokens -- no regex matching needed.

This trades an $O(V)$-per-step cost for a one-time $O(V \times |\text{states}|)$ precomputation.

### Building the DFA

We use the [interegular](https://github.com/MegaIng/interegular) library to convert our regex into its equivalent DFA. Let's see what it looks like:


In [None]:
import interegular

regex = r"([0-9]+)?\.[0-9]+"
fsm = interegular.parse_pattern(regex).to_fsm()

print(fsm)

In [None]:
print(fsm.alphabet)

In [None]:
for (start, transitions) in fsm.map.items():
    print(start, transitions)

In [None]:
print(fsm.initial)

In [None]:
print(fsm.finals)

In [None]:
fsm.alphabet.keys()

### Reading the DFA

The DFA has three key components:

- **`fsm.alphabet`** maps each character to a **symbol index**. Characters that behave identically (e.g. all digits `0`-`9`) share the same index. There is also an `anything_else` symbol for characters not explicitly listed.
- **`fsm.map`** is the transition table: `fsm.map[state][symbol_index] = next_state`. If a `(state, symbol_index)` pair is missing, there is no valid transition (the string is rejected).
- **`fsm.initial`** and **`fsm.finals`** are the start state and set of accept states.

Let's visualize this DFA as a graph:


In [None]:
from interegular import fsm as fsm_module
from collections import defaultdict

# Build edge labels dynamically from fsm.alphabet
idx_to_chars = defaultdict(list)
for char, idx in fsm.alphabet.items():
    if char is fsm_module.anything_else:
        idx_to_chars[idx].append("*")
    else:
        idx_to_chars[idx].append(char)

# Collapse groups into readable labels, e.g. ['0','1',...,'9'] -> "[0-9]"
idx_to_label = {}
for idx, chars in idx_to_chars.items():
    if len(chars) > 3:
        idx_to_label[idx] = f"[{chars[0]}-{chars[-1]}]"
    else:
        idx_to_label[idx] = ",".join(chars)

In [None]:
import graphviz
from IPython.display import display

# Define the regex pattern
regex = r"([0-9]+)?\.[0-9]+"

# Convert regex to a finite state machine (FSM)
fsm = interegular.parse_pattern(regex).to_fsm()

# Generate Graphviz DOT format representation
dot = graphviz.Digraph(format="png")

# Add states to the graph
for state in fsm.states:
    shape = "doublecircle" if state in fsm.finals else "circle"
    dot.node(str(state), shape=shape)

# Add transitions to the graph
for (start, transitions) in fsm.map.items():
    for char, end in transitions.items():
        dot.edge(str(start), str(end), label=idx_to_label.get(char, str(char)))

display(dot)

In [None]:
fsm.map

### From characters to tokens: walking the DFA

The DFA operates on **characters**, but LLMs generate **tokens** (which can be multi-character strings like `".2"` or `"1"`). To check whether a token is compatible with a given DFA state, we need to "walk" the token through the DFA one character at a time.

Given a starting state and a token string, we:
1. Look up the first character in `fsm.alphabet` to get its symbol index.
2. Check if there is a transition from the current state for that symbol. If not, the token is **rejected** from this state.
3. Follow the transition to the next state and repeat for the remaining characters.
4. If we consume all characters without rejection, the token is **valid** from this state.

For example, starting from state 0 with token `".2"`:
- Character `"."` has symbol index 2, and `fsm.map[0][2] = 2` $\Rightarrow$ move to state 2.
- Character `"2"` has symbol index 0 (a digit), and `fsm.map[2][0] = 4` $\Rightarrow$ move to state 4.
- We traversed states $(0, 2, 4)$. State 4 is a final state, so `".2"` is a complete valid match!

Conversely, token `"a"` from state 0: character `"a"` maps to `anything_else` (symbol index 1), and there is no transition `fsm.map[0][1]`, so `"a"` is rejected.

### Exercise 2: Implement `partial_match`

Write a function that walks a token through the DFA and returns the tuple of traversed states, or `None` if the token is rejected.


In [None]:
def partial_match(state, token):
    """Partially match the token to the DFA starting from `state`.

    We iterate over the token's symbols, and at each step transition to the 
    next state if we find a valid transition. 
    If there is a stage without a valid transision, we return None, otherwise
    we return a tuple that contains the sequence of traversed states.

    Hints:
    - Use fsm.alphabet[symbol] to get the alphabet index of a character.
    - Use fsm.map[state] to get the transitions from a state.
    - Return a tuple of all traversed states (including the starting state).
    """
    
    traversed_states = (state,)
    # Iterate over the token's symbols, trying at each step to transition
    # to a new DFA state.
    #
    # your code here
    #
    
    return traversed_states

In [None]:
token = ".21"
print(partial_match(0, token))

In [None]:
token = ".21."
print(partial_match(0, token))

### Exercise 3: Build the state-to-token index

Now we need to precompute, for every DFA state, which tokens from our vocabulary correspond to a valid transition. We build two data structures:

- **`states_to_vocab[state]`**: the set of token IDs that are valid from `state`.
- **`states_token_states[state][token_id]`**: the DFA state we land on after consuming that token from `state`.

This is the one-time precomputation that makes generation fast: instead of running the regex for every token at every step, we just look up `states_to_vocab[current_state]`.


In [None]:
from collections import defaultdict

vocabulary = ["a", ".", ".2", "1"]

# Map from the DFA states to the tokens that correspond to a valid transition
# from this state.
states_to_vocab = defaultdict(set)
states_token_states = defaultdict(dict)

# Iterate (once) through the vocabulary and for each token, check from each
# DFA state whether partial_match finds a valid path.
# If so, record the token_id in states_to_vocab[state] and the landing state
# in states_token_states[state][token_id].
#
# your code here
#

In [None]:
states_to_vocab

### Exercise 4: DFA-based generation

With the index built, generation becomes straightforward:

1. Start in the initial DFA state. Look up `states_to_vocab[state]` to get the allowed tokens.
2. Build a mask: $-\infty$ everywhere, then set the allowed positions to 0.
3. Add the mask to the logits and sample.
4. Look up `states_token_states[state][token_id]` to transition to the next DFA state.
5. Repeat.

You should get the **same result** as the naive approach (`11.21111`) since the random seed is the same.


In [None]:
np.random.seed(12349) # you should get the same result as before

logits = np.array([1., 1., 1., 1.])  # same as before

regex = r"([0-9]+)?\.[0-9]+"

completion = ""
state = fsm.initial
for _ in range(7):

    # Build the logit mask using states_to_vocab[state]
    # (no regex needed — just a set lookup!)
    #
    # your code here
    #

print(completion)

### A small Benchmark

We now benchmark the two approaches (naive regex partial matching vs DFA-based constrained decoding) across increasing vocabulary sizes. This demonstrates why the naive approach becomes impractical as the vocabulary grows to realistic LLM sizes (~50k tokens), and why precomputing valid transitions via a DFA is essential.

We build a synthetic vocabulary with a small "useful" core of digit/dot tokens and pad the rest with junk tokens that will never match (a realistic scenario where most tokens are irrelevant to the regex).

In [None]:
import time

REGEX = r"([0-9]+)?\.[0-9]+"

def build_vocabulary(size: int):
    """Build a synthetic vocabulary: a small useful core + junk padding tokens."""
    core = [
        ".", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
        ".0", ".1", ".2", ".3", ".4", ".5", ".6", ".7", ".8", ".9",
        "10", "12", "42", "100", "256",
    ]
    padding = [f"tok_{i}" for i in range(size - len(core))]
    vocab = core + padding
    return vocab[:size]


def naive_mask(completion, vocabulary, pattern):
    """Approach 1: one regex partial match per token per step — O(V) per token."""
    mask = []
    for token in vocabulary:
        tentative = completion + token
        if re.fullmatch(pattern, tentative, partial=True) is None:
            mask.append(-math.inf)
        else:
            mask.append(0.0)
    return np.array(mask)


def build_dfa_index(vocabulary, pattern):
    """Approach 2: one-time precomputation — build the DFA and index valid tokens per state."""
    fsm = interegular.parse_pattern(pattern).to_fsm()

    def _partial_match(state, token):
        traversed = (state,)
        for symbol in token:
            alphabet_idx = fsm.alphabet.get(symbol)
            if alphabet_idx is None:
                alphabet_idx = fsm.alphabet.get(interegular.fsm.anything_else)
            if state not in fsm.map or alphabet_idx not in fsm.map[state]:
                return None
            state = fsm.map[state][alphabet_idx]
            traversed += (state,)
        return traversed

    states_to_vocab = defaultdict(set)
    states_token_states = defaultdict(dict)

    for token_id, token in enumerate(vocabulary):
        for state in fsm.map:
            path = _partial_match(state, token)
            if path is not None:
                states_to_vocab[state].add(token_id)
                states_token_states[state][token_id] = path[-1]

    return fsm, states_to_vocab, states_token_states


def dfa_mask(state, states_to_vocab, vocab_size):
    """O(1) lookup per token using the precomputed index."""
    mask = np.full(vocab_size, -np.inf)
    valid = list(states_to_vocab[state])
    if valid:
        mask[valid] = 0.0
    return mask

In [None]:
def benchmark(vocab_sizes, n_steps=7, n_repeats=3):
    results = []

    print(f"Regex: {REGEX}")
    print(f"Generating {n_steps} tokens per run, median of {n_repeats} repeats\n")
    print(f"{'Vocab size':>12}  {'Naive (ms/step)':>16}  {'DFA mask (ms/step)':>18}  {'DFA precomp (ms)':>16}  {'Speedup':>8}")
    print("-" * 80)

    for V in vocab_sizes:
        vocabulary = build_vocabulary(V)
        logits = np.ones(V)

        # ---- Naive timing ----
        naive_times = []
        for _ in range(n_repeats):
            np.random.seed(42)
            completion = ""
            t0 = time.perf_counter()
            for _ in range(n_steps):
                mask = naive_mask(completion, vocabulary, REGEX)
                masked_logits = logits + mask
                probs = softmax(masked_logits)
                next_id = np.random.choice(V, p=probs)
                completion += vocabulary[next_id]
            naive_times.append((time.perf_counter() - t0) / n_steps)
        naive_ms = np.median(naive_times) * 1000

        # ---- DFA precomputation ----
        t0 = time.perf_counter()
        fsm_bench, s2v, sts = build_dfa_index(vocabulary, REGEX)
        precomp_ms = (time.perf_counter() - t0) * 1000

        # ---- DFA mask timing ----
        dfa_times = []
        for _ in range(n_repeats):
            np.random.seed(42)
            state = fsm_bench.initial
            completion = ""
            t0 = time.perf_counter()
            for _ in range(n_steps):
                mask = dfa_mask(state, s2v, V)
                masked_logits = logits + mask
                probs = softmax(masked_logits)
                next_id = np.random.choice(V, p=probs)
                state = sts[state][next_id]
                completion += vocabulary[next_id]
            dfa_times.append((time.perf_counter() - t0) / n_steps)
        dfa_ms = np.median(dfa_times) * 1000

        speedup = naive_ms / dfa_ms if dfa_ms > 0 else float("inf")
        results.append((V, naive_ms, dfa_ms, precomp_ms, speedup))
        print(f"{V:>12,}  {naive_ms:>14.2f}ms  {dfa_ms:>16.3f}ms  {precomp_ms:>14.1f}ms  {speedup:>7.0f}x")

    return results

results = benchmark([100, 500, 1_000, 5_000, 10_000])

In [None]:
import matplotlib.pyplot as plt

vocab_sizes = [r[0] for r in results]
naive_times = [r[1] for r in results]
dfa_times = [r[2] for r in results]

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(vocab_sizes, naive_times, "o-", label="Naive (regex partial match)")
ax.plot(vocab_sizes, dfa_times, "s-", label="DFA (precomputed index)")
ax.set_xlabel("Vocabulary size")
ax.set_ylabel("Time per generation step (ms)")
ax.set_title("Constrained Decoding: Naive vs DFA")
ax.legend()
ax.set_yscale("log")
ax.grid(True, which="both", ls="--", alpha=0.5)
plt.tight_layout()
plt.show()

**Key takeaway:**
- **Naive:** O(V) partial matches **per generated token** — linear in vocab size, dominates generation time.
- **DFA:** O(V x |states|) **one-time** precomputation, then O(1) lookup per token.
- At V=10k+, the naive approach is orders of magnitude slower. At realistic LLM vocab sizes (~50k), it becomes completely impractical.

## Coalescence: precomputing masks

The DFA approach already gives a huge speedup: instead of running a regex partial match per token, we do a set lookup. But there is still overhead at each generation step: we allocate a numpy array (`np.full(V, -inf)`) and set the valid indices to 0.

**Coalescence** observes that many DFA states share the *exact same* set of allowed tokens. For our regex, look at `states_to_vocab` above: states 4 and 5 both allow `{3}`, states 0, 1, and 3 all allow `{1, 2, 3}`. Why recompute the same mask for equivalent states?

The idea:
1. Build the token-level FSM (using `token_fsm.py`)
2. Group states by their `frozenset` of allowed token IDs
3. **Precompute one mask array per group** during the one-time preprocessing
4. Build a lookup: `state → precomputed_mask`
5. At generation time: `mask = precomputed_masks[state]` — a single dict lookup, **zero array allocation**

In [None]:
from token_fsm import make_deterministic_fsm, create_fsm_index_tokenizer


def build_coalesced_index(vocabulary, pattern):
    """Build the token-level FSM and precompute one mask per unique allowed-token set.

    Returns
    -------
    token_fsm : TokenFSM with .map[state][token_id] -> next_state
    precomputed_masks : dict[state] -> np.array (the logit mask, ready to use)
    precomp_ms : total precomputation time in milliseconds
    """
    V = len(vocabulary)
    t0 = time.perf_counter()

    # Step 1: regex -> character-level DFA -> clean up
    raw_fsm = interegular.parse_pattern(pattern).to_fsm()
    clean_fsm, _ = make_deterministic_fsm(raw_fsm)

    # Step 2: build token-level FSM
    tok_fsm, index = create_fsm_index_tokenizer(clean_fsm, vocabulary)

    # Step 3: coalescence — group states by their allowed token set
    # For each state in clean_fsm.states, get the frozenset of allowed token IDs
    # from tok_fsm.map. States with the same allowed set should share the same
    # precomputed mask (a numpy array of shape (V,) with 0 for allowed tokens
    # and -inf for the rest).
    # Build:
    #   mask_cache: frozenset -> np.array (one mask per unique token set)
    #   precomputed_masks: state -> np.array (lookup for generation)
    #
    # your code here
    #

    precomp_ms = (time.perf_counter() - t0) * 1000

    n_states = len(clean_fsm.states)
    n_groups = len(mask_cache)
    print(f"  Coalescence: {n_states} states -> {n_groups} unique masks")

    return tok_fsm, precomputed_masks, precomp_ms

### Exercise 5: Implement coalescence

The helper module `token_fsm.py` provides two functions:
- **`make_deterministic_fsm(fsm)`** cleans up the character-level DFA (remaps states to contiguous integers).
- **`create_fsm_index_tokenizer(fsm, vocabulary)`** builds a **token-level FSM** -- an FSM whose transitions are over token IDs instead of characters. It returns a `TokenFSM` object with a `.map[state][token_id] = next_state` transition table.

Your task is to implement the coalescence step: group states that share the same set of allowed tokens, precompute one mask per group, and build a `precomputed_masks` dict.


In [None]:
# --- Configuration ---
regex_pattern = r"([0-9]+)?\.[0-9]+"
vocabulary = ["a", ".", ".2", "1"]

print(f"Regex:      {regex_pattern}")
print(f"Vocabulary: {vocabulary}")

### Walking through the token-level FSM step by step

Before using `build_coalesced_index`, let's see how the pipeline from `token_fsm.py` works on our toy example. We go through each step: parsing the regex, cleaning the DFA, building the token-level FSM, and running constrained generation with it.


In [None]:
# --- Step 1: Parse regex to character-level DFA ---
raw_fsm = interegular.parse_pattern(regex_pattern).to_fsm()
print("── Raw character-level DFA ──")
print(f"  States:  {raw_fsm.states}")
print(f"  Initial: {raw_fsm.initial}")
print(f"  Finals:  {raw_fsm.finals}")
print(f"  Transitions:")
for state, trans in sorted(raw_fsm.map.items(), key=lambda x: str(x[0])):
    print(f"    State {state}: {dict(trans)}")

In [None]:
# --- Step 2: Clean up the DFA ---
clean_fsm, state_mapping = make_deterministic_fsm(raw_fsm)
print(f"── Cleaned DFA (state mapping: {state_mapping}) ──")
print(f"  States:  {clean_fsm.states}")
print(f"  Initial: {clean_fsm.initial}")
print(f"  Finals:  {clean_fsm.finals}")

In [None]:
# --- Step 3: Build token-level FSM ---
token_fsm, index = create_fsm_index_tokenizer(clean_fsm, vocabulary)

print(f"── Token-level FSM ──")
print(f"  Initial state: {token_fsm.initial}")
print(f"  Accept states: {token_fsm.finals}")
print(f"\n  Transition table (state → token → next_state):")
for state in sorted(token_fsm.map.keys()):
    for tid, next_s in sorted(token_fsm.map[state].items()):
        print(f"    State {state} --[{tid}: '{vocabulary[tid]}']→ State {next_s}")

In [None]:
# --- Step 4: Constrained generation using the token FSM ---
print("── Constrained generation ──")

np.random.seed(12349)
logits = np.ones(len(vocabulary))

completion = ""
state = token_fsm.initial

for step in range(7):
    # Masking is now just a set lookup — no regex needed!
    allowed = token_fsm.allowed_token_ids(state)
    mask = np.full(len(vocabulary), -np.inf)
    mask[list(allowed)] = 0.0

    masked_logits = logits + mask
    probs = softmax(masked_logits)
    next_id = np.random.choice(len(vocabulary), p=probs)

    next_state = token_fsm.next_state(state, next_id)
    print(f"  Step {step}: state={state}, "
          f"allowed={[vocabulary[i] for i in sorted(allowed)]}, "
          f"sampled='{vocabulary[next_id]}' → state={next_state}")

    state = next_state
    completion += vocabulary[next_id]

print(f"\n  Final completion: '{completion}'")
is_full_match = state in token_fsm.finals
print(f"  In accept state?  {is_full_match}")

## Part 4: Applying to a JSON schema with a real tokenizer

So far we used a toy vocabulary of 4 tokens. In practice, LLMs use **BPE (Byte-Pair Encoding) tokenizers** with ~50k tokens, where a single token can be a multi-character subword like `"name"`, `":"`, or `"John"`.

This creates an important subtlety: a single BPE token can span **multiple DFA transitions** at once. For example, the token `"name"` walks through 4 character-level DFA states in one step. This is exactly why we need a **token-level FSM** rather than a character-level one.

Let's see how the token-level FSM handles a JSON-structured regex with GPT-2's real tokenizer. The pattern enforces a specific JSON schema:

```
\{"name":("John"|"Paul"),"age":(20|30)\}
```

This is the bridge between **structured output** (JSON) and **constrained decoding**: we express the schema as a regex, compile it to a DFA, then build the token-level index over the real BPE vocabulary.

### Exercise 6: Implement `walk_token_fsm` and build the index

This exercise is similar to Exercise 2 (`partial_match`), but now you must also handle the `anything_else` alphabet symbol (for characters not explicitly in the regex), and build the full index over GPT-2's 50k vocabulary.


In [None]:
json_pattern = r'\{"name":("John"|"Paul"),"age":(20|30)\}'

# Build character-level DFA
raw_fsm = interegular.parse_pattern(json_pattern).to_fsm()
json_fsm, _ = make_deterministic_fsm(raw_fsm)

print(f"Pattern: {json_pattern}")
print(f"DFA: {len(json_fsm.states)} states, {len(json_fsm.finals)} accept state(s)")
print(f"\nCharacter-level DFA transitions:")
for state in sorted(json_fsm.map.keys()):
    print(f"  State {state}: {dict(json_fsm.map[state])}")

In [None]:
from transformers import AutoTokenizer
from interegular import fsm as fsm_module

tokenizer = AutoTokenizer.from_pretrained("gpt2")

def walk_token_fsm(fsm, state, token_str):
    """Walk a token string through the character-level DFA.
    
    For each character in token_str, look up its alphabet symbol index,
    then check if there's a valid transition from the current state.
    Return the final state if the full token is consumed, or None if
    any character has no valid transition.

    Hints:
    - Use fsm.alphabet to map characters to symbol indices.
    - Handle characters not in the alphabet using fsm_module.anything_else.
    - Use fsm.map.get(state, {}) for transitions.
    """
    #
    # your code here
    #

print(f"Tokenizer vocab size: {tokenizer.vocab_size:,}")
print("Building token-level FSM index (this may take a minute)...")

t0 = time.perf_counter()
json_index = defaultdict(dict)
# For each DFA state and each token in the vocabulary, check if the token
# can walk through the DFA starting from that state using walk_token_fsm.
# If so, record the landing state in json_index[state][token_id].
# Use tokenizer.decode([token_id]) to get the token string.
#
# your code here
#
elapsed = time.perf_counter() - t0

print(f"Index built in {elapsed:.1f}s")
print(f"States with valid transitions: {len(json_index)}")

In [None]:
# Display: for each state, show the decoded BPE tokens and their target states
for state in sorted(json_index.keys()):
    transitions_decoded = {
        repr(tokenizer.decode([tid])): next_s
        for tid, next_s in json_index[state].items()
    }
    print(f"State {state}: {transitions_decoded}")

Notice how BPE multi-character tokens like `'name'`, `'":"'`, `'John'`, `',"'` each correspond to valid multi-step transitions through the DFA. At each generation step, the LLM can jump several DFA states at once by emitting a single multi-character token. This is exactly why we need a **token-level** FSM rather than a character-level one: real tokenizers don't emit one character at a time.

In [None]:
def plot_char_dfa(fsm):
    """Plot the character-level DFA with graphviz."""
    idx_to_chars = defaultdict(list)
    for char, idx in fsm.alphabet.items():
        if char is fsm_module.anything_else:
            idx_to_chars[idx].append("*")
        elif char == " ":
            idx_to_chars[idx].append("⎵")
        else:
            idx_to_chars[idx].append(char)

    dot = graphviz.Digraph(
        name="Character-level DFA",
        graph_attr={"rankdir": "LR", "dpi": "50", "fontsize": "12"},
        node_attr={"fontsize": "11"},
        edge_attr={"fontsize": "9"},
    )
    dot.node("start", shape="point", width="0")
    dot.edge("start", str(fsm.initial))

    for state in sorted(fsm.states):
        shape = "doublecircle" if state in fsm.finals else "circle"
        dot.node(str(state), str(state), shape=shape)

    edge_labels = defaultdict(list)
    for state, transitions in fsm.map.items():
        for sym_idx, target in transitions.items():
            chars = idx_to_chars.get(sym_idx, [f"[{sym_idx}]"])
            edge_labels[(str(state), str(target))].append(",".join(chars))

    for (src, dst), labels in edge_labels.items():
        dot.edge(src, dst, label=" | ".join(labels))
    return dot


def plot_token_fsm(index, vocabulary, initial, finals):
    """Plot the token-level FSM with BPE tokens as edge labels."""
    dot = graphviz.Digraph(
        name="Token-level FSM",
        graph_attr={"rankdir": "LR", "dpi": "45", "fontsize": "12"},
        node_attr={"fontsize": "11"},
        edge_attr={"fontsize": "9"},
    )
    all_states = {initial} | set(finals)
    for state, transitions in index.items():
        all_states.add(state)
        for tid, target in transitions.items():
            all_states.add(target)

    dot.node("start", shape="point", width="0")
    dot.edge("start", str(initial))

    for state in sorted(all_states):
        shape = "doublecircle" if state in finals else "circle"
        dot.node(str(state), str(state), shape=shape)

    edge_labels = defaultdict(list)
    for state, transitions in index.items():
        for tid, target in transitions.items():
            token_str = vocabulary[tid]
            disp = token_str.replace('"', '\\"').replace("{", "\\{").replace("}", "\\}")
            edge_labels[(str(state), str(target))].append(f'"{disp}"')

    for (src, dst), labels in edge_labels.items():
        dot.edge(src, dst, label=" | ".join(labels))
    return dot

In [None]:
display(plot_char_dfa(json_fsm))

In [None]:
vocab_list = [tokenizer.decode([i]) for i in range(tokenizer.vocab_size)]
display(plot_token_fsm(json_index, vocab_list, json_fsm.initial, json_fsm.finals))