In [None]:
# For debugging as needed
# import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
!pip install transformers accelerate bitsandbytes huggingface_hub

Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m46.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.48.2


In [None]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import random
from tqdm import tqdm
import os
from google.colab import drive

# We need argparse.Namespace to reconstruct the 'args' object
# that was saved inside the checkpoint
from argparse import Namespace

## Testing Args:
* `grid` - tests a list of prompts - each at different lambdas
* `targeted` - tests a specific `prompt` (supplied as arg) - each at different lambdas
* `prompted` - starts a user interface loop

In [None]:
# ❗️ Point this to the folder in your Google Drive containing the model
DRIVE_PATH = "/content/drive/My Drive/modular-fudge/trained_models/"
# MODEL_NAME = "lstm_e_20251109_222627" # wihtout .pth.tar extention
MODEL_NAME = "lstm_h_20251109_052253"

# (You will need to accept the license on Hugging Face first for Llama!)
LLM_MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
# LLM_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

testing_args = {
    "type": "prompted",
    "prompt": "Write a brief summary of the concept of 'philosophy'",
}

SEED = 24601

TOKENIZER_NAME = 'meta-llama/Llama-3.2-3B-Instruct'
PAD_TOKEN = '[PAD]'

HIDDEN_DIM = 300

MAX_NEW_TOKENS = 300
TOP_K = 300

# Calculated config

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

model_checkpoint = f'{MODEL_NAME}.pth.tar'
CKPT_PATH = os.path.join(DRIVE_PATH, model_checkpoint)

Using device: cuda


In [None]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import random
from tqdm import tqdm
import os
from google.colab import drive

# We need argparse.Namespace to reconstruct the 'args' object
# that was saved inside the checkpoint
from argparse import Namespace

# --- From util.py ---
def num_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# --- From models ---
class LSTMClassifier(nn.Module):

    def __init__(self, args, vocab_size, pad_token_id):
        """
        Initializes the LSTM model.

        Args:
            args: The full ArgumentParser namespace. Reads
                  `args.lstm_hidden_dim` and `args.lstm_num_layers`.
            vocab_size: The total vocabulary size for the embedding layer.
            pad_token_id: The ID of the padding token.
        """
        super().__init__()

        self.embed = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=args.lstm_hidden_dim,
            padding_idx=pad_token_id
        )

        self.rnn = nn.LSTM(
            args.lstm_hidden_dim,
            args.lstm_hidden_dim,
            num_layers=args.lstm_num_layers,
            bidirectional=False,
            dropout=0.5,
            batch_first=True # Makes the permute/transpose logic simpler
        )
        self.out_linear = nn.Linear(args.lstm_hidden_dim, 1)

    def forward(self, inputs, lengths):
        """
        Internal forward pass for the LSTM.
        Requires `lengths` for sequence packing.
        """
        # (batch_size, seq_len, hidden_dim)
        embedded_inputs = self.embed(inputs)

        # Pack sequence for efficient RNN processing
        packed_inputs = pack_padded_sequence(
            embedded_inputs,
            lengths.cpu(), # Must be on CPU
            batch_first=True,
            enforce_sorted=False
        )

        # rnn_output is (packed_batch, hidden_dim)
        rnn_output, _ = self.rnn(packed_inputs)

        # Unpack: (batch_size, seq_len, hidden_dim)
        rnn_output, _ = pad_packed_sequence(
            rnn_output,
            batch_first=True
        )

        # (batch_size, seq_len)
        return self.out_linear(rnn_output).squeeze(2)

    def get_final_scores(self, batch):
        """
        Returns the scores for the last token of each sequence in the batch.
        Used by the guided generation to condition on the last generated token.
        """
        inputs, lengths, _ = batch # _ is targets, not used here
        # The forward method returns (batch_size, seq_len)
        all_token_scores = self.forward(inputs, lengths)
        # Extract the score for the last token of each sequence
        # lengths is (batch_size,), all_token_scores is (batch_size, seq_len)
        final_scores = all_token_scores[torch.arange(inputs.size(0)), lengths - 1]
        return final_scores

def get_model(model_args, vocab_size, pad_token_id):
    """
    Factory function to create a model based on model_args.
    Assumes model_args contains 'model_type' and relevant model-specific parameters.
    """
    if model_args.model_type == 'lstm':
        return LSTMClassifier(model_args, vocab_size, pad_token_id)
    # Add other model types here if necessary
    else:
        raise ValueError(f"Unknown model type: {model_args.model_type}")

In [None]:
from google.colab import userdata
from huggingface_hub import login
hf_token = userdata.get('HF_TOKEN')
login(token=hf_token)

In [None]:
def load_classifier(ckpt_path, device, args):
    """Loads a trained classifier from a checkpoint using the model factory."""
    print(f"Loading classifier from {ckpt_path}...")
    checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)

    # Load args *from the checkpoint* to know what model to build
    model_args = checkpoint['args']

    # This assumes your main_train.py saved 'tokenizer_name' in its args
    if not hasattr(model_args, 'tokenizer_name'):
        tokenizer_name = TOKENIZER_NAME
    else:
        tokenizer_name = model_args.tokenizer_name # Use tokenizer name from checkpoint args

    classifier_tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name,
        token=args.hf_token # Use token from CLI args
    )
    # Llama models often don't have a pad_token by default. Assign eos_token as pad_token
    # to ensure vocabulary size consistency with training and provide a pad_token_id.
    if classifier_tokenizer.pad_token is None:
        classifier_tokenizer.pad_token = classifier_tokenizer.eos_token

    # Get vocab size and pad_token_id for the model
    vocab_size = len(classifier_tokenizer)
    pad_token_id = classifier_tokenizer.pad_token_id # Get pad_token_id here

    # --- Use the factory to build the correct model ---
    model = get_model(model_args, vocab_size, pad_token_id)

    model.load_state_dict(checkpoint['state_dict'])
    model = model.to(device)
    model.eval()

    print(f"Classifier loaded (Type: {model_args.model_type}, Epochs: {checkpoint['epoch']}).")
    return model, classifier_tokenizer

def generate_guided(
    llm,
    llm_tokenizer,
    classifier,
    classifier_tokenizer,
    prompt,
    max_len,
    condition_lambda,
    top_k
):
    """
    Performs FUDGE-style guided generation for a single prompt.
    This function is now model-agnostic.
    """
    device = llm.device

    # --- Create the Llama 3.1 Chat Prompt ---
    messages = [
        {"role": "user", "content": prompt}
    ]

    # This applies the template and adds the "assistant" prompt
    input_ids = llm_tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(device)

    # We store the original prompt's length to slice it off later
    prompt_length = input_ids.shape[1]
    generated_ids = input_ids

    with torch.no_grad():
        for _ in range(max_len):
            # --- A: Get Base LLM Logits ---
            llm_outputs = llm(generated_ids)
            # Get logits (which are float16) and cast them to float32
            next_token_logits = llm_outputs.logits[:, -1, :].float()

            # --- B: Get Top-K Candidates ---
            top_logits, top_indices = torch.topk(next_token_logits, top_k)

            # --- C: Run Classifier on Candidates ---

            # 1. Create k candidate prefixes
            candidate_prefixes = torch.cat(
                [generated_ids.expand(top_k, -1), top_indices.squeeze(0).unsqueeze(-1)],
                dim=-1
            )

            # 2. Get the lengths.
            # Since all candidate_prefixes are the same length (no padding),
            # we can create the lengths tensor directly.
            current_seq_len = candidate_prefixes.shape[1]
            lengths = torch.LongTensor([current_seq_len] * top_k).to(device)

            # 3. Create the batch directly from the token IDs
            #    We are skipping the decode/re-encode step!
            batch = [
                candidate_prefixes,  # Use the LLM's token IDs directly
                lengths,
                None
            ]

            # 4. Get classifier scores
            #    The 'get_final_scores' adapter works as-is.
            last_token_logits = classifier.get_final_scores(batch)

            # --- D: Combine and Select ---

            # Get LLM log probs: log(P(x))
            llm_log_probs = F.log_softmax(top_logits, dim=-1)

            # We need log(P(a|x)), which for a binary classifier
            # logit is log(sigmoid(logit)).
            # This gives the log-probability of the *positive class* (style=1).
            classifier_log_probs = F.logsigmoid(last_token_logits)

            # FUDGE: log(P(x|a)) = log(P(x)) + lambda * log(P(a|x))
            # Shape: (1, top_k) + (top_k,) -> (1, top_k)
            combined_log_probs = llm_log_probs + (condition_lambda * classifier_log_probs)

            best_token_index = torch.argmax(combined_log_probs)
            next_token_id = top_indices[0, best_token_index].unsqueeze(0)

            # --- E: Append and Yield/Repeat ---
            generated_ids = torch.cat([generated_ids, next_token_id.unsqueeze(0)], dim=-1)

            # --- THIS IS THE NEW LOGIC ---
            # 1. Decode the single new token
            new_text = llm_tokenizer.decode(next_token_id.squeeze(0), skip_special_tokens=True)

            # 2. "Yield" it back to the caller
            yield new_text
            # --- END OF NEW LOGIC ---

            if next_token_id.item() == llm_tokenizer.eos_token_id:
                break

In [None]:
# Set seeds for reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Create a Namespace object for runtime arguments, especially for hf_token
from argparse import Namespace
runtime_args = Namespace(hf_token=hf_token) # hf_token is a kernel variable

# 1. Load our trained LSTM classifier
classifier, classifier_tokenizer = load_classifier(CKPT_PATH, DEVICE, runtime_args)

print("--- Classifier loaded! ---")

Loading classifier from /content/drive/My Drive/modular-fudge/trained_models/lstm_h_20251109_052253.pth.tar...


tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

Classifier loaded (Type: lstm, Epochs: 8).
--- Classifier loaded! ---


In [None]:
llm_model= LLM_MODEL_NAME
 # 2. Load the base LLM
print(f"Loading base LLM: {llm_model}...")
llm = AutoModelForCausalLM.from_pretrained(
    llm_model,
    dtype=torch.float16
).to(DEVICE)
llm_tokenizer = AutoTokenizer.from_pretrained(
    TOKENIZER_NAME
)
if llm_tokenizer.pad_token is None:
    llm_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})

print("--- LLM loaded! ---")

Loading base LLM: meta-llama/Llama-3.2-3B-Instruct...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

--- LLM loaded! ---


In [None]:
def prompted_testing():
  run = True
  while run:
      try:
          prompt = input("Enter a prompt (or 'q' to quit): ")
          if prompt.lower() == 'q':
              break
          condition_lambda_str = input("Enter a lambda value: ")
          print(f"--- Generating with lambda={int(condition_lambda_str)} ---")

          output_generator = generate_guided(
              llm,
              llm_tokenizer,
              classifier,
              classifier_tokenizer,
              prompt,
              MAX_NEW_TOKENS,
              int(condition_lambda_str),
              TOP_K
          )
          for new_token in output_generator:
            print(new_token, end="", flush=True)
          print("\n---") # Add a newline at the end

      except KeyboardInterrupt:
          print("\nExiting.")
          break

In [None]:
def targeted_testing(prompt="Write a mystery story."):
  lambdas = [0.0, 5.0, 10.0, 20.0]
  for lambda_val in lambdas:
    output_generator = generate_guided(
              llm,
              llm_tokenizer,
              classifier,
              classifier_tokenizer,
              prompt,
              MAX_NEW_TOKENS,
              lambda_val,
              TOP_K
          )
    # 2. Exhaust the generator into a single string
    full_output = "".join(list(output_generator))
    output = full_output.strip()
    print(f"\nGUIDED OUTPUT (lambda={lambda_val}):\n", output)
    print("-" * 30)

In [None]:
prompts = [
    "Who was Albert Einstein?",
    "Explain what artificial intelligence is.",
    "How does a neural network work?",
    "What are some recent advancements in renewable energy?",
    "I just got promoted at work!",
    "I’m feeling really down today.",
    "The weather is nice today.",
    "Write a short story about a robot who dreams of becoming human.",
    "Generate a poem about love.",
    "Give me three ideas for a birthday surprise.",
    "Generate a story about space exploration.",
    "Write a short story about a time traveler who meets a famous historical figure.",
    "What's a good way to spend a sunny day in the park?",
    "Are people fundamentally good?",
    "What is the key to happiness?",
    "Who is the greatest military leader in history?",
    "What was life like in London at the start of the twentieth century?",
    "How do you learn to play an instrument?",
    "Compare and contrast impressionism and realism.",
    "What causes sepsis?",
    "Write a mystery story"
]

In [None]:
import csv
import tqdm
import time
import os

def grid_testing():

  log_file = f"/content/drive/My Drive/modular-fudge/tests/{MODEL_NAME}_tests.csv"

  headers = ['top_k', 'lambda', 'prompt', 'elapsed_time', 'output']

  if not os.path.exists(log_file):
      with open(log_file, 'w', newline='') as csvfile:
          csv_writer = csv.writer(csvfile)
          csv_writer.writerow(headers)
      print(f"CSV file '{log_file}' created with headers.")
  else:
      print(f"CSV file '{log_file}' already exists.")

  lambdas = [0.0, 5.0, 10.0, 15.0, 20.0]
  for prompt in prompts:
    print(f"Testing prompt: {prompt}")
    for lambda_val in lambdas:
      print(f"@ lambda: {lambda_val}")
      np.random.seed(SEED)
      torch.manual_seed(SEED)
      if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
      start_time = time.time()
      output_generator = generate_guided(
            llm,
            llm_tokenizer,
            classifier,
            classifier_tokenizer,
            prompt,
            MAX_NEW_TOKENS,
            lambda_val,
            TOP_K
        )
      # 2. Exhaust the generator into a single string
      full_output = "".join(list(output_generator))
      output = full_output.strip()
      end_time = time.time()
      elapsed_time = end_time - start_time

      new_row_data = [TOP_K, lambda_val, prompt, elapsed_time, output]
      with open(log_file, 'a', newline='') as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerow(new_row_data)

In [None]:
def test(testing_args):
  if (testing_args["type"]=="prompted"):
    prompted_testing()
  elif (testing_args["type"]=="targeted"):
    targeted_testing(testing_args["prompt"])
  elif (testing_args["type"]=="grid"):
    grid_testing()

In [None]:
from IPython.display import HTML, display

def set_css_output_wrap():
    display(HTML('''
    <style>
    div.output_text pre {
        white-space: pre-wrap;
    }
    div.output_subarea pre {
        white-space: pre-wrap;
    }
    </style>
    '''))

set_css_output_wrap()

test(testing_args)