# Training linear probes on unlearned models

This notebook trains probes on a given unlearned model and saves model predictions and ground truths to a csv file, which can be called by `probe_visualization.py`.

## Installations (run once and then restart kernel)

Note that this cell will kill the process and force a restart to ensure that all installs and patches load properly

In [None]:
# Works in colab but not Jupyter initially, 
# so using updated version to handle errors that came up on Jupyter
!git clone https://github.com/center-for-humans-and-machines/transformer-heads.git
!uv pip install -e ./transformer-heads
!uv pip install "transformers>=4.37.0,<4.50.0"
!uv pip install --upgrade datasets fsspec
!uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Optional patch for generate compatibility (only if you want to support .generate() in transformers 4.50+)
!sed -i 's/class TransformerWithHeads(PreTrainedModel):/class TransformerWithHeads(PreTrainedModel, GenerationMixin):/' transformer-heads/transformer_heads/model/model.py

# For some reason Zephyr RMU depended on these, but base didn't
!uv pip install protobuf sentencepiece

# Force a restart (to make sure all installs and patches load properly)
import os
os.kill(os.getpid(), 9)

## Training probes

In [None]:
from transformer_heads import load_headed
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    MistralForCausalLM,
    Trainer,
    BitsAndBytesConfig,
    TrainingArguments,
    GPT2Model,
    GPT2LMHeadModel,
)
from transformer_heads.util.helpers import DataCollatorWithPadding, get_model_params
from peft import LoraConfig
from transformer_heads.config import HeadConfig
from transformer_heads.util.model import print_trainable_parameters
from transformer_heads.util.evaluate import (
    evaluate_head_wise,
    get_top_n_preds,
    get_some_preds,
)
import torch
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, HTML

In [None]:
ZEPHYR_7B_RMU = "cais/Zephyr_RMU"
ZEPHYR_7B_ELM = "baulab/elm-zephyr-7b-beta"
ZEPHYR_7B_BASE = "HuggingFaceH4/zephyr-7b-beta"
LLAMA3_8B_INSTRUCT = "LLM-GAT/llama-3-8b-instruct-elm-checkpoint-8"
LLAMA3_8B_BASE = "meta-llama/Meta-Llama-3-8B-Instruct"


display(HTML("<h2>Update the dataset path for the text we want to probe on:</h2>"))

# When using this notebook, you will need to update this file path to the test to probe on. In the paper, we used the test bio test dataset
file_path = '../data/wmdp_rephrased/data_hindi_filler_text/test/bio_questions.json'

display(HTML("<h2>Choose the model to train probes for:</h2>"))

model_selector = widgets.Dropdown(
    options=[ZEPHYR_7B_RMU, ZEPHYR_7B_ELM, ZEPHYR_7B_BASE, LLAMA3_8B_INSTRUCT, LLAMA3_8B_BASE],
    value=ZEPHYR_7B_RMU,
)

display(model_selector)




In [None]:
# model_path = "cais/Zephyr_RMU"
# model_path = "baulab/elm-zephyr-7b-beta"
# model_path = "HuggingFaceH4/zephyr-7b-beta"
# model_path = "LLM-GAT/llama-3-8b-instruct-elm-checkpoint-8"
# model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
# In any subsequent cell:
model_path = model_selector.value
print(f"Selected model: {model_path}")
train_epochs = 1
eval_epochs = 1
logging_steps = 100
full_finetune = False

In [None]:
if model_path == "baulab/elm-zephyr-7b-beta":
    config_path = "HuggingFaceH4/zephyr-7b-beta"
else:
    config_path = model_path
model_params = get_model_params(config_path)
model_class = model_params["model_class"]
hidden_size = model_params["hidden_size"]
vocab_size = model_params["vocab_size"]
print(model_params)

In [None]:
# NOTE: This was the same for both Llama and Zephyr I believe, but verify based on what's printed above
num_heads = 32

We are doing text classification, so we have to set pred_for_sequence to True for this task. In the imdb dataset, we only have two labels, 0 for negative and 1 for positive. So we have to set num_outputs to 2.

In [None]:
head_configs = [
    HeadConfig(
        name=f"wmdp_head_{i}",
        layer_hook=-(i),
        in_size=hidden_size,
        output_activation="linear",
        pred_for_sequence=True,
        loss_fct="cross_entropy",
        num_outputs=4,
    )
    # for i in range(num_heads, 5)
    for i in range(1, num_heads+1)
]

In [None]:
dd = load_dataset("cais/wmdp", "wmdp-bio")

This code uses hindi filler instead. Comment it out if you want to use normal rephrasing

In [None]:

import json

# Update this file path for the WMDP bio questions replaced with Hindi filler text.
all_data = [] # To store all the JSON objects from the file

try:
    with open(file_path, 'r') as f:
        for line_number, line in enumerate(f):
            try:
                data = json.loads(line) # Use json.loads() for each line
                all_data.append(data)
                # You can print each object as it's read, or process/store it
                # print(f"Object {line_number + 1}: {json.dumps(data, indent=4)}")
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON from line {line_number + 1}: {e}")
                print(f"Problematic line content: {line.strip()}") # Print the line that caused the error
    
    # Now all_data contains a list of all JSON objects from the file
    # You can print the entire list, for example:
    print(f"Successfully read {len(all_data)} JSON objects from the file.")
    if all_data:
        print("First object:")
        print(json.dumps(all_data[0], indent=4, ensure_ascii=False)) # ensure_ascii=False for proper Hindi character display
        if len(all_data) > 1:
            print("\nLast object:")
            print(json.dumps(all_data[-1], indent=4, ensure_ascii=False))

except FileNotFoundError:
    print(f"Error: File not found at {file_path}")


from datasets import Dataset
test_hindi_dataset = Dataset.from_list(all_data)
hindi_dd = {"test": test_hindi_dataset}

dd = hindi_dd

In the *tokenize_function*, we set the *label* entry in the dataset for each of our heads.

In [None]:
import numpy as np
from datasets import Dataset, DatasetDict

def create_train_test_split(dataset, train_size=0.5, seed=42):
    # Get the total size of the dataset
    total_size = len(dataset)
    train_size_count = int(total_size * train_size)

    # Generate random indices for the split
    np.random.seed(seed)
    indices = np.random.permutation(total_size)
    train_indices = indices[:train_size_count]
    test_indices = indices[train_size_count:]

    # Create the train and test datasets
    train_dataset = dataset.select(train_indices)
    test_dataset = dataset.select(test_indices)

    return train_dataset, test_dataset

# Create the train-test split if only test split exists
if 'train' not in dd.keys() and 'test' in dd.keys():
    print(f"Creating train-test split from the original test set of size {len(dd['test'])}")
    train_dataset, test_dataset = create_train_test_split(dd['test'])

    # Create a new DatasetDict with both splits
    dd = DatasetDict({
        'train': train_dataset,
        'test': test_dataset
    })

    print(f"Created train split with {len(dd['train'])} examples")
    print(f"Created test split with {len(dd['test'])} examples")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    formatted_texts = []
    for i in range(len(examples["question"])):
        prompt = "The following are multiple choice questions (with answers) about biology.\n\n"
        question = examples["question"][i]
        choices = examples["choices"][i]
        formatted_text = prompt + f'{question.strip()}\nA. {choices[0]}\nB. {choices[1]}\nC. {choices[2]}\nD. {choices[3]}\nAnswer:'
        formatted_texts.append(formatted_text)

    out = tokenizer(formatted_texts, padding=False, truncation=False)
    for hc in head_configs:
        out[hc.name] = examples["answer"]
    return out


for split in dd.keys():
    dd[split] = dd[split].map(tokenize_function, batched=True)

dd.set_format(
    type="torch",
    columns=["input_ids", "attention_mask"] + [x.name for x in head_configs],
)
for split in dd.keys():
    dd[split] = dd[split].remove_columns(["question", "choices", "answer"])

In [None]:
# Note: if this takes too long, consider quantizing using BitsAndBytesConfig. That said, it didn't appear to make a big difference for me.
model = load_headed(
    model_class,
    model_path,
    head_configs=head_configs,
    quantization_config=None,
    # if full_finetune else quantization_config,
    freeze_base_model=not full_finetune,
    device_map={"": torch.cuda.current_device()},
)

Our heads are linear layers with only two outputs. Thus we have a very low amount of trainable parameters.

In [None]:
ins, preds, ground_truths = get_some_preds(
    model, dd["test"], tokenizer, n=5, classification=True
)
print(
    pd.DataFrame(
        list(zip(ins, preds["wmdp_head_3"], ground_truths["wmdp_head_3"])),
        columns=["prompt", "label", "ground_truth"],
    )
)

Untrained heads give fairly random outputs.

In [None]:
collator = DataCollatorWithPadding(
    feature_name_to_padding_value={
        "input_ids": tokenizer.pad_token_id,
        "attention_mask": 0,
    }
)

In [None]:
print(evaluate_head_wise(model, dd["test"], collator, epochs=eval_epochs))

In [None]:
args = TrainingArguments(
    output_dir="wmdp_linear_probe",
    learning_rate=0.0002,
    num_train_epochs=train_epochs,  # To speed things up set to 0.1, set to 1 for better performance
    logging_steps=logging_steps,
    do_eval=False,
    remove_unused_columns=False,
)
trainer = Trainer(
    model,
    args=args,
    train_dataset=dd["train"],
    data_collator=collator,
)
trainer.train()

In [None]:
print(evaluate_head_wise(model, dd["test"], collator, epochs=eval_epochs))

In [None]:
ins, preds, ground_truths = get_some_preds(
    model, dd["test"], tokenizer, n=len(dd["test"]), classification=True
)

In [None]:
print(model_selector.value)

Store data containing probe performance

---



In [None]:
import os

def save_predictions_to_csv(ins, preds, ground_truths, output_dir="prediction_results"):
    """
    Save model predictions and ground truths to CSV files.

    Args:
        ins: List of input texts/prompts
        preds: Dictionary mapping head names to prediction lists
        ground_truths: Dictionary mapping head names to ground truth lists
        output_dir: Directory to save CSV files
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Get all head names
    head_names = list(preds.keys())

    # Save individual CSVs for each head (more readable format)
    for head in head_names:
        head_data = []
        for i, input_text in enumerate(ins):
            if i < len(preds[head]):  # Ensure index is valid
                row = {
                    "input": input_text,
                    "prediction": preds[head][i],
                    "ground_truth": ground_truths[head][i],
                    "correct": preds[head][i] == ground_truths[head][i]
                }
                head_data.append(row)

        # Save to CSV
        head_df = pd.DataFrame(head_data)
        head_path = os.path.join(output_dir, f"{head}_predictions.csv")
        head_df.to_csv(head_path, index=False)
        print(f"Saved {head} predictions to {head_path}")

    # Calculate and save accuracy summary
    accuracy_data = []
    for head in head_names:
        correct = sum(1 for p, t in zip(preds[head], ground_truths[head]) if p == t)
        total = len(preds[head])
        accuracy = correct / total if total > 0 else 0

        accuracy_data.append({
            "head_name": head,
            "accuracy": accuracy,
            "correct_count": correct,
            "total_count": total
        })

    # Save accuracy summary
    accuracy_df = pd.DataFrame(accuracy_data)
    accuracy_path = os.path.join(output_dir, "accuracy_summary.csv")
    accuracy_df.to_csv(accuracy_path, index=False)
    print(f"Saved accuracy summary to {accuracy_path}")

    return {
        "head_predictions": [os.path.join(output_dir, f"{head}_predictions.csv") for head in head_names],
        "accuracy_summary": accuracy_path
    }

# Then save predictions to CSV
converted_model_name = model_path.replace("/", "_")
csv_files = save_predictions_to_csv(ins, preds, ground_truths, f"prediction_results_hindi_filler_{converted_model_name}")