# Causal Classifier - Guided Generation (Testing)

This notebook is used to run guided text generation using a trained causal classifier from the modular FUDGE project. It loads a base LLM and a trained classifier checkpoint to guide the output in real-time. For a complete overview of the project architecture please see the [full project on the GitHub repo](https://github.com/latoohey/modular-fudge). The training script to create a classifier for guiding generation is also available as a [Colab notebook](https://colab.research.google.com/drive/1zVfBB_zIKHpSANBmTcj8KRBFwsPgbd9m?usp=sharing).

## Imports

In [None]:
# For debugging as needed
# import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
!pip install transformers accelerate bitsandbytes huggingface_hub
!pip install fastapi uvicorn pyngrok nest_asyncio

Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m59.4/59.4 MB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.48.2
Collecting pyngrok
  Downloading pyngrok-7.5.0-py3-none-any.whl.metadata (8.1 kB)
Downloading pyngrok-7.5.0-py3-none-any.whl (24 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.5.0


In [None]:
from argparse import Namespace
import csv
from google.colab import userdata
from huggingface_hub import login
from IPython.display import HTML, display
import numpy as np
import os
import random
import time
import torch
import torch.nn as nn
import pandas as pd
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from functools import lru_cache

import os
import csv
import time
import pandas as pd
import itertools
from pathlib import Path

import uvicorn
import nest_asyncio
import asyncio
from pyngrok import ngrok
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware # <--- Key Import
from pydantic import BaseModel
import time

import einops
import os
from google.colab import drive

In [None]:
LOAD_MAMBA = True

if LOAD_MAMBA:
  # 1. Mount Google Drive
  drive.mount('/content/drive')

  # 2. Define the directory on your Drive to store the wheels
  #    We use a specific folder name to keep it organized.
  wheel_dir = '/content/drive/MyDrive/colab_wheels/mamba_builds'
  os.makedirs(wheel_dir, exist_ok=True)

  # 3. Define the package versions you want
  packages = [
      "causal-conv1d>=1.4.0",
      "mamba-ssm"
  ]

  # 4. Check if wheels already exist in your Drive
  print(f"Checking for existing wheels in {wheel_dir}...")
  existing_wheels = [f for f in os.listdir(wheel_dir) if f.endswith('.whl')]

  if len(existing_wheels) >= len(packages):
      print("‚úÖ Found pre-built wheels! Installing from Drive...")
      # Install directly from your Drive folder
      !pip install "$wheel_dir"/*.whl
  else:
      print("‚ö†Ô∏è No wheels found. Building from source (this will take time once)...")

      # Install build dependencies first
      !pip install packaging ninja

      # Build the wheels and save them directly to your Drive
      # We use --no-deps to avoid building wheels for huge packages like PyTorch
      print(f"Building wheels to {wheel_dir}...")
      !pip wheel {" ".join(packages)} --wheel-dir="$wheel_dir" --no-deps

      # Now install the newly built wheels
      print("Installing newly built wheels...")
      !pip install "$wheel_dir"/*.whl

  from mamba_ssm import Mamba
  print("üéâ Done! Mamba and Causal-Conv1d are ready.")

Mounted at /content/drive
Checking for existing wheels in /content/drive/MyDrive/colab_wheels/mamba_builds...
‚úÖ Found pre-built wheels! Installing from Drive...
Processing ./drive/MyDrive/colab_wheels/mamba_builds/causal_conv1d-1.5.3.post1-cp312-cp312-linux_x86_64.whl
Processing ./drive/MyDrive/colab_wheels/mamba_builds/mamba_ssm-2.2.6.post3-cp312-cp312-linux_x86_64.whl
Collecting ninja (from causal-conv1d==1.5.3.post1)
  Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.1 kB)
Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (180 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m180.7/180.7 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ninja, causal-conv1d, mamba-ssm
Successfully installed causal-conv1d-1.5.3.post1 mamba-ssm-2.2.6.post3 ninja-1.13.0
üéâ Done! Mamba and Causal-Co

## Model Definitions

In [None]:
# --- From util.py ---
def num_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
class LSTMClassifier(nn.Module):

    def __init__(self, args, vocab_size, pad_token_id):
        """
        Initializes the LSTM model.

        Args:
            args: The full ArgumentParser namespace. Reads
                  `args.lstm_hidden_dim` and `args.lstm_num_layers`.
            vocab_size: The total vocabulary size for the embedding layer.
            pad_token_id: The ID of the padding token.
        """
        super().__init__()

        self.embed = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=args.lstm_hidden_dim,
            padding_idx=pad_token_id
        )

        self.rnn = nn.LSTM(
            args.lstm_hidden_dim,
            args.lstm_hidden_dim,
            num_layers=args.lstm_num_layers,
            bidirectional=False,
            dropout=0.5,
            batch_first=True # Makes the permute/transpose logic simpler
        )
        self.out_linear = nn.Linear(args.lstm_hidden_dim, 1)

    def forward(self, inputs, lengths):
        """
        Internal forward pass for the LSTM.
        Requires `lengths` for sequence packing.
        """
        # (batch_size, seq_len, hidden_dim)
        embedded_inputs = self.embed(inputs)

        # Pack sequence for efficient RNN processing
        packed_inputs = pack_padded_sequence(
            embedded_inputs,
            lengths.cpu(), # Must be on CPU
            batch_first=True,
            enforce_sorted=False
        )

        # rnn_output is (packed_batch, hidden_dim)
        rnn_output, _ = self.rnn(packed_inputs)

        # Unpack: (batch_size, seq_len, hidden_dim)
        rnn_output, _ = pad_packed_sequence(
            rnn_output,
            batch_first=True
        )

        # (batch_size, seq_len)
        return self.out_linear(rnn_output).squeeze(2)

    def get_final_scores(self, batch):
        """
        Returns the scores for the last token of each sequence in the batch.
        Used by the guided generation to condition on the last generated token.
        """
        inputs, lengths, _ = batch # _ is targets, not used here
        # The forward method returns (batch_size, seq_len)
        all_token_scores = self.forward(inputs, lengths)
        # Extract the score for the last token of each sequence
        # lengths is (batch_size,), all_token_scores is (batch_size, seq_len)
        final_scores = all_token_scores[torch.arange(inputs.size(0)), lengths - 1]
        return final_scores

In [None]:
"""
========================================================================
Mamba Classifier Model Definition
========================================================================

This file implements a Mamba-based classifier that follows the same
contract as the LSTM classifier for compatibility with the project's
main training (`main_train.py`) and evaluation (`evaluate.py`) scripts.

The Mamba architecture uses selective state space models (SSMs) for
efficient sequence modeling with linear complexity in sequence length.

Requirements:
- mamba-ssm (install with: pip install mamba-ssm)
- torch
- einops
"""

class MambaClassifier(nn.Module):

    def __init__(self, args, vocab_size, pad_token_id):
        """
        Initializes the Mamba model.

        Args:
            args: The full ArgumentParser namespace. Reads:
                  - `args.mamba_d_model` (hidden dimension, default 256)
                  - `args.mamba_d_state` (SSM state dimension, default 16)
                  - `args.mamba_d_conv` (local convolution width, default 4)
                  - `args.mamba_expand` (expansion factor, default 2)
                  - `args.mamba_num_layers` (number of Mamba blocks, default 4)
                  - `args.mamba_dropout` (dropout rate, default 0.1)
            vocab_size: The total vocabulary size for the embedding layer.
            pad_token_id: The ID of the padding token.
        """
        super().__init__()

        # Get hyperparameters from args with defaults
        self.d_model = getattr(args, 'mamba_d_model', 256)
        self.d_state = getattr(args, 'mamba_d_state', 16)
        self.d_conv = getattr(args, 'mamba_d_conv', 4)
        self.expand = getattr(args, 'mamba_expand', 2)
        self.num_layers = getattr(args, 'mamba_num_layers', 4)
        self.dropout_rate = getattr(args, 'mamba_dropout', 0.1)

        # Embedding layer
        self.embed = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=self.d_model,
            padding_idx=pad_token_id  # Use pad_token_id from tokenizer
        )

        # Dropout for regularization
        self.dropout = nn.Dropout(self.dropout_rate)

        # Stack of Mamba blocks
        self.mamba_blocks = nn.ModuleList([
            Mamba(
                d_model=self.d_model,    # Model dimension
                d_state=self.d_state,    # SSM state expansion factor
                d_conv=self.d_conv,      # Local convolution width
                expand=self.expand,      # Block expansion factor
            )
            for _ in range(self.num_layers)
        ])

        # Layer normalization between blocks
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(self.d_model)
            for _ in range(self.num_layers)
        ])

        # Final layer norm before output
        self.final_norm = nn.LayerNorm(self.d_model)

        # Output projection to single logit per token
        self.out_linear = nn.Linear(self.d_model, 1)

        # Initialize weights
        self._init_weights()


    def _init_weights(self):
        # Initialize embedding layer
        nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)

        # Initialize linear output layer
        nn.init.normal_(self.out_linear.weight, mean=0.0, std=0.02)
        if self.out_linear.bias is not None:
            nn.init.constant_(self.out_linear.bias, 0)

    def forward(self, inputs, lengths=None):
        """
        Internal forward pass for the Mamba model.

        Note: Mamba handles variable-length sequences naturally without
        packing/unpacking, but we accept lengths for compatibility.

        Args:
            inputs: Token IDs of shape (batch_size, seq_len)
            lengths: Sequence lengths (optional, for compatibility)

        Returns:
            scores: Per-token logits of shape (batch_size, seq_len)
        """
        batch_size, seq_len = inputs.shape

        # Embed tokens: (batch_size, seq_len, d_model)
        x = self.embed(inputs)
        x = self.dropout(x)

        # Create causal mask if needed (for padding)
        # Mamba is inherently causal, but we need to handle padding
        if lengths is not None:
            # Create attention mask for padded positions
            # Shape: (batch_size, seq_len)
            mask = torch.arange(seq_len, device=inputs.device).unsqueeze(0) < lengths.unsqueeze(1)
            # Expand mask to match hidden dimension for masking
            # Shape: (batch_size, seq_len, 1)
            mask = mask.unsqueeze(-1).float()
        else:
            mask = None

        # Process through Mamba blocks with residual connections
        for i, (mamba_block, layer_norm) in enumerate(zip(self.mamba_blocks, self.layer_norms)):
            # Pre-norm architecture
            residual = x
            x = layer_norm(x)

            # Mamba block
            x = mamba_block(x)

            # Apply mask if available (zero out padded positions)
            if mask is not None:
                x = x * mask

            # Residual connection and dropout
            x = residual + self.dropout(x)

        # Final normalization
        x = self.final_norm(x)

        # Project to logits: (batch_size, seq_len, 1) -> (batch_size, seq_len)
        scores = self.out_linear(x).squeeze(-1)

        return scores

    def get_final_scores(self, batch):
        """
        Adapter for evaluation.
        Unpacks batch, calls `self.forward`, and returns final logit.

        Args:
            batch: The raw, collated batch from the DataLoader.

        Returns:
            last_logits: torch.Tensor of shape (batch_size,)
                        The logit from the last real token for each item.
        """
        # Unpack the batch
        inputs, lengths, _ = batch

        # Move tensors to the model's device
        inputs = inputs.to(self.embed.weight.device)
        lengths = lengths.to(self.embed.weight.device)

        # Call forward pass
        # scores shape: (batch_size, seq_len)
        scores = self.forward(inputs, lengths)

        # Find the index of the last token for each sequence
        # Shape: (batch_size,)
        last_indices = (lengths - 1).long()

        # Gather the scores from the last valid position
        # Shape: (batch_size, 1) -> (batch_size,)
        last_logits = scores.gather(
            1, last_indices.unsqueeze(1)
        ).squeeze(1)

        return last_logits

In [None]:
def get_model(args, vocab_size_param, pad_token_id_param):
    """
    This factory function reads the --model_type argument
    and returns the correct, initialized model.
    """
    if args.model_type == 'lstm':
        return LSTMClassifier(args, vocab_size_param, pad_token_id_param)
    elif args.model_type == 'mamba':
        return MambaClassifier(args, vocab_size_param, pad_token_id_param)
    else:
        raise ValueError(f"Unknown model type: {args.model_type}")

In [None]:
# 1. Add the decorator. maxsize=1 is usually enough if you just
#    want to hold the current model in memory.
@lru_cache(maxsize=1)
def load_classifier(ckpt_path, device):
    """Loads a trained classifier from a checkpoint using the model factory."""

    # This print statement will only run the FIRST time you call the function
    # with a specific path/device combination.
    print(f"Loading classifier from {ckpt_path}...")

    checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)

    # Load args *from the checkpoint* to know what model to build
    model_args = checkpoint['args']
    print(f"Checkpoint args: {model_args}")

    # This assumes your main_train.py saved 'tokenizer_name' in its args
    if not hasattr(model_args, 'tokenizer_name'):
        tokenizer_name = CLASSIFIER_TOKENIZER_NAME # Ensure this global is defined or passed in
    else:
        tokenizer_name = model_args.tokenizer_name

    classifier_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

            # Add pad token if it doesn't exist
    if classifier_tokenizer.pad_token is None:
        # 1. Check for Llama 3 specific fine-tune token
        if '<|finetune_right_pad_id|>' in classifier_tokenizer.get_vocab():
            classifier_tokenizer.pad_token = '<|finetune_right_pad_id|>'

        # 2. Check for generic reserved tokens (common in TikToken)
        elif '<|reserved_special_token_0|>' in classifier_tokenizer.get_vocab():
            classifier_tokenizer.pad_token = '<|reserved_special_token_0|>'

        # 3. Safe Fallback: Use EOS token (No resizing required)
        else:
            print("Warning: No dedicated pad token found. Using EOS token as PAD.")
            classifier_tokenizer.pad_token = classifier_tokenizer.eos_token

    vocab_size = len(classifier_tokenizer)
    print(f"Classifier vocab size: {vocab_size}")
    pad_token_id = classifier_tokenizer.pad_token_id

    # --- Use the factory to build the correct model ---
    # Ensure get_model is imported or defined in this scope
    model = get_model(model_args, vocab_size, pad_token_id)

    model.load_state_dict(checkpoint['state_dict'])
    model = model.to(device)
    model.eval()

    print(f"Classifier loaded (Type: {model_args.model_type}, Epochs: {checkpoint['epoch']}).")

    # Returns the tuple. The cache will store this entire tuple.
    return model, classifier_tokenizer

In [None]:
def calculate_combined_scores(top_logits, last_token_logits, condition_lambda, use_z_score=False):
    """
    Normalizes and combines LLM logits with Classifier scores.
    Returns: combined_log_probs (for selection), final_classifier_scores (for logging), llm_log_probs
    """
    # 1. Normalize LLM scores to log probs
    llm_log_probs = F.log_softmax(top_logits, dim=-1)

    # --- CHANGE 1: EARLY EXIT FOR OPTIMIZATION ---
    # If the classifier was skipped (lambda=0), return pure LLM scores immediately.
    if last_token_logits is None:
        # Create dummy zeros for the "classifier scores" so the logger doesn't crash.
        # We make it match the shape of top_logits [1, top_k]
        dummy_classifier_scores = torch.zeros_like(top_logits)

        # Return: (Pure LLM Scores, Dummy Zeros, Pure LLM Scores)
        return llm_log_probs, dummy_classifier_scores, llm_log_probs

    # 2. Normalize Classifier scores to log probs
    classifier_log_probs = F.log_softmax(last_token_logits, dim=-1)

    # Extract the "True" class score (assuming binary classification index 1 is target)
    if len(classifier_log_probs.shape) > 1 and classifier_log_probs.shape[-1] > 1:
        relevant_classifier_scores = classifier_log_probs[:, 1]
    else:
        relevant_classifier_scores = classifier_log_probs

    # 3. Apply Strategy
    if use_z_score:
        # Calculate stats across the top_k candidates
        c_mean = relevant_classifier_scores.mean()

        # --- CHANGE 2: FIX THE STD() CRASH ---
        # unbiased=False prevents crash when top_k=1 (div by zero error)
        c_std = relevant_classifier_scores.std(unbiased=False)

        if c_std < 1e-8: c_std = 1.0 # Safety

        final_classifier_scores = (relevant_classifier_scores - c_mean) / c_std
    else:
        final_classifier_scores = relevant_classifier_scores

    # 4. Combine: LLM_Log_Prob + (Lambda * Classifier_Score)
    combined_log_probs = llm_log_probs + (condition_lambda * final_classifier_scores)

    return combined_log_probs, final_classifier_scores, llm_log_probs

In [None]:
def select_next_token(combined_log_probs, top_indices, strategy="greedy", temperature=1.0):
    """
    Selects the next token index based on strategy.
    Returns: next_token_id (tensor), best_index_relative (int index of top_k)
    """
    if strategy == "sample":
        # Divide by temp to control randomness
        probs = F.softmax(combined_log_probs / temperature, dim=-1)
        # Sample from the distribution
        best_index_relative = torch.multinomial(probs, num_samples=1)
    else:
        # Greedy default (Returns a 0-dim scalar tensor)
        best_index_relative = torch.argmax(combined_log_probs)

    # --- THE FIX IS HERE ---
    # We force convert to Python int regardless of dimensions.
    # argmax returns 0-dim, multinomial returns 2-dim. .item() handles both.
    if isinstance(best_index_relative, torch.Tensor):
        best_index_relative = int(best_index_relative.item())

    # Extract the actual token ID from the top_k list
    next_token_id = top_indices[0, best_index_relative].unsqueeze(0)

    return next_token_id, best_index_relative

In [None]:
def record_evaluation(evaluation_history, step, generated_ids, tokenizer,
                      top_indices, llm_scores, clf_scores, combined_scores, selected_idx):
    """
    Logs the step details to the history list.
    """
    current_context_str = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    top_k = top_indices.shape[1]

    step_data = {
        "step": step,
        "context": current_context_str,
        "candidates": []
    }

    # FIX: Normalize clf_scores to be 1D so we can loop over it easily
    # If it came from zeros_like(top_logits), it's [1, K]. We want [K].
    if clf_scores.dim() > 1:
        clf_scores = clf_scores.squeeze(0)

    for i in range(top_k):
        cand_id = top_indices[0, i].item()
        cand_token = tokenizer.decode([cand_id])

        # Safe extraction of scalar values
        s_llm = llm_scores[0, i].item()

        # FIX: Now we can safely use [i] for both Normal and Optimized cases
        s_clf = clf_scores[i].item()

        s_comb = combined_scores[0, i].item()
        is_winner = (i == selected_idx)

        step_data["candidates"].append({
            "token_text": cand_token,
            "llm_score": round(s_llm, 4),
            "classifier_score": round(s_clf, 4),
            "weighted_combined": round(s_comb, 4),
            "selected": is_winner
        })

    evaluation_history.append(step_data)

In [None]:
def generate_guided(
    llm,
    llm_tokenizer,
    classifier,
    classifier_tokenizer,
    prompt,
    max_len,
    condition_lambda,
    top_k,
    evaluation_history=None,
    use_z_score=False,
    strategy="greedy",
    temperature=1.0
):
    device = llm.device

    # ... (Steps 1, 2, and 3: Template, Sanitization, Tokenization remain same) ...
    # [Pasted for context]
    try:
        if callable(CUSTOM_PROMPT_TEMPLATE):
            messages = CUSTOM_PROMPT_TEMPLATE(prompt)
        else:
            messages = prompt
    except NameError:
        messages = prompt

    if isinstance(messages, str):
        messages = [{"role": "user", "content": messages}]
    elif isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], str):
        messages = [{"role": "user", "content": messages[0]}]

    add_gen_prompt = globals().get('ADD_GENERATION_PROMPT', True)
    input_ids = llm_tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=add_gen_prompt,
        return_tensors="pt"
    ).to(device)

    generated_ids = input_ids

    with torch.no_grad():
        for step in range(max_len):
            # --- A: Get Base LLM Logits ---
            llm_outputs = llm(generated_ids)
            next_token_logits = llm_outputs.logits[:, -1, :].float()

            # --- B: Get Top-K Candidates ---
            top_logits, top_indices = torch.topk(next_token_logits, top_k)

            # --- C: Run Classifier (OPTIMIZED) ---

            # If lambda is effectively zero (smaller than 0.000001), skip the heavy lift
            # Check if we are effectively turning the classifier off
            if abs(condition_lambda) < 1e-6:
                last_token_logits = None
            else:
                # Only do this heavy VRAM expansion if we actually plan to use it

                # Create sequences: [Current Context + Candidate Token]
                candidate_prefixes = torch.cat(
                    [generated_ids.expand(top_k, -1), top_indices.squeeze(0).unsqueeze(-1)],
                    dim=-1
                )

                # Prepare classifier batch
                current_seq_len = candidate_prefixes.shape[1]
                lengths = torch.LongTensor([current_seq_len] * top_k).to(device)
                batch = [candidate_prefixes, lengths, None]

                # Get raw classifier scores
                last_token_logits = classifier.get_final_scores(batch)

            # --- D: Calculate Scores (Helper 1) ---
            # If lambda is 0, this calculates: LLM_Score + (0 * 0) = LLM_Score
            combined_scores, clf_scores, llm_log_probs = calculate_combined_scores(
                top_logits,
                last_token_logits,
                condition_lambda,
                use_z_score
            )

            # --- E: Select Token (Helper 2) ---
            next_token_id, best_idx = select_next_token(
                combined_scores,
                top_indices,
                strategy=strategy,
                temperature=temperature
            )

            # --- F: Log (Helper 3) ---
            if evaluation_history is not None:
                record_evaluation(
                    evaluation_history, step, generated_ids, llm_tokenizer,
                    top_indices, llm_log_probs, clf_scores, combined_scores, best_idx
                )

            # --- G: Append and Yield ---
            generated_ids = torch.cat([generated_ids, next_token_id.unsqueeze(0)], dim=-1)
            new_text = llm_tokenizer.decode(next_token_id.squeeze(0), skip_special_tokens=True)

            yield new_text

            if next_token_id.item() == llm_tokenizer.eos_token_id:
                break

In [None]:
def seed_everything(seed=42):
    # 1. Set the python built-in random seed
    random.seed(seed)

    # 2. Set the numpy seed
    np.random.seed(seed)

    # 3. Set the pytorch seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # If using multi-GPU

    # 4. Important: Force CuDNN to be deterministic
    # This slows down training slightly but ensures 'exact' reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # 5. Set hashing seed (vital for dictionary ordering/hashing)
    os.environ['PYTHONHASHSEED'] = str(seed)

    print(f"Global seed set to {seed}")

In [None]:
def setup():
  from google.colab import drive
  from google.colab import userdata
  seed_everything(SEED)
  drive.mount('/content/drive')
  hf_token = userdata.get('HF_TOKEN')
  login(token=hf_token)
  NGROK_AUTH_TOKEN = userdata.get('NGROK_AUTH_TOKEN')
  return NGROK_AUTH_TOKEN

In [None]:
def get_classifier(classifier_model_name, device):

  model_checkpoint = f'{classifier_model_name}.pth.tar'

  release_url = None
  if GITHUB_RELEASE_VERSION is not None:
    release_url = f"https://github.com/latoohey/modular-fudge/releases/download/{GITHUB_RELEASE_VERSION}/{model_checkpoint}"
    # To use your own model update the path here

  if GITHUB_RELEASE_VERSION is not None:
    !wget "{release_url}" -O {model_checkpoint}
    ckpt_path = model_checkpoint
    print("Model downloaded")
  else:
    ckpt_path = os.path.join(CLASSIFIER_PATH, model_checkpoint)
    print("Using Drive model")


  # 1. Load our trained LSTM classifier
  classifier, classifier_tokenizer = load_classifier(ckpt_path, device)
  print("--- Classifier loaded! ---")
  return classifier, classifier_tokenizer

In [None]:
def initialize_environment():

  NGROK_AUTH_TOKEN = setup()

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print(f"Using device: {device}")

  llm_model = LLM_MODEL_NAME
  print(f"Loading base LLM: {llm_model}...")
  llm = AutoModelForCausalLM.from_pretrained(
      llm_model,
      dtype=torch.float16
  ).to(device)
  llm_tokenizer = AutoTokenizer.from_pretrained(
      LLM_TOKENIZER_NAME
  )
  # Add pad token if it doesn't exist
  if llm_tokenizer.pad_token is None:
      # 1. Check for Llama 3 specific fine-tune token
      if '<|finetune_right_pad_id|>' in llm_tokenizer.get_vocab():
          llm_tokenizer.pad_token = '<|finetune_right_pad_id|>'

      # 2. Check for generic reserved tokens (common in TikToken)
      elif '<|reserved_special_token_0|>' in llm_tokenizer.get_vocab():
          llm_tokenizer.pad_token = '<|reserved_special_token_0|>'

      # 3. Safe Fallback: Use EOS token (No resizing required)
      else:
          print("Warning: No dedicated pad token found. Using EOS token as PAD.")
          llm_tokenizer.pad_token = llm_tokenizer.eos_token

  print("--- LLM loaded! ---")

  server_classifiers = {}

  for classifier_name in SERVER_CLASSIFIERS:
    classifier, classifier_tokenizer = get_classifier(classifier_name, device)
    server_classifiers[classifier_name] = {
        'classifier': classifier,
        'classifier_tokenizer': classifier_tokenizer
    }


  model_defs = {
      "llm": llm,
      "llm_tokenizer": llm_tokenizer,
      "classifier": None,
      "classifer_tokenizer": None
  }

  prompt_args = {
    "prompt": None,
    "max_new_tokens": MAX_NEW_TOKENS,
    "lambda_val": None,
    "top_k": TOP_K,
    "evaluation_history": KEEP_EVALUATION_HISTORY,
    "use_z_score": USE_Z_SCORE,
    "strategy": STRATEGY,
    "temperature": TEMPERATURE
  }

  return Mamba, model_defs, server_classifiers, prompt_args, NGROK_AUTH_TOKEN

def initialize_environment_test():
    NGROK_AUTH_TOKEN = setup()
    Mamba = None
    model_defs = {
      "llm": "llm",
      "llm_tokenizer": "llm_tokenizer",
      "classifier": None,
      "classifer_tokenizer": None
    }

    prompt_args = {
      "prompt": None,
      "max_new_tokens": MAX_NEW_TOKENS,
      "lambda_val": None,
      "top_k": TOP_K,
      "evaluation_history": KEEP_EVALUATION_HISTORY,
      "use_z_score": USE_Z_SCORE,
      "strategy": STRATEGY,
      "temperature": TEMPERATURE
    }

    server_classifiers = {
        "classifier_name": {
            "classifier": "classifier",
            "classifier_tokenizer": "classifier_tokenizer"
            }
        }

    return Mamba, model_defs, server_classifiers, prompt_args, NGROK_AUTH_TOKEN

## Configuration:

Note: Please include your Hugging Face token as a Colab Secret named `HF_TOKEN`

* There are three `TESTING_TYPE`s
  * `grid` - tests a list of prompts - each at different lambdas
  * `targeted` - tests a specific `TESTING_PROMPT` at different lambdas
  * `prompted` - starts a user interface loop

* `SEED` set for reproducability

* Trained models are available in Releases in project GitHub repository
https://github.com/latoohey/modular-fudge. To use one set the `CLASSIFIER_MODEL_NAME` without the file extenstions and appropriate `GITHUB_RELEASE_VERSION`. You can also use your own classifier. Just modify the download and import code below to have the `CKPT_PATH` point to the `.pth` file zipped in a `.tar`. To reduce config issues train the model with the project training script which defines the input and output needs.

* Define the `CLASSIFIER_TOKENIZER_NAME` that the classifier was trained with. This does not need to match the `LLM_TOKENIZER_NAME` but re-tokenizing adds time at inference.

* The `HIDDEN_DIM` is defined as an argument in training so it also needs to be supplied here.

* Define the `LLM_MODEL_NAME` you want to use to generate output. The T4 GPU can comfortably run 3B parameter models and below - try the A100 for bigger models. You'll need to be approved on Hugging Face by Meta to use a Llama model.

* `LLM_PAD_TOKEN` is the pad token for the LLM.

* `MAX_NEW_TOKENS` sets the maximum output from the LLM. Longer outputs slow down as the process runs - I haven't figured out why yet.

* `TOP_K` is the number of candidate tokens that the classifier checks before the LLM selects it's final token. The original FUDGE paper had this set at 200. Note the math on this: Each generation involves the number of output tokens (often all the way to `MAX_NEW_TOKENS`) multiplied by `TOP_K` so that number can get very big very fast.

* `SAVE_TESTS_TO_DRIVE`: `True` saves files generated from the `grid` testing to your Google Drive. You'll be prompted to login. `False` saves the file to the Colab runtime. Since the test runs through about 100 generations you need to make sure your runtime doesn't expire.

* To use a plain, text prompt set the `PROMPT_TEMPLATE` to `None`. Many LLMs have a defined prompt input type usually outlined in their documentation. You can define this using a lambda function named `CUSTOM_PROMPT_TEMPLATE`. It MUST accept one argument (e.g., 'p') which will be your prompt string and MUST return the 'messages' list structure you want. For example, the minimum for Llama would be defined with:

  * `CUSTOM_PROMPT_TEMPLATE = lambda p: [{"role": "user", "content": p}]`

  This should be used in conjunction with `ADD_GENERATION_PROMPT = True` which matches to the Transformers library `add_generation_prompt` argument defined [here](https://huggingface.co/docs/transformers/en/chat_templating#addgenerationprompt)

In [None]:
TESTING_TYPE = "grid" # "grid" or "targeted" or "prompted" or "token_eval"
# TESTING_PROMPT is needed if TESTING_TYPE is "targeted" or "token_eval"
TESTING_PROMPT = "write a paragraph about europe"

GRID_TEST_RUN_NAME = "mamba__128_4_16_1"
SAVE_TESTS_TO_DRIVE = True
TEST_PROMPTS_FILE_PATH = "modular-fudge/data/eval_prompts.csv"
PROMPTS_TO_TEST_LIMIT = 50
GRID_LAMBDAS = [1.2, 1.4, 1.6]
GRID_CLASSIFIER_NAMES = ['mamba_128_4_16_1']
GRID_TOP_KS = [100]
GRID_USE_Z_SCORES = [True]

#TESTING_LAMBDA is needed if TESTING_TYPE is "token_eval"
TESTING_LAMBDA=1

SEED = 24601

CLASSIFIER_MODEL_NAME = "mamba_256_4_16_1"

GITHUB_RELEASE_VERSION = None # "v1.0"
#---OR---
CLASSIFIER_PATH = '/content/drive/MyDrive/modular-fudge/trained_models'

KEEP_EVALUATION_HISTORY = None
EVAL_STEP_TO_ANALYZE = None

USE_Z_SCORE = True
STRATEGY = "greedy"  # Options: "greedy", "sample"
TEMPERATURE = 1.0     # Only used if strategy="sample"

CLASSIFIER_TOKENIZER_NAME = 'meta-llama/Llama-3.2-3B-Instruct'

# (You will need to accept the license on Hugging Face first for Llama)
# LLM_MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
LLM_MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
# LLM_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

LLM_TOKENIZER_NAME = 'meta-llama/Llama-3.2-3B-Instruct'

# Keep both of these values low if running "token_eval"
MAX_NEW_TOKENS = 512
TOP_K = 100

# PROMPT TEMPLATE
# To use just a plain prompt, set this to None:
# CUSTOM_PROMPT_TEMPLATE = None

# --- OR ---

# To use a custom template, define a lambda function.
# The lambda MUST accept one argument (e.g., 'p') which will be your prompt string.
# It MUST return the 'messages' list structure you want.

# Example 1: Add a simple prefix
# CUSTOM_PROMPT_TEMPLATE = lambda p: [
#     {"role": "user", "content": f"Task: Answer the following question. {p}"}
# ]

# Example 2: Add a System Prompt
# This one is the base format for LLama models
CUSTOM_PROMPT_TEMPLATE = lambda p: [{"role": "user", "content": p}]
ADD_GENERATION_PROMPT = True

SERVER_CLASSIFIERS = ["mamba_128_4_16_1", "lstm_2_256"]
#SERVER_CLASSIFIERS = ["lstm_2_256"]

In [None]:
# Map this to the arguments your generate_guided function needs
class GenerationArgs(BaseModel):
    prompt: str
    lambda_val: float = 1.0
    model_name: str = None

# --- 3. Your Existing Generator Logic ---
# I'm mocking this for the example, but you will PASTE YOUR FUNCTION here.
def generate_guided_test(args):
    import time
    mock_tokens = ["This ", "is ", "a ", "streamed ", "response ", "from ", "Colab."]
    for token in mock_tokens:
        time.sleep(0.5)
        yield token

# --- 4. The Streaming Wrapper ---
# This is the bridge. It takes the HTTP input, calls your function,
# and ensures the output is formatted correctly for the web stream.
def stream_generator(input_data: GenerationArgs):
    # We pass the pydantic object or unpack it to your function
    # generator = generate_guided(input_data.dict()) # If passing dict

    # Calling your generator
    client_data = input_data.model_dump()
    # generator = generate_guided_test(client_data)

    prompt_args['prompt'] = client_data['prompt']
    prompt_args['lambda_val'] = client_data['lambda_val']

    try:
      requested_classifier = server_classifiers[client_data['model_name']]
    except KeyError:
      requested_classifier = next(iter(server_classifiers.values()))

    model_defs['classifier'] = requested_classifier['classifier']
    model_defs['classifier_tokenizer'] = requested_classifier['classifier_tokenizer']

    generator = generate_guided(
        model_defs["llm"],
        model_defs["llm_tokenizer"],
        model_defs["classifier"],
        model_defs["classifier_tokenizer"],
        prompt_args['prompt'],
        prompt_args['max_new_tokens'],
        prompt_args['lambda_val'],
        prompt_args['top_k'],
        prompt_args['evaluation_history'],
        prompt_args['use_z_score'],
        prompt_args['strategy'],
        prompt_args['temperature']
    )



    for text_chunk in generator:
        # We yield the text directly.
        # Browsers/Clients will receive this chunk by chunk.
        yield text_chunk


Mamba, model_defs, server_classifiers, prompt_args, NGROK_AUTH_TOKEN = initialize_environment()

# Mamba, model_defs, server_classifiers, prompt_args, NGROK_AUTH_TOKEN = initialize_environment_test()

app = FastAPI()

# 4. CRITICAL: Add CORS Middleware
# This tells FastAPI to answer the "OPTIONS" preflight check with "Yes, come on in."
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],      # Allows all origins (VS Code, GitHub Pages, etc.)
    allow_credentials=True,
    allow_methods=["*"],      # Allows all methods (POST, GET, OPTIONS, etc.)
    allow_headers=["*"],      # Allows all headers (including ngrok-skip-browser-warning)
)

# --- 5. The Endpoint ---
@app.post("/generate_stream")
async def generate_stream_endpoint(input_data: GenerationArgs):
    # We return a StreamingResponse object
    # media_type="text/event-stream" is standard for streaming updates,
    # but "text/plain" is often easier for simple raw text demos.
    return StreamingResponse(stream_generator(input_data), media_type="text/plain")


# --- 6. Tunnel & Run ---
ngrok.kill()
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# Connect on port 8000
public_url = ngrok.connect(8000).public_url
print(f"\nüöÄ API LIVE AT: {public_url} üöÄ")
print(f"Endpoint: {public_url}/generate_stream")

# 8. Run the Server (The Async Way)
# This prevents the 'asyncio.run() cannot be called' error
config = uvicorn.Config(app, port=8000)
server = uvicorn.Server(config)
await server.serve()

Global seed set to 24601
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
Loading base LLM: meta-llama/Llama-3.2-3B-Instruct...


config.json:   0%|          | 0.00/878 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

--- LLM loaded! ---
Using Drive model
Loading classifier from /content/drive/MyDrive/modular-fudge/trained_models/mamba_128_4_16_1.pth.tar...
Checkpoint args: namespace(data_dir='data', save_dir='/content/drive/My Drive/', hf_token='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', ckpt=None, batch_size=64, epochs=100, lr=0.0001, seed=24601, num_workers=4, pos_cat='eb', neg_cat='simple_wiki', print_freq=100, model_type='mamba', lstm_hidden_dim=128, lstm_num_layers=4, mamba_d_model=128, mamba_num_layers=4, mamba_d_state=16, mamba_dropout=0.1, on_colab=True, val_size=400, max_len=1024, min_sentence_length=3, tokenizer_name='meta-llama/Llama-3.2-3B-Instruct', task='transfer', device=device(type='cuda'))
Classifier vocab size: 128256
Classifier loaded (Type: mamba, Epochs: 12).
--- Classifier loaded! ---
Using Drive model
Loading classifier from /content/drive/MyDrive/modular-fudge/trained_models/lstm_2_256.pth.tar...
Checkpoint args: namespace(data_dir='data', save_dir='/content/drive/My Drive/', h

INFO:     Started server process [2108]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)


INFO:     97.238.50.1:0 - "OPTIONS /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "OPTIONS /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO:     97.238.50.1:0 - "POST /generate_stream HTTP/1.1" 200 OK
INFO

INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [2108]


In [None]:
ngrok.kill()