# FSM-based Constrained Generation Example  
In this notebook, we demonstrate how to constrain the output of a language model using a finite state machine (FSM).  
We’ll use a set of pre-defined HTTP response codes and build an FSM that ensures our language model can only generate one of these valid codes.

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


## Step 1: Initialize Model and Tokenizer  
In this example, we are using the "Qwen/Qwen2.5-1.5B" model, but you can replace this with any model available to you.

In [2]:
model_name = "Qwen/Qwen2.5-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
torch.manual_seed(42)
model.eval()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
    )
    (norm): Qwen2RMSNorm()
  )
  (lm_head): Linear

## Step 2: Define the Target HTTP Response Codes  

In [3]:
HTTP_CODES = [
    # "200 OK",
    # "301 Moved Permanently",
    # "401 Unauthorized",
    # "404 Not Found",
    # "500 Internal Server Error"
    "200",
    "301",
    "401",
    "404",
    "500"
]

## Step 3: Build the FSM from the HTTP Codes  
We create a helper function `build_fsm_for_codes` which converts our HTTP status code strings into sequences of tokenizer token IDs.  
For each code, the function builds a chain of states (starting from 0) and records transitions between these states.  
Each code’s final state is marked as an accepting state.

In [4]:
def build_fsm_for_codes(codes, tokenizer):
    """
    For each code string (e.g. "404 Not Found"), we:
      1) Convert it into a sequence of subword tokens using tokenizer.encode(...).
      2) Build a chain of states for each subword token:
           - startState --(token1)--> nextState --(token2)-->...--(finalToken)--> acceptState
      3) Merge all codes so they share the same startState (0).
    """
    transitions = {}
    acceptance_states = set()
    start_state = 0
    next_free_state = 1

    for code_str in codes:
        code_token_ids = tokenizer.encode(code_str, add_special_tokens=False)
        current_state = start_state

        for tid in code_token_ids:
            if (current_state, tid) not in transitions:
                transitions[(current_state, tid)] = next_free_state
                current_state = next_free_state
                next_free_state += 1
            else:
                current_state = transitions[(current_state, tid)]
        acceptance_states.add(current_state)

    return transitions, acceptance_states, next_free_state - 1

In [5]:
transitions, accepting_states, max_state = build_fsm_for_codes(HTTP_CODES, tokenizer)

## Step 4: Create the Transition Table  
This mapping tells us, for each state in the FSM, which tokens (and corresponding next states) are allowed. It will be used to filter the language model’s output at every generation step.

In [6]:
state_to_valid_tokens = {s: [] for s in range(max_state + 1)}
for (s, tid), ns in transitions.items():
    state_to_valid_tokens[s].append((tid, ns))

## Step 5: Define the Constrained Generation Function  
The function `generate_http_code_with_fsm` takes a user prompt, then iteratively guides the language model.  
At each step, it restricts the allowed tokens based on the current FSM state. The process stops as soon as it reaches an accepting state or if no valid transitions remain.

In [7]:
def generate_http_code_with_fsm(user_prompt: str, max_steps=30):
    """
    1) Tokenizes the user prompt.
    2) Iteratively generates tokens, filtering them via the FSM transitions.
    3) Stops once an accepting state (a complete HTTP code) is reached, or no valid transitions are available.
    """
    input_ids = tokenizer(user_prompt, return_tensors="pt").input_ids # tokenize the prompt
    current_state = 0 
    generated_token_ids = []

    for step in range(max_steps):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits[:, -1, :]  # get logits for the last token

        valid_info = state_to_valid_tokens.get(current_state, []) # retrieve valid tokens for the current FSM state
        if not valid_info:
            # if no valid transitions are available, stop generation
            break

        valid_token_ids = [t[0] for t in valid_info]
        next_states = [t[1] for t in valid_info]

        # setting logits to -inf basically means setting the token probability to 0
        mask = torch.full_like(logits, float('-inf')) 
        mask[0, valid_token_ids] = logits[0, valid_token_ids] 

        # applying softmax to get probability as 0 for the invalid tokens
        probs = torch.nn.functional.softmax(mask, dim=-1)
        num_top = min(3, len(valid_token_ids))
        allowed_probs = probs[0, valid_token_ids]
        top_probs, top_indices = torch.topk(allowed_probs, k=num_top)
        print(f"Step {step+1}: Top {num_top} allowed tokens:")
        for prob, idx in zip(top_probs, top_indices):
            token_id = valid_token_ids[idx]
            token_str = tokenizer.decode([token_id], clean_up_tokenization_spaces=False)
            print(f"  Token: '{token_str}', Probability: {prob.item():.4f}")
        next_token_id = torch.argmax(probs, dim=-1).item()
        # next_token_id = torch.multinomial(probs, num_samples=1).item() # sample the next token

        chosen_index = valid_token_ids.index(next_token_id)
        current_state = next_states[chosen_index]

        generated_token_ids.append(next_token_id)
        next_token_tensor = torch.tensor([[next_token_id]])
        input_ids = torch.cat([input_ids, next_token_tensor], dim=1) # update input_ids which is used for the next token generation

        if current_state in accepting_states:
            # if any accepting state is reached, stop generation
            break

    return tokenizer.decode(generated_token_ids, clean_up_tokenization_spaces=False)

## Step 6: Test the Constrained Generation  
We now test our function with a prompt. The function should output an HTTP response code that follows the constraints of our FSM.

In [8]:
user_prompt = "HTTP status code (three digits only) for a site that has moved permanently:"
# user_prompt = "Enter the three-digit HTTP response code for a permanent redirect: "
# user_prompt = "The following HTTP status code should consist solely of three digits. For a successful request, type: "
# user_prompt = "Kindly supply only the numerical three-digit HTTP status code for a resource that has been moved permanently: "
# user_prompt = "HTTP code: Only a three-digit number. What code indicates a site not found error? "
# user_prompt = "What is the three-digit numeric HTTP status code for unauthorized access? "
# user_prompt = "When the server experiences an unexpected error during processing, it returns a specific three-digit code. Provide that code for an internal server error: "

output_text = generate_http_code_with_fsm(user_prompt)
print(f"\n{user_prompt}")
print(output_text)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Step 1: Top 3 allowed tokens:
  Token: '3', Probability: 0.5885
  Token: '4', Probability: 0.2437
  Token: '2', Probability: 0.0996
Step 2: Top 1 allowed tokens:
  Token: '0', Probability: 1.0000
Step 3: Top 1 allowed tokens:
  Token: '1', Probability: 1.0000

HTTP status code (three digits only) for a site that has moved permanently:
301


### NOTE
Hallucinations in language models are instances where the generated output includes information or tokens that don't strictly follow the input or expected constraints. Even under strict mechanisms like FSM-constrained decoding, hallucinations can occur due to several factors:
* The model’s internal probability distribution might weigh context, learned patterns, and prompt cues in unexpected ways, causing it to occasionally favor a token sequence that diverges from your intended output.
* Strong constraints can conflict with the natural flow of language that the model was trained on—if the allowed tokens don’t match the model’s unconstrained output distribution well, the decoder might produce extra tokens or slightly modify the intended answer.
* The prompt context might not perfectly isolate the constrained part, leading the model to “hallucinate” additional context before or after the desired three-digit code.
* Even small variations in tokenization can affect how the FSM maps allowed transitions to tokens, creating scenarios where the model’s logits and hard constraints interact unpredictably.