In [None]:
"""
This runs forward passes on samples and stores: (1) pre-MLP activations; (2) top-k expert selections; (3) sample metadata."
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Qwen3VLMoeForConditionalGeneration
from transformers.loss.loss_utils import ForCausalLMLoss
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import gc
import pickle
from termcolor import colored
import importlib
import cupy
import cuml

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df_fast
from utils.misc import flatten_list
from utils.role_templates import render_single_message

main_device = 'cuda:0'
seed = 123

ws = '/workspace/deliberative-alignment-jailbreaks'

clear_all_cuda_memory()
check_memory()

In [None]:
ws = '/workspace/deliberative-alignment-jailbreaks'

# Load base model

In [None]:
"""
Load the base tokenizer/model
"""
selected_model_index = 4

def get_model(index):
    # HF model ID, model prefix, model architecture,  attn implementation, whether to use hf lib implementation
    models = {
        0: ('openai/gpt-oss-20b', 'gptoss20', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 24),
        1: ('openai/gpt-oss-120b', 'gptoss120', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 36), # Will load experts in MXFP4 if triton kernels installed
        2: ('Qwen/Qwen3-VL-30B-A3B-Thinking', 'qwen3vl30ba3b', 'qwen3vlmoe', None, True, 48),
        3: ('zai-org/GLM-4.5-Air-FP8', 'glm45air', 'glm4moe', None, True, 45), # GLM-4.5 has one dense pre-layer, this is MoE layers
        4: ('allenai/Olmo-3-7B-Think', 'olmo3-7', 'olmo3', None, True, 32) # OlMo-3
    }
    return models[index]

def load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
        
    if model_architecture == 'qwen3vlmoe':
        model = Qwen3VLMoeForConditionalGeneration.from_pretrained(model_id, cache_dir = cache_dir, dtype = 'auto', trust_remote_code = not model_use_hf, device_map = None, attn_implementation = model_attn).to(main_device).eval()
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir = cache_dir, dtype = 'auto', trust_remote_code = not model_use_hf, device_map = None, attn_implementation = model_attn).to(main_device).eval()

    return tokenizer, model    

model_id, model_prefix, model_architecture, model_attn, model_use_hf, model_n_layers = get_model(selected_model_index)
tokenizer, model = load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf)

In [None]:
"""
Some checks for special models (GPT-OSS - check attn + expert precision)
"""
if model_architecture == 'gptoss':
    print(model.model.layers[0].mlp.experts.down_proj)
    print(model.model.config._attn_implementation)

# # Test model() call
# inputs = tokenizer(['Test string'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 12).to(model.device)
# with torch.no_grad():
#     original_results = model(**inputs, use_cache = False)
# print(original_results['logits'][0, :].detach().float().cpu().numpy())

# # Test custom loader
# import importlib
# import utils.pretrained_models.ringmini2 as test_mod   # the module object
# test_mod = importlib.reload(test_mod)
# run_model_return_topk = test_mod.run_ringmini2_return_topk
# custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = True)
# print(custom_results['logits'][0, :].detach().float().cpu().numpy())

check_memory()

In [None]:
"""
Load reverse-engineered forward pass functions that return topk expert IDs and weights
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

@torch.no_grad()
def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 640).to(model.device)
    original_results = model(**inputs, use_cache = False)
    custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = True)
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    # assert len(custom_results['all_topk_experts']) == len(custom_results['all_topk_weights']), 'Length of topk IDs and weights not equal'
    # print(f"Length of topk: {len(custom_results['all_topk_experts'])}")
    # print(f"Topk size: {custom_results['all_topk_experts'][0].shape}")
    # print(f"First token topk IDs: {custom_results['all_topk_experts'][0][1,]}")
    # print(f"First token topk weights: {custom_results['all_topk_weights'][0][1,]}")
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), custom_results['logits'].size(-1)).detach().cpu().item()
    print(f"LM loss: {loss}")
    print(f"Hidden states layers (pre-mlp | post-layer): {len(custom_results['all_pre_mlp_hidden_states'])} | {len(custom_results['all_hidden_states'])}")
    print(f"Hidden state size (pre-mlp | post-layer): {(custom_results['all_pre_mlp_hidden_states'][0].shape)} | {(custom_results['all_hidden_states'][0].shape)}")
    # print(f"Router logits : {(custom_results['all_router_logits'][0].shape)}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

# Get dataset

In [None]:
# Test
s = tokenizer.apply_chat_template(
    [
        # {'role': 'developer', 'content': 'Test.'},
        {'role': 'user', 'content': 'Hi! I am a dog and I like to bark'},
        {'role': 'assistant', 'content': 'I am a dog and I like to bark'}
    ],
    tokenize = False,
    padding = 'max_length',
    truncation = True,
    max_length = 512,
    add_generation_prompt = False
)
print(s)
# print(re.sub(r'(Current date:\s*)\d{4}-\d{2}-\d{2}', '2025-09-01', s))

In [None]:
import importlib
import utils.role_templates
importlib.reload(utils.role_templates)
from utils.role_templates import render_single_message

In [None]:
print(render_single_message(model_architecture, role = 'assistant-cot',  content ='Hi'))

In [None]:
"""
Load dataset - C4 + HPLT2
"""
n_sample_size = 100

def load_raw_ds():

    # def get_c4():
    #     return load_dataset('allenai/c4', 'en', split = 'validation', streaming = True).shuffle(seed = seed, buffer_size = 50_000)
    
    def get_hplt2():
        return load_dataset('HPLT/HPLT2.0_cleaned', 'eng_Latn', split = 'train', streaming = True).shuffle(seed = seed, buffer_size = 50_000)

    def get_data(ds, n_samples, data_source): # en
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append({'text': sample['text'], 'source': data_source})
        return raw_data
    
    # return get_data(get_c4(), n_sample_size, 'c4-en')    
    return get_data(get_hplt2(), n_sample_size, 'c4-en')    

raw_data = load_raw_ds()

In [None]:
##### ADD SYSTEM PROMPT?
# import textwrap

# if model_architecture == 'gptoss':
#     system_prompt = tokenizer.bos_token + textwrap.dedent("""
#     <|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
#     Knowledge cutoff: 2024-06
#     Current date: 2025-06-28

#     Reasoning: medium

#     # Valid channels: analysis, commentary, final. Channel must be included for every message.
#     Calls to these tools must go to the commentary channel: 'functions'.<|end|>
#     """).strip()
# if model_architecture == 'olmo3':
#     system_prompt = tokenizer.bos_token + textwrap.dedent("""
#     <|im_start|>system
#     You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>
#     """).strip() + '\n'

# print(system_prompt)

In [None]:
"""
Function to get sample sequences
"""
truncated_texts = tokenizer.batch_decode(tokenizer([t['text'] for t in raw_data], padding = False, truncation = True, max_length = 640).input_ids)

def get_sample_seqs(sample_str):
    sample_seqs = {
        'system': ('system', sample_str, None),
        'user': ('user', sample_str, None),
        'cot': ('assistant-cot', sample_str, None),
        'assistant': ('assistant-final', sample_str, None),
        'tool': ('tool', sample_str, None)
    }
    return [
        {'role': k, 'prompt': render_single_message(model_architecture, v[0], v[1], v[2])}
        for k, v in sample_seqs.items()
    ]

input_list = flatten_list([
    [
        {'question_ix': i, 'question': x, **p}
        for p in get_sample_seqs(x)
    ]
    for i, x in enumerate(truncated_texts)
])

input_df =\
    pd.DataFrame(input_list)\
    .assign(prompt_ix = lambda df: list(range(len(df))))

input_df

In [None]:
""" 
Load dataset into a dataloader. The dataloader returns the original tokens - this is important for BPE tokenizers as otherwise it's difficult to reconstruct the correct string later!
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

def chunk_list(input_list, max_length):
    return [input_list[i:i + max_length] for i in range(0, len(input_list), max_length)]

# Create and chunk into lists of size 800 each - these will be the export breaks
test_dl = DataLoader(
    ReconstructableTextDataset(input_df['prompt'].tolist(), tokenizer, max_length = 768, prompt_ix = input_df['prompt_ix'].tolist()),
    batch_size = 64,
    shuffle = False,
    collate_fn = stack_collate
)

# Run states

In [None]:
"""
Run forward passes
"""
@torch.no_grad()
def run_and_export_states(model, dl: ReconstructableTextDataset, layers_to_keep_acts: list[int]):
    """
    Run forward passes on given model and store the decomposed sample_df plus hidden states

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits` and `all_pre_mlp_hidden_states`.
        @dl: A ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_pre_mlp_hidden_states` (see returned object description).

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `all_pre_mlp_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.

    Example:
        test_outputs = run_and_export_states(model, test_dl, layers_to_keep_acts = list(range(model_n_layers)))
    """
    all_pre_mlp_hidden_states = []
    sample_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']
        prompt_indices = batch['prompt_ix']

        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

        # Check no bugs by validating output/perplexity
        if batch_ix == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), output['logits'].size(-1)).detach().cpu().item()
            for i in range(min(20, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
                
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )
                
        prompt_indices_df = pd.DataFrame(
            [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
            columns = ['sequence_ix', 'prompt_ix']
        )
        
        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
            .assign(batch_ix = batch_ix)
        
        sample_dfs.append(sample_df)

        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_pre_mlp_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

        batch_ix += 1

    sample_df = pd.concat(sample_dfs, ignore_index = True).drop(columns = ['batch_ix', 'sequence_ix']) # Drop batch/seq_ix, since prompt_ix identifies
    all_pre_mlp_hidden_states = torch.cat(all_pre_mlp_hidden_states, dim = 0)

    return {
        'sample_df': sample_df,
        'all_pre_mlp_hs': all_pre_mlp_hidden_states
    }

act_map = list(range(0, model_n_layers, 2 if model_n_layers <= 24 else 4))
res = run_and_export_states(
    model,
    test_dl,
    layers_to_keep_acts = act_map
)

# Run probes


In [None]:
from utils.role_assignments import label_content_roles
import importlib
import utils.role_assignments
importlib.reload(utils.role_assignments)
from utils.role_assignments import label_content_roles

sample_df =\
    label_content_roles(model_architecture, res['sample_df'])\
    .assign(token_in_seq_ix = lambda df: df.groupby(['prompt_ix']).cumcount())\
    .assign(token_in_role_ix = lambda df: df.groupby(['prompt_ix', 'role']).cumcount())\
    .assign(sample_ix = lambda df: range(0, len(df)))

sample_df

In [None]:
sample_df['role'].drop_duplicates()

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + layer-wise dict
"""
all_pre_mlp_hs = res['all_pre_mlp_hs'].to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
# del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

In [None]:
"""
Run logistic regression probes to get role space models
"""
def fit_lr(x_train, y_train, x_test, y_test):
    """
    Fit a probe with a standard 80/20 split
    """
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', C = 1.0, max_iter = 5000, fit_intercept = True) # 1e-2 or 1e-3 reg
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    y_test_pred = lr_model.predict(x_test)
    return lr_model, accuracy, y_test_pred

def get_probe_result(layer_ix, label2id):
    """
    Get probe results for a single layer and label combination

    Params:
        @layer_ix: The layer index to train the probe on
        @label2id: The label-to-id mapping
    
    Description:
        Trains only on content space for given roles
    """
    # id2label = {v: k for k, v in label2id.items()}
    roles = list(label2id.keys())

    # Get valid samples
    valid_sample_df =\
        sample_df\
        .pipe(lambda df: df[(df['in_content_span'] == True) & (df['role'].notna()) & (df['role'].isin(roles))])

    # 80:20 prompt split
    prompt_ix_train, prompt_ix_test = cuml.train_test_split(valid_sample_df['prompt_ix'].unique(), test_size = 0.2, random_state = seed)
    train_df = valid_sample_df[valid_sample_df['prompt_ix'].isin(prompt_ix_train)]
    test_df = valid_sample_df[valid_sample_df['prompt_ix'].isin(prompt_ix_test)]

    # Get y labels
    role_labels_train_cp = cupy.asarray([label2id[r] for r in train_df['role']])
    role_labels_test_cp = cupy.asarray([label2id[r] for r in test_df['role']])

    # Get x labels
    x_train_cp = cupy.asarray(all_pre_mlp_hs[layer_ix][train_df['sample_ix'].tolist(), :].to(torch.float32).detach().cpu())
    x_test_cp = cupy.asarray(all_pre_mlp_hs[layer_ix][test_df['sample_ix'].tolist(), :].to(torch.float32).detach().cpu())
    
    uniq_train = np.unique(role_labels_train_cp.get())
    if len(uniq_train) < len(label2id):
        raise Exception(f"Skipping layer {layer_ix}, mapping {label2id}: missing roles in train", uniq_train)

    lr_model, test_acc, y_test_pred = fit_lr(x_train_cp, role_labels_train_cp, x_test_cp, role_labels_test_cp)

    print(f"Layer [{layer_ix}] with roles [{'+'.join(roles)}]: {test_acc:.2f}")

    # Optional: Return classification metrics
    results_df =\
        test_df\
        .assign(pred = y_test_pred.tolist())\
        .assign(pred = lambda df: df['pred'].map({v: k for k, v in label2id.items()}))\
        .assign(is_acc = lambda df: df['role'] == df['pred'])

    acc_by_role =\
        results_df\
        .groupby(['role', 'pred'], as_index = False)\
        .agg(count = ('sample_ix', 'count'))

    acc_by_pos =\
        results_df\
        .groupby('token_in_role_ix', as_index = False)\
        .agg(count = ('sample_ix', 'count'), acc = ('is_acc', 'mean'))

    return {
        'layer_ix': layer_ix, 'label2id': label2id,
        'roles': roles, 'probe': lr_model, 'accuracy': test_acc,
        'acc_by_role': acc_by_role, 'acc_by_pos': acc_by_pos
    }

layers_to_test = list(all_pre_mlp_hs.keys()) # [i for i in list(all_pre_mlp_hs.keys()) if i % 4 == 0]
mappings_to_test = [
    {'user': 0, 'assistant-cot': 1, 'assistant-final': 2},
    {'system': 0, 'user': 1, 'assistant-cot': 2, 'assistant-final': 3},
    {'system': 0, 'user': 1, 'assistant-cot': 2, 'assistant-final': 3, 'tool': 4},
    {'user': 0, 'assistant-cot': 1, 'assistant-final': 2, 'tool': 3}
]

all_probes = []
for layer_ix in tqdm(layers_to_test):
    for mapping in mappings_to_test:
        all_probes.append(get_probe_result(layer_ix, mapping))

# Generalization Test

In [None]:
# Fix role names
role_map = {'assistant-cot': 'cot', 'assistant-final': 'assistant'}
for probe in all_probes:
    probe['roles'] = tuple(role_map.get(r, r) for r in probe['roles'])

probe_layers = sorted(set(x['layer_ix'] for x in all_probes))
probe_roles = sorted(set(tuple(x['roles']) for x in all_probes))

print(f"Num probes: {str(len(all_probes))}")
print(f"Probe layers:\n  {', '.join([str(x) for x in probe_layers])}")
print(f"Probe roles:\n {'\n '.join([str(list(x)) for x in probe_roles])}")


In [None]:
import importlib
import utils.role_assignments
importlib.reload(utils.role_assignments)
from utils.role_assignments import label_content_roles

In [None]:
"""
Tests two important functions:
- `load_chat_template`: Overwrites the existing chat template in the tokenizer with one that better handles discrepancies.
    See docstring for the function for details. 
- `label_content_roles`: Takes an instruct-formatted text, then assigns each token to the correct roles.
"""
import importlib
import utils.role_assignments
import utils.role_templates
importlib.reload(utils.role_assignments)
label_content_roles = utils.role_assignments.label_content_roles
importlib.reload(utils.role_templates)
load_chat_template = utils.role_templates.load_chat_template

# Load custom chat templater
old_chat_template = tokenizer.chat_template
new_chat_template = load_chat_template(f'{ws}/utils/chat_templates', model_architecture)

def test_chat_template(tokenizer):
    s = tokenizer.apply_chat_template(
        [
            {'role': 'system', 'content': 'Test.'},
            {'role': 'user', 'content': 'Hi! I\'m a dog.'},
            {'role': 'assistant', 'content': '<think>The user is a dog!</think>Congrats!'},
            {'role': 'user', 'content': 'Thanks!'},
            {'role': 'assistant', 'content': '<think>Hmm, the user said thanks!</think>Anything else?'}
        ],
        tokenize = False,
        padding = 'max_length',
        truncation = True,
        max_length = 512,
        add_generation_prompt = False
    )
    print(s)
    return s

tokenizer.chat_template = old_chat_template
s = test_chat_template(tokenizer)
tokenizer.chat_template = new_chat_template
s = test_chat_template(tokenizer)

z =\
    pd.DataFrame({'input_ids': tokenizer(s)['input_ids']})\
    .assign(token = lambda df: tokenizer.convert_ids_to_tokens(df['input_ids']))\
    .assign(prompt_ix = 0, batch_ix = 0, sequence_ix = 0, token_ix = lambda df: df.index)

label_content_roles(model_architecture, z).tail(50)

In [None]:
"""
Run forward passes
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

@torch.no_grad()
def run_and_export_states(model, dl: ReconstructableTextDataset, layers_to_keep_acts: list[int]):
    """
    Run forward passes on given model and store the decomposed sample_df plus hidden states

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits` and `all_pre_mlp_hidden_states`.
        @dl: A ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_pre_mlp_hidden_states` (see returned object description).

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `all_pre_mlp_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.

    Example:
        test_outputs = run_and_export_states(model, test_dl, layers_to_keep_acts = list(range(model_n_layers)))
    """
    all_pre_mlp_hidden_states = []
    sample_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']
        prompt_indices = batch['prompt_ix']

        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

        # Check no bugs by validating output/perplexity
        if batch_ix == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), output['logits'].size(-1)).detach().cpu().item()
            for i in range(min(20, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
                
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )
                
        prompt_indices_df = pd.DataFrame(
            [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
            columns = ['sequence_ix', 'prompt_ix']
        )
        
        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
            .assign(batch_ix = batch_ix)
        
        sample_dfs.append(sample_df)

        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_pre_mlp_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

        batch_ix += 1

    sample_df = pd.concat(sample_dfs, ignore_index = True).drop(columns = ['batch_ix', 'sequence_ix']) # Drop batch/seq_ix, since prompt_ix identifies
    all_pre_mlp_hidden_states = torch.cat(all_pre_mlp_hidden_states, dim = 0)

    return {
        'sample_df': sample_df,
        'all_pre_mlp_hs': all_pre_mlp_hidden_states
    }

def chunk_list(input_list, max_length):
    return [input_list[i:i + max_length] for i in range(0, len(input_list), max_length)]

In [None]:
"""
Import conversations
"""
# filename = model_id.split('/')[-1].lower()
user_queries_df = pd.read_csv(f'{ws}/experiments/role-analysis/convs/gpt-oss-20b.csv')

print(f"Starting length: {len(user_queries_df['conv_id'].unique())} convs")

# 1. Remove convs with any missing responses
convs_with_missing = user_queries_df[user_queries_df[['user_query', 'cot', 'assistant']].isna().any(axis = 1)]['conv_id'].unique()
user_queries_df = user_queries_df[~user_queries_df['conv_id'].isin(convs_with_missing)]

print(f"After dropping missing: {len(user_queries_df['conv_id'].unique())} convs")

# 2. Filter out super long convs + convs with any message <= conv_min_length (to avoid issues with flag_message_type fn)
user_queries_df =\
    user_queries_df\
    .assign(total_len = lambda df:
        df.groupby('conv_id')['user_query'].transform(lambda x: x.str.len().sum()) +
        df.groupby('conv_id')['cot'].transform(lambda x: x.str.len().sum()) +
        df.groupby('conv_id')['assistant'].transform(lambda x: x.str.len().sum())
    )\
    .assign(min_len = lambda df: df[['user_query', 'cot', 'assistant']].apply(lambda x: x.str.len()).min(axis=1))\
    .assign(conv_min_len = lambda df: df.groupby('conv_id')['min_len'].transform('min'))\
    .query('total_len < 12000 and conv_min_len >= 5')\
    .drop(columns = ['total_len', 'min_len', 'conv_min_len'])

print(f"After dropping long convs: {len(user_queries_df['conv_id'].unique())} convs")

# 3. Filter out convs where any message is a substring of another message
#    (this causes dupe issues with the flag_message_type fn, which does not handle tokens in multiple string contexts)
def has_substring_messages(group):
    contents = group['user_query'].tolist() + group['cot'].tolist() + group['assistant'].tolist()
    for i, a in enumerate(contents):
        for j, b in enumerate(contents):
            if i != j and a in b:
                return True
    return False

bad_convs = user_queries_df.groupby('conv_id').filter(has_substring_messages)['conv_id'].unique()
user_queries_df = user_queries_df.pipe(lambda df: df[~df['conv_id'].isin(bad_convs)])
print(f"After dropping substr: {len(user_queries_df['conv_id'].unique())} convs")

# Combine into conv-level df
convs_df =\
    user_queries_df\
    .sort_values(['conv_id', 'user_query_ix'])\
    .assign(conv_id = lambda df: df.groupby('conv_id', sort = False).ngroup())\
    .groupby('conv_id')\
    .apply(lambda g: pd.Series({
        'dataset': g['dataset'].values[0],
        'messages': [
            msg
            for _, row in g.iterrows()
            for msg in [
                {'role': 'user', 'content': row['user_query']},
                {'role': 'cot', 'content': row['cot']},
                {'role': 'assistant', 'content': row['assistant']}
            ]
        ]
    }))\
    .reset_index()\
    [['conv_id', 'dataset', 'messages']]\
    .sample(n = 50)\
    .assign(conv_id = lambda df: range(len(df)))

print(f"Final: {len(convs_df['conv_id'].unique())} convs")

convs_df
convs = convs_df['messages'].tolist()

messages_df =\
    convs_df\
    .explode('messages')\
    .assign(message_ix = lambda df: df.groupby('conv_id').cumcount())\
    .assign(
        role = lambda df: df['messages'].apply(lambda x: x['role']),
        content = lambda df: df['messages'].apply(lambda x: x['content'])
    )\
    .drop(columns = 'messages')\
    .reset_index(drop = True)

convs = convs_df['messages'].tolist()

display(convs_df.head(5))
display(messages_df.head(5))
print(f"Total convs: {len(convs_df)}")

In [None]:
import textwrap

if model_architecture == 'gptoss':
    system_prompt = tokenizer.bos_token + textwrap.dedent("""
    <|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
    Knowledge cutoff: 2024-06
    Current date: 2025-06-28

    Reasoning: medium

    # Valid channels: analysis, commentary, final. Channel must be included for every message.
    Calls to these tools must go to the commentary channel: 'functions'.<|end|>
    """).strip()
if model_architecture == 'olmo3':
    system_prompt = tokenizer.bos_token + textwrap.dedent("""
    <|im_start|>system
    You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>
    """).strip() + '\n'

print(system_prompt)

In [None]:
"""
Prep untagged conversations
"""
all_convs_by_type = {}

# Prep conversations
def prep_conv(conv):
    return system_prompt + '\n\n'.join([x['content'] for x in conv])

all_convs_by_type['untagged'] = [prep_conv(conv) for conv in convs]
print(all_convs_by_type['untagged'][0])

# Find max input length
max_input_length =\
    tokenizer(all_convs_by_type['untagged'], padding = True, truncation = True, max_length = 1024 * 8, return_tensors = 'pt')\
    ['attention_mask'].sum(dim = 1).max().item()

print(f"Max input length: {max_input_length}")
print(f"Untagged convs: {len(all_convs_by_type['untagged'])}")

In [None]:
"""
Prep tagged conversations
"""
def fold_cot_into_final(convs):
    """Fold CoT into the following assistant message as a <think></think> tag."""
    result = []
    for conv in convs:
        new_conv = []
        it = iter(enumerate(conv))
        for i, msg in it:
            if msg['role'] == 'cot':
                if i + 1 >= len(conv) or conv[i + 1]['role'] != 'assistant':
                    raise ValueError("cot must be followed by assistant")
                _, next_msg = next(it)
                new_conv.append({
                    'role': 'assistant',
                    'content': f"<think>{msg['content']}</think>{next_msg['content']}"
                })
            else:
                new_conv.append({'role': msg['role'], 'content': msg['content']})
        result.append(new_conv)
    return result

final_convs = fold_cot_into_final(convs)
all_convs_by_type['tagged'] = [
    # system_prompt + x 
    x
    for x in tokenizer.apply_chat_template(final_convs, tokenize = False, add_generation_prompt = False)
]
print(all_convs_by_type['tagged'][0])

# Find max input length
max_input_length =\
    tokenizer(all_convs_by_type['tagged'], padding = True, truncation = True, max_length = 1024 * 8, return_tensors = 'pt')\
    ['attention_mask'].sum(dim = 1).max().item()

print(f"Max input length: {max_input_length}")
print(f"Tagged convs: {len(all_convs_by_type['tagged'])}")

In [None]:
"""
Run forward passes
"""
# Get input df
input_convs_df =\
    pd.concat([
        pd.DataFrame({'convs': v, 'conv_type': conv_type})\
            .assign(conv_id = lambda df: range(0, len(df)))\
            .merge(convs_df, how = 'inner', on = 'conv_id')
        for conv_type, v in all_convs_by_type.items()
    ], ignore_index = True)\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))
    
display(input_convs_df)

convs_dl = DataLoader(
    ReconstructableTextDataset(input_convs_df['convs'].tolist(), tokenizer, max_length = 1024 * 4, prompt_ix = list(range(0, len(input_convs_df)))),
    batch_size = 8,
    shuffle = False,
    collate_fn = stack_collate
)

convs_outputs = run_and_export_states(model, convs_dl, layers_to_keep_acts = probe_layers)

In [None]:
"""
Label messages
"""
from utils.substring_assignments import flag_message_types

sample_dfs_by_conv = [group for _, group in convs_outputs['sample_df'].groupby('prompt_ix', sort = True)]
metadata_by_conv = input_convs_df.to_dict('records')

all_res = []

# Iterate through (input metadata, token-level sample df) pairs
for conv_metadata, sample_df_for_conv in tqdm(zip(metadata_by_conv, sample_dfs_by_conv)):

    content_spans = [msg['content'] for msg in conv_metadata['messages']]
    content_roles = [msg['role'] for msg in conv_metadata['messages']]
    try:
        res = \
            flag_message_types(sample_df_for_conv, content_spans)\
            .merge(
                pd.DataFrame({'role': content_roles, 'base_message_ix': range(len(content_spans))}),
                on = 'base_message_ix',
                how = 'left'
            )
        
    except Exception as e:
        print(e)
        continue
    
    all_res.append(res)

display(pd.concat(all_res).pipe(lambda df: df[df['role'].isna()]))
        
sample_df_labeled =\
    pd.concat(all_res).reset_index(drop = True)\
    .assign(sample_ix = lambda df: range(0, len(df)))\
    .assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount())

sample_df_labeled

In [None]:
"""
Clean up conversation pre_mlp_hs
"""
convs_pre_mlp_hs = convs_outputs['all_pre_mlp_hs'].to(torch.float16)
# del conv_outputs
convs_pre_mlp_hs = {layer_ix: convs_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(probe_layers)} # Match layers_to_keep_act

print(convs_pre_mlp_hs[0].shape)
print(len(sample_df_labeled))
gc.collect()

In [None]:
"""
Run projections
"""
test_layer = 12 # probe_layers[int(np.ceil(0.5 * len(probe_layers)).item())]
test_roles = ['user', 'cot', 'assistant'] # system, user, cot, assistant, tool
probe = [x for x in all_probes if sorted(x['roles']) == sorted(test_roles) and x['layer_ix'] == test_layer][0]

project_sample_df = sample_df_labeled # Or drop non-roles
project_hs_cp = cupy.asarray(convs_pre_mlp_hs[test_layer][project_sample_df['sample_ix'].tolist(), :])
project_probs = probe['probe'].predict_proba(project_hs_cp).round(8)

proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
if len(proj_results) != len(project_sample_df):
    raise Exception("Error!")

role_level_df =\
    pd.concat([proj_results, project_sample_df[['sample_ix']]], axis = 1)\
    .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
    .merge(project_sample_df, on = 'sample_ix', how = 'inner')\
    .merge(input_convs_df[['prompt_ix', 'dataset', 'conv_id', 'conv_type']], on = 'prompt_ix')

## GET TOKEN-LEVEL IS IT ACCURATE?
# token_level_df

role_level_df

In [None]:
role_level_df\
    .pipe(lambda df: df[df['conv_type'] == 'tagged'])\
    .pipe(lambda df: df[~df['role'].isna()])\
    .pipe(lambda df: df[(df['token_in_prompt_ix'] >= 0) & (df['token_in_prompt_ix'] <= 10000)])\
    .groupby(['dataset', 'role_space', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_prob = ('prob', 'mean'))\
    .groupby(['dataset', 'role_space', 'role'], as_index = False)\
    .agg(mean_prob = ('mean_prob', 'mean'))\
    .pivot(index = ['role', 'dataset'], columns = 'role_space', values = 'mean_prob')

In [None]:
role_level_df\
    .pipe(lambda df: df[df['conv_type'] == 'untagged'])\
    .pipe(lambda df: df[~df['role'].isna()])\
    .pipe(lambda df: df[(df['token_in_prompt_ix'] >= 0) & (df['token_in_prompt_ix'] <= 10000)])\
    .groupby(['dataset', 'role_space', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_prob = ('prob', 'mean'))\
    .groupby(['dataset', 'role_space', 'role'], as_index = False)\
    .agg(mean_prob = ('mean_prob', 'mean'))\
    .pivot(index = ['role', 'dataset'], columns = 'role_space', values = 'mean_prob')

In [None]:
"""
Run projections
"""
all_projs = []

project_sample_df = sample_df_labeled # Or drop non-roles

for probe in tqdm(all_probes):

    project_hs_cp = cupy.asarray(convs_pre_mlp_hs[probe['layer_ix']][project_sample_df['sample_ix'].tolist(), :])

    project_probs = probe['probe'].predict_proba(project_hs_cp).round(4)
    proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
    if len(proj_results) != len(project_sample_df):
        raise Exception("Error!")

    role_level_df =\
        pd.concat([proj_results, project_sample_df[['sample_ix']]], axis = 1)\
        .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
        .merge(project_sample_df, on = 'sample_ix', how = 'inner')\
        .merge(input_convs_df[['prompt_ix', 'dataset', 'conv_id', 'conv_type']], on = 'prompt_ix')\
        .assign(layer_ix = probe['layer_ix'])\
        .assign(roles = ','.join(probe['roles']))\
        .drop(columns = ['base_message'])

    all_projs.append(role_level_df)

all_projs_df =\
    pd.concat(all_projs)\
    .reset_index(drop = True)\
    .drop(columns = ['output_prob', 'output_id'])\
    .pipe(lambda df: df[~df['role'].isna()])

In [None]:
all_projs_df\
    .pipe(lambda df: df[df['roles'] == 'system,user,cot,assistant'])\
    .pipe(lambda df: df[df['conv_type'] == 'tagged'])\
    .pipe(lambda df: df[df['role_space'] == df['role']])\
    .drop(columns = 'role_space')\
    .rename(columns = {'prob': 'acc'})\
    .groupby(['roles', 'layer_ix', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_acc = ('acc', 'mean'))\
    .groupby(['roles', 'layer_ix', 'role'], as_index = False)\
    .agg(mean_acc = ('mean_acc', 'mean'))\
    .pivot(index = ['roles', 'layer_ix'], columns = 'role', values = 'mean_acc')\
    .reset_index()\
    .rename_axis(columns = None)

## Run qualitative tests