In [None]:
"""
Train probes
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Glm4vForConditionalGeneration, AutoModelForImageTextToText
from transformers.loss.loss_utils import ForCausalLMLoss
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import gc
import pickle
from termcolor import colored
import importlib
import cupy
import cuml
import sklearn
import yaml
from packaging import version
import transformers

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_outputs import convert_outputs_to_df_fast

main_device = 'cuda:0'
seed = 123

ws = '/workspace/deliberative-alignment-jailbreaks'
version = version.parse(transformers.__version__).major

clear_all_cuda_memory()
check_memory()

# Load model

In [None]:
"""
Load the base tokenizer/model
"""
selected_model_index = 0

def get_model(index):
    """
    - HF model id, model prefix (short model identifier), model arch
    - Attn implementation, whether to use the HF default implementation, # hidden layers
    """
    models = {
        0: ('openai/gpt-oss-20b', 'gptoss-20b', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 24),
        1: ('openai/gpt-oss-120b', 'gptoss-120b', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 36),
        2: ('nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16', 'nemotron-3-nano', 'nemotron3', None, False, 52),
        3: ('Qwen/Qwen3-30B-A3B-Thinking-2507', 'qwen3-30b-a3b', 'qwen3moe', None, True, 48),
        4: ('ai21labs/AI21-Jamba-Reasoning- 3B', 'jamba-reasoning', 'jamba', None, True, 28),
        5: ('ServiceNow-AI/Apriel-1.6-15b-Thinker', 'apriel-1.6-15b-thinker', 'apriel', None, True, 48),
        6: ('allenai/Olmo-3-7B-Think', 'olmo3-7b-think', 'olmo3', None, True, 32),
        7: ('zai-org/GLM-4.6V-Flash', 'glm-4.6v-flash', 'glm46v', None, True, 40),
        8: ('zai-org/GLM-4.7-Flash', 'glm-4.7-flash', 'glm4moelite', None, True, 46)
    }
    return models[index]

def load_model_and_tokenizer(model_id, model_architecture, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
    load_params = {'cache_dir': cache_dir, 'dtype': 'auto', 'trust_remote_code': not model_use_hf, 'device_map': None, 'attn_implementation': model_attn}    
    if model_architecture == 'glm46v':
        model = Glm4vForConditionalGeneration.from_pretrained(model_id, **load_params).to(main_device).eval()
    elif model_architecture == 'apriel':
        model = AutoModelForImageTextToText.from_pretrained(model_id, **load_params).to(main_device).eval()
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id, **load_params).to(main_device).eval()

    return tokenizer, model

model_id, model_prefix, model_architecture, model_attn, model_use_hf, model_n_layers = get_model(selected_model_index)
tokenizer, model = load_model_and_tokenizer(model_id, model_architecture, model_attn, model_use_hf)

if model_architecture == 'gptoss':
    print(model.model.layers[0].mlp.experts.down_proj) # Precision should be MXFP4
    print(model.model.config._attn_implementation) # Attn should be FA3

if tokenizer.pad_token is None:
    print('Setting pad token automatically')
    tokenizer.pad_token = tokenizer.eos_token

if version == 5 and hasattr(model, 'set_experts_implementation'): # In transformers v5, avoid non-deterministic implss
    model.set_experts_implementation('eager')

check_memory()

In [None]:
"""
Load reverse-engineered forward pass functions (usage note - this can be replaced by simpler hooks if desired)
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

@torch.no_grad()
def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 640).to(model.device)
    original_results = model(**inputs, use_cache = False)
    custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = True)
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), custom_results['logits'].size(-1)).detach().cpu().item()
    print(f"LM loss: {loss}")
    print(f"Hidden states layers (pre-mlp | post-layer): {len(custom_results['all_pre_mlp_hidden_states'])} | {len(custom_results['all_hidden_states'])}")
    print(f"Hidden state size (pre-mlp | post-layer): {(custom_results['all_pre_mlp_hidden_states'][0].shape)} | {(custom_results['all_hidden_states'][0].shape)}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

In [None]:
"""
Test generation is sensible
"""
def test_generation():
    conv = tokenizer.apply_chat_template(
        [{"role": "user", "content": "Write a haiku about GPUs"},],
        tokenize = False,
        enable_thinking = True,
        add_generation_prompt = True
    )
    inputs = tokenizer(conv, return_tensors = 'pt')
    gen_ids = model.generate(inputs['input_ids'].to(main_device), max_new_tokens = 100, do_sample = False)
    print(tokenizer.batch_decode(gen_ids, skip_special_tokens = False)[0])    

test_generation()

# Train probes

## Prep dataset

In [None]:
"""
Load dataset - Dolma3 / C4
"""
n_sample_size = yaml.safe_load(open('experiment-config/probe.yaml'))[model_prefix]['n_sample_size']

def load_raw_ds():

    def get_c4():
        return load_dataset('allenai/c4', 'en', split = 'validation', streaming = True).shuffle(seed = seed, buffer_size = 50_000)
    
    def get_dolma3():
        return load_dataset('allenai/dolma3_mix-150B-1025', split = 'train', revision = '3a8349c', streaming = True).shuffle(seed = seed, buffer_size = 50_000)
    
    def get_data(ds, n_samples, data_source): # en
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append({'text': sample['text'], 'source': data_source})
        return raw_data
    
    return get_data(get_c4(), int(n_sample_size * .25), 'c4-en') + get_data(get_dolma3(), int(n_sample_size * .75), 'dolma3')

raw_data = load_raw_ds()

In [None]:
"""
Test rendering
"""
import utils.role_templates
importlib.reload(utils.role_templates)
from utils.role_templates import render_single_message, render_mixed_cot

def test_render():
    print(tokenizer.apply_chat_template(
        [
            {'role': 'user', 'content': 'Hi! I am a dog and I like to bark'},
            {'role': 'assistant', 'content': 'Hello! What a lovely dog you are!'}
        ],
        tokenize = False, padding = 'max_length', truncation = True, max_length = 512, add_generation_prompt = True,
        enable_thinking = True
    ))
    print('--')
    print(render_single_message(model_prefix, role = 'user', content ='Hi'))
    print('--')
    print(render_mixed_cot(model_prefix, 'The user...', 'Yes!'))
    return True

test_render()

In [None]:
"""
Validate forward passes work
"""
# Define opening texts to append at the start of each role for training. If multiple, will randomize equally.
train_prefixes = yaml.safe_load(open('experiment-config/probe.yaml'))[model_prefix]['train_prefixes']
for p in train_prefixes:
    print(p)

@torch.no_grad()
def test1():
    conv = tokenizer.apply_chat_template(
        [{'role': 'user', 'content': 'Hi stinky'}],
        tokenize = False,
        add_generation_prompt = True
    )
    inputs = tokenizer(conv, add_special_tokens = True, return_tensors = 'pt')
    gen_ids = model.generate(inputs['input_ids'].to(main_device), attention_mask = inputs['attention_mask'].to(main_device), max_new_tokens = 500, do_sample = False)
    df = pd.DataFrame({'token_id': gen_ids[0].tolist(), 'token': tokenizer.convert_ids_to_tokens(gen_ids[0].tolist()),})
    print(tokenizer.decode(df['token_id'].tolist(), skip_special_tokens = False))
    return df

@torch.no_grad()
def test2():
    """Quick test - can the model run forward passes correctly?"""
    for tr_prefix in train_prefixes:
        conv = tr_prefix + render_single_message(model_prefix, 'user', 'Where is Atlanta?') +\
            (
            '<|start|>assistant<|channel|>analysis<|message|>' # gpt-oss-*
            # '<|im_start|>assistant\n<think></think>' # Nemotron
            # '<|assistant|>\n<think>' # GLM-4.6V-Flash
            # '<|im_start|>assistant\n<think>\n' # Qwen3
            # '\n<|begin_assistant|>\nHere are my reasoning steps:\n' # Apriel-1.6-15B-Thinker
            # '<|assistant|><think>' # GLM-4.7-Flash
            )
        inputs = tokenizer(conv, add_special_tokens = False, return_tensors = 'pt')
        gen_ids = model.generate(inputs['input_ids'].to(main_device), attention_mask = inputs['attention_mask'].to(main_device), max_new_tokens = 500, do_sample = False)
        print(tokenizer.decode(gen_ids[0], skip_special_tokens = False))
        print('\n')

test1()
test2()

In [None]:
"""
Create sample sequences
"""
SEQLEN = yaml.safe_load(open('experiment-config/probe.yaml'))[model_prefix]['seq_len']
NESTED_REASONING = yaml.safe_load(open('experiment-config/probe.yaml'))[model_prefix]['nested_reasoning']

def get_sample_seqs_for_input_seq(probe_text, partner_text, prefix = ''):
    """
    Helper to convert a single input sequence into all examples

    Params
        @probe_text: The text we're extracting states from (appears in all roles)
        @partner_text: Random paired text (only used in merged sample's assistant position)
    """
    seqs = []
    for role in ['system', 'user', 'tool', 'cot']:
        seqs.append({
            'role': role,
            'prompt': prefix + render_single_message(model_prefix, role = role, content = probe_text)
        })

    # If merge AND LRM, render with CoT prefix
    if NESTED_REASONING:
        seqs.append({
            'role': 'assistant',
            'prompt': prefix + render_mixed_cot(model_prefix, cot = partner_text, assistant = probe_text)
        })
    # Else render standalone
    else:
        seqs.append({
            'role': 'assistant',
            'prompt': prefix + render_single_message(model_prefix, role = 'assistant', content = probe_text)
        })

    return seqs

def build_sample_seqs(train_prefixes):
    """
    Build all sample sequences
    """
    truncated_texts = tokenizer.batch_decode(tokenizer([t['text'] for t in raw_data], add_special_tokens = False, padding = False, truncation = True, max_length = SEQLEN).input_ids)
    n_seqs = len(truncated_texts)

    np.random.seed(seed)
    partner_lengths = ((np.random.beta(0.5, 4.0, size = n_seqs) * (SEQLEN/2 + 1)).astype(int)).tolist()
    partner_texts = [
        tokenizer.decode(tokenizer(text['text'], add_special_tokens = False, padding = False, truncation = True, max_length = int(partner_lengths[i])).input_ids)
        for i, text in enumerate(raw_data)
    ]
    perm = np.random.permutation(n_seqs)
    while np.any(perm == np.arange(n_seqs)):
        perm = np.random.permutation(n_seqs)

    sampled_prefixes = np.random.choice(train_prefixes, size = n_seqs)

    input_list = []
    for base_ix, base_text in enumerate(truncated_texts):
        partner_text = partner_texts[int(perm[base_ix])].strip()
        prefix = sampled_prefixes[base_ix]

        for seq in get_sample_seqs_for_input_seq(base_text, partner_text, prefix):
            row = {'question_ix': base_ix, 'question': base_text, **seq}
            input_list.append(row)

    input_df = pd.DataFrame(input_list).assign(prompt_ix = lambda df: list(range(len(df))))
    return input_df

input_df = build_sample_seqs(train_prefixes = train_prefixes)
display(input_df)

for p in [row['prompt'] for row in input_df.pipe(lambda df: df[df['question_ix'] == 2]).to_dict('records')]:
    print(p)
    print("=" * 80)

In [None]:
""" 
Load dataset into a dataloader which returns original tokens - important for BPE tokenizers to reconstruct the correct string later
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

max_seqlen = int(tokenizer(input_df['prompt'].tolist(), padding = True, truncation = False, return_tensors = 'pt')['attention_mask'].sum(dim = 1).max().item())
train_dl = DataLoader(
    ReconstructableTextDataset(input_df['prompt'].tolist(), tokenizer, max_length = max_seqlen, prompt_ix = input_df['prompt_ix'].tolist()),
    batch_size = 32,
    shuffle = False,
    collate_fn = stack_collate
)

## Get hidden states

In [None]:
"""
Run forward passes
"""
@torch.no_grad()
def run_and_export_states(model, dl: ReconstructableTextDataset, layers_to_keep_acts: list[int]):
    """
    Run forward passes on given model and store the decomposed sample_df plus hidden states

    Params:
        @model: The model to run forward passes on via `run_model_return_topk`. Should return a dict with keys `logits` and `all_hidden_states`.
        @dl: A ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, `original_tokens`, and `prompt_ix`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_hidden_states` (see returned object description).

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `all_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.

    Example:
        test_outputs = run_and_export_states(model, train_dl, layers_to_keep_acts = list(range(model_n_layers)))
    """
    all_hidden_states = []
    sample_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']
        prompt_indices = batch['prompt_ix']

        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

        # Check no bugs by validating output/perplexity
        if batch_ix == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), output['logits'].size(-1)).detach().cpu().item()
            for i in range(min(20, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
                
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )
                
        prompt_indices_df = pd.DataFrame(
            [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
            columns = ['sequence_ix', 'prompt_ix']
        )
        
        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
            .assign(batch_ix = batch_ix)
        
        sample_dfs.append(sample_df)

        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

    sample_df = pd.concat(sample_dfs, ignore_index = True).drop(columns = ['batch_ix', 'sequence_ix']) # Drop batch/seq_ix, since prompt_ix identifies
    all_hidden_states = torch.cat(all_hidden_states, dim = 0)

    return {
        'sample_df': sample_df,
        'all_hs': all_hidden_states
    }

layers_to_probe = list(range(0, model_n_layers, 4)) if model_n_layers >= 30 else list(range(0, model_n_layers, 2))
res = run_and_export_states(model, train_dl, layers_to_keep_acts = layers_to_probe)

# Convert to f16 for cupy compatability
all_probe_hs = res['all_hs'].to(torch.float16)
all_probe_hs = {layer_ix: all_probe_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(layers_to_probe)}
del res['all_hs']
gc.collect()

## Label roles


In [None]:
"""
Take each token and label it with its correct role (user/assistant/system/tool/cot); check the counts are as expected
"""
import utils.role_assignments
importlib.reload(utils.role_assignments)
from utils.role_assignments import label_content_roles
from utils.substring_assignments import flag_message_types

probe_sample_df = (
    # Flag all roles
    label_content_roles(model_prefix, res['sample_df'])    
    .assign(sample_ix = lambda df: range(0, len(df)))
    # Flag those that match target role (for retain later)
    .merge(input_df[['prompt_ix', 'role']].rename(columns = {'role': 'target_role'}), how = 'inner', on = 'prompt_ix')
    .assign(match_target_role = lambda df: np.where(df['role'] == df['target_role'], True, False))
    # Flag those that match the prepend (for drop later)
    .pipe(lambda df: flag_message_types(df, train_prefixes, allow_ambiguous = True))
    .drop(columns = 'base_message')
    # Drop prepend text + non-content tags
    .pipe(lambda df: df[(df['base_message_ix'].isna()) & (df['is_content'] == True) & (df['role'].notna()) & (df['match_target_role'])])
)

# Verify role counts are accurate (should be exactly equal for most models, except when models where role tags can be merged with content into single toks)
display(probe_sample_df.groupby('role', as_index = False).agg(count = ('sample_ix', 'count')))

# Validate roles are flagged correctly
display(
    probe_sample_df\
    .pipe(lambda df: df[df['prompt_ix'] <= 14])\
    .groupby(['prompt_ix', 'seg_ix', 'role'], as_index = False)\
    .agg(combined_text = ('token', ''.join))\
    .assign(eot = lambda df: df['combined_text'].str[-30:])
)

## Run probes

In [None]:
"""
First run a quick grid search over hyperparmeters
"""
SKIP_FIRST_N = 32 if NESTED_REASONING else 0

def fit_lr(x_train, y_train, x_test, y_test, add_scaling = False, **lr_params):
    """
    Fit a probe
    """
    steps = []
    if add_scaling:
        steps.append(('scaler', cuml.preprocessing.StandardScaler()))
    steps.append(('clf', cuml.linear_model.LogisticRegression(
        penalty = 'l2',
        max_iter = 5_000,
        fit_intercept = True,
        **lr_params
    )))
    lr_model = sklearn.pipeline.Pipeline(steps)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    y_test_pred = lr_model.predict(x_test)
    y_test_prob = lr_model.predict_proba(x_test)
    nll = cuml.metrics.log_loss(y_test, y_test_prob,)
    return lr_model, accuracy, nll, y_test_pred

def get_probe_result(sample_df, layer_hs, roles_map, add_scaling = False, **lr_params):
    """
    Get probe results for a single layer and label combination

    Params:
        @sample_df: The sample-level df; with a column `sample_ix` indicating the token order of 0...T-1;
            the actual df may be shorter due to pre-filters
        @layer_hs: A tensor of probe hidden states for a layer, of T x D
        @roles_map: The mapping order of the roles; a dict {}

    Description:
        Trains only on content space for given roles
    """
    # Train/test split
    prompt_ix_train, prompt_ix_test = cuml.train_test_split(sample_df['prompt_ix'].unique(), test_size = 0.1, random_state = seed)
    train_df = sample_df[sample_df['prompt_ix'].isin(prompt_ix_train)]
    test_df = sample_df[sample_df['prompt_ix'].isin(prompt_ix_test)]

    # Get y labels
    role_labels_train_cp = cupy.asarray([roles_map[r] for r in train_df['role']])
    role_labels_test_cp = cupy.asarray([roles_map[r] for r in test_df['role']])

    # Get x labels
    x_train_cp = cupy.asarray(layer_hs[train_df['sample_ix'].tolist(), :].to(torch.float32).detach().cpu())
    x_test_cp = cupy.asarray(layer_hs[test_df['sample_ix'].tolist(), :].to(torch.float32).detach().cpu())
    
    if (len(train_df) != x_train_cp.shape[0]):
        raise Exception(f"Shape mismatch!")

    uniq_train = np.unique(role_labels_train_cp.get())
    if len(uniq_train) < len(roles_map):
        raise Exception(f"Skipping mapping {roles_map}: missing roles in train", uniq_train)

    lr_model, test_acc, test_nll, y_test_pred = fit_lr(
        x_train_cp, role_labels_train_cp, x_test_cp, role_labels_test_cp,
        add_scaling = add_scaling,
        **lr_params
    )

    # Classification metrics
    results_df =\
        test_df\
        .assign(pred = y_test_pred.tolist())\
        .assign(pred = lambda df: df['pred'].map({v: k for k, v in roles_map.items()}))\
        .assign(is_acc = lambda df: df['role'] == df['pred'])

    acc_by_role =\
        results_df\
        .groupby(['role', 'pred'], as_index = False)\
        .agg(count = ('sample_ix', 'count'))

    acc_by_pos =\
        results_df\
        .groupby('token_in_seg_ix', as_index = False)\
        .agg(count = ('sample_ix', 'count'), acc = ('is_acc', 'mean'))

    return {
        'probe': lr_model, 'acc': test_acc, 'nll': test_nll,
        'acc_by_role': acc_by_role, 'acc_by_pos': acc_by_pos
    }

def test_c(C, add_scaling):
    """
    Test hyperparams: mid-layer, roles for UCAT rolespace
    """
    test_roles = ['user', 'assistant', 'tool']
    test_layer_ix = layers_to_probe[(len(layers_to_probe) - 1) // 2] 
    probe_res = get_probe_result(
        sample_df = probe_sample_df[(probe_sample_df['role'].isin(test_roles)) & (probe_sample_df['token_in_seg_ix'] >= SKIP_FIRST_N)].reset_index(drop = True),
        layer_hs = all_probe_hs[test_layer_ix],
        roles_map = {x: i for i, x in enumerate(test_roles)},
        add_scaling = add_scaling,
        C = C
    )
    return {'C': C, 'add_scaling': add_scaling, 'val_acc': probe_res['acc'], 'val_nll': probe_res['nll']}

clear_all_cuda_memory()
gc.collect()

for scale_val in [True, False]:
    print(f"Scaling: {scale_val}")
    display(pd.DataFrame([test_c(c_val, add_scaling = scale_val) for c_val in tqdm([1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2])]))

In [None]:
"""
Run role probes at every fourth layer for role combinations specified in all_role_combinations (excl CoT role for non-reasoning models)
"""
train_params = yaml.safe_load(open('experiment-config/probe.yaml'))[model_prefix]['train_params']

all_role_combinations = [
    ('user', 'assistant'),
    ('user', 'assistant', 'tool'),
]

if model_prefix in ['gptoss-20b']:
    all_role_combinations += [
        ('user', 'cot', 'assistant'),
        ('user', 'cot', 'assistant', 'tool'),
        
        ('system', 'user', 'assistant'),
        ('system', 'user', 'assistant', 'tool'),
        ('system', 'user', 'cot', 'assistant'),
        ('system', 'user', 'cot', 'assistant', 'tool'),
    ]

all_role_combinations = [
    {
        'roles': list(roles),
        'roles_map': {x: i for i, x in enumerate(roles)},
        # Sample df for only those roles - filtering here is fine since we retain sample_ix which get_probe_result() uses to trace the original token
        'sample_df': probe_sample_df.pipe(lambda df: df[(df['role'].isin(roles)) & (df['token_in_seg_ix'] >= SKIP_FIRST_N)]).reset_index(drop = True),
    }
    for roles in all_role_combinations
]

clear_all_cuda_memory()
gc.collect()

all_probes = []
for roles_dict in all_role_combinations:
    print(f"Training {roles_dict['roles']}")
    for layer_ix in layers_to_probe:
        probe_res = get_probe_result(
            sample_df = roles_dict['sample_df'],
            layer_hs = all_probe_hs[layer_ix],
            roles_map = roles_dict['roles_map'],
            add_scaling = train_params['add_scaling'],
            C = train_params['C']
        )
        print(f"  Layer [{layer_ix}]: {probe_res['acc']:.2f}")

        all_probes.append({
            **probe_res,
            'layer_ix': layer_ix,
            'role_space': roles_dict['roles'],
            'roles_map': roles_dict['roles_map'],
            'n_inputs': len(roles_dict['sample_df'])
        })

print(f"Num probes: {str(len(all_probes))}")
print(f"Probe layers:\n  {', '.join([str(x) for x in layers_to_probe])}")

In [None]:
"""
Save probes + metrics
"""
def validate_accuracy_and_save(all_probes):
    """
    Calculate accurcay by (rolespace, layer, role) and save
    """
    print('Val accuracy by layer:')
    display(
        pd.DataFrame(all_probes)[['layer_ix', 'role_space', 'acc']]\
            .assign(acc = lambda df: df['acc'].round(2), role_space = lambda df: df['role_space'].apply(lambda x: ','.join(([r[0] for r in x]))))\
            .pivot(index = 'layer_ix', columns = 'role_space', values = 'acc')
    )

    print('Concatenating across layers and saving...')
    acc_by_role = pd.concat([p['acc_by_role'].assign(model = model_prefix, layer_ix = p['layer_ix'], role_space = ','.join(p['role_space'])) for p in all_probes], ignore_index = True)
    acc_by_pos =\
        pd.concat([p['acc_by_pos'].assign(model = model_prefix, layer_ix = p['layer_ix'], role_space = ','.join(p['role_space'])) for p in all_probes], ignore_index = True)\
        .assign(acc = lambda df: df['acc'].round(4))

    acc_by_role.to_csv(f'{ws}/experiments/role-analysis/outputs/probe-training/acc_by_role_{model_prefix}.csv', index = False)
    acc_by_pos.to_csv(f'{ws}/experiments/role-analysis/outputs/probe-training/acc_by_pos_{model_prefix}.csv', index = False)

    with open(f'{ws}/experiments/role-analysis/outputs/probes/{model_prefix}.pkl', 'wb') as f:
        pickle.dump(all_probes, f)

    print('Accuracy by role:')
    base_sums = acc_by_role.groupby(['role', 'layer_ix', 'role_space'], as_index = False).agg(base_sum = ('count', 'sum'))

    acc_by_role =\
        acc_by_role\
        .pipe(lambda df: df[df['role'] == df['pred']])\
        .groupby(['role', 'layer_ix', 'role_space'], as_index = False)\
        .agg(sum = ('count', 'sum'))\
        .merge(base_sums, on = ['layer_ix', 'role_space', 'role'], how = 'inner')\
        .assign(acc = lambda df: df['sum']/df['base_sum'])\
        .pivot(index = ['role_space', 'layer_ix'], columns = 'role', values = 'acc')\
        .round(2)

    display(acc_by_role)

    return True

validate_accuracy_and_save(all_probes)

# Validate on real conversations

## Load data

In [None]:
"""
Validate IFT on the full chat template, which tokenizes the whole thing e2e. This should match the format from original messages.:
- `load_chat_template`: Overwrites the existing chat template in the tokenizer with one that better handles discrepancies.
    See docstring for the function for details. 
- `label_content_roles`: Takes an instruct-formatted text, then assigns each token to the correct roles.
"""
import importlib
import utils.role_assignments
import utils.role_templates
importlib.reload(utils.role_assignments)
importlib.reload(utils.role_templates)
load_chat_template = utils.role_templates.load_chat_template
label_content_roles = utils.role_assignments.label_content_roles

def test_chat_template(tokenizer):
    s = tokenizer.apply_chat_template(
        [
            {'role': 'system', 'content': 'Test.'},
            {'role': 'user', 'content': 'Hi! I\'m a dog.'},
            {'role': 'assistant', 'content': '<think>The user is a dog!</think>Congrats!'},
            {'role': 'user', 'content': 'Thanks!'},
            {'role': 'assistant', 'content': '<think>Hmm, the user said thanks!</think>Anything else?'}
        ],
        tokenize = False,
        padding = 'max_length',
        truncation = True,
        max_length = 512,
        add_generation_prompt = True
    )
    print(s)
    return s

# Original chat template
if 'base_chat_template' not in globals():
    base_chat_template = tokenizer.chat_template
tokenizer.chat_template = base_chat_template
s = test_chat_template(tokenizer)

# New chat template
tokenizer.chat_template = load_chat_template(f'{ws}/utils/chat_templates', model_prefix)
s = test_chat_template(tokenizer)

# Label content roles
tokens_df = (
    pd.DataFrame({'input_ids': tokenizer(s, add_special_tokens = False)['input_ids']})
    .assign(token = lambda df: tokenizer.convert_ids_to_tokens(df['input_ids']))
    .assign(prompt_ix = 0, token_ix = lambda df: df.index)
)
label_content_roles(model_prefix, tokens_df).head(50)

In [None]:
"""
Define opening text to append at the start of each role
"""
test_prefix = yaml.safe_load(open('experiment-config/probe.yaml'))[model_prefix]['test_prefix']

def validate_test_prefix():
    conv = test_prefix + tokenizer.apply_chat_template(
        [{'role': 'user', 'content': 'Hi stinky pupper! How are you?'}],
        tokenize = False,
        add_generation_prompt = True
    )

    inputs = tokenizer(conv, add_special_tokens = False, return_tensors = 'pt')
    gen_ids = model.generate(inputs['input_ids'].to(main_device), attention_mask = inputs['attention_mask'].to(main_device), max_new_tokens = 500, do_sample = False)
    print(tokenizer.decode(gen_ids[0]))

    df = pd.DataFrame({'token_id': gen_ids[0].tolist(), 'token': tokenizer.convert_ids_to_tokens(gen_ids[0].tolist()),})
    return df

display(validate_test_prefix().head(10))

In [None]:
"""
Import conversations
"""
def filter_convs(user_queries_df, max_samples, max_total_len = 10000, min_msg_len = 100, max_user_messages_per_conv = 2):
    """
    Filters conversation rows and returns:
      - convs: list[list[{'role': ..., 'content': ...}]]
      - convs_df: conv-level dataframe with columns ['conv_id', 'dataset', 'messages']

    Assumes input has columns:
        conv_id, user_query_ix, dataset, user_query, assistant, and optionally cot.
    """
    content_cols = ['user_query', 'assistant', 'cot']
    print(f"Starting length: {user_queries_df['conv_id'].nunique()} convs")
    
    # 1. Remove convs with any missing responses
    user_queries_df = user_queries_df.pipe(lambda df: df[~df['conv_id'].isin(df[df[content_cols].isna().any(axis=1)]['conv_id'].unique())])
    print(f"After dropping missing: {user_queries_df['conv_id'].nunique()} convs")

    # 2. Filter out (1) long convs, (2) convs with any message <= conv_min_length (to avoid stub messages), (3) convs with too many messages
    user_queries_df = (
        user_queries_df
        .pipe(lambda df: df.assign(**{c: df[c].astype(str).str.strip() for c in content_cols}))
        .assign(
            row_total_len = lambda df: df[content_cols].apply(lambda r: r.str.len()).sum(axis=1),
            row_min_len = lambda df: df[content_cols].apply(lambda r: r.str.len()).min(axis=1),
        )
        .assign(
            total_len = lambda df: df.groupby("conv_id")["row_total_len"].transform("sum"),
            conv_min_len = lambda df: df.groupby("conv_id")["row_min_len"].transform("min"),
            user_turns=lambda df: df.groupby("conv_id")["user_query_ix"].transform("count"),
        )
        .query("total_len < @max_total_len and conv_min_len >= @min_msg_len and user_turns <= @max_user_messages_per_conv")
        .drop(columns = ["row_total_len", "row_min_len", "total_len", "conv_min_len", "user_turns"])
    )
    print(f"After dropping long convs: {user_queries_df['conv_id'].nunique()} convs")

    # 3. Drop convs with substring messages (causes dupe issues with the flag_message_type fn, which does not handle tokens in multiple string contexts)
    def has_subtr_messages(g):
        msgs = sum([g[c].tolist() for c in ["user_query", "assistant", "cot"]], [])
        # sort by length so we only check short-in-long; reduces comparisons and removes symmetric i/j loop
        msgs = sorted(msgs, key=len)
        return any(
            (i != j) and (a in b)
            for i, a in enumerate(msgs)
            for j, b in enumerate(msgs)
            if len(a) <= len(b) and (i < j)  # only short->long checks
        )            
    
    bad_conv_ids = user_queries_df.groupby("conv_id", sort=False).apply(has_subtr_messages, include_groups=False).reset_index(name="bad").query("bad")["conv_id"].unique()
    
    user_queries_df = user_queries_df.pipe(lambda df: df[~df["conv_id"].isin(bad_conv_ids)])
    print(f"After dropping substr: {user_queries_df['conv_id'].nunique()} convs")

    # 4. Combine into conv-level df
    def rows_to_messages(g):
        msgs = []
        for _, row in g.iterrows():
            msgs.append({"role": "user", "content": row["user_query"]})
            msgs.append({"role": "cot", "content": row["cot"]})
            msgs.append({"role": "assistant", "content": row["assistant"]})
        return msgs

    convs_df = (
        user_queries_df
        .sort_values(["conv_id", "user_query_ix"])
        .assign(conv_id=lambda df: df.groupby("conv_id", sort=False).ngroup())
        .groupby("conv_id", sort=False)
        .apply(lambda g: pd.Series({"dataset": g["dataset"].values[0], "messages": rows_to_messages(g)}), include_groups = False)
        .reset_index()
        [["conv_id", "dataset", "messages"]]
        .pipe(lambda df: df.sample(n = min(max_samples, len(df)), random_state = seed))
        .reset_index(drop = True)
        .assign(conv_id = lambda df: range(len(df)))
    )
    print(f"Final: {convs_df['conv_id'].nunique()} convs")
    return convs_df['messages'].tolist(), convs_df


convs, convs_df = filter_convs(
    user_queries_df = pd.read_csv(f'{ws}/experiments/role-analysis/data/conversations/{model_prefix}.csv'),
    max_samples = 30
)

display(convs_df.head(5))
print(f"Total convs: {len(convs_df)}")

In [None]:
"""
Prep untagged/tagged/mistagged conversations
"""
from utils.role_templates import fold_cot_into_final

all_convs_by_type = {}
seperators = yaml.safe_load(open('experiment-config/probe.yaml'))[model_prefix]['test_seperators']
print(seperators)

# 1. Prep untagged convs
def prep_untagged_conv(conv, start_sep, role_seperator):
    """Create the untagged conversation, matching linebreak struct of model"""
    return test_prefix + start_sep + role_seperator.join([x['content'] for x in conv])

all_convs_by_type['untagged'] = [prep_untagged_conv(conv, seperators['untagged_start_sep'], seperators['untagged_role_sep']) for conv in convs]

# 2. Prep tagged convs
all_convs_by_type['tagged'] = [
    test_prefix + seperators['tagged_start_sep'] + x
    for x in tokenizer.apply_chat_template(fold_cot_into_final(convs), tokenize = False, add_generation_prompt = False)
]

# 3. Prep mistagged convs
def prep_mistagged_conv(conv, start_sep, role_seperator, role, model_prefix):
    """
    Create the mistagged conversation in <user> and <tool> tags
    - Uses the tagged seperator between the start and the start of the new role, since the new role is tagged
    - Uses the untagged seperator within the conv to seperate semantic roles, since there is no inner role seperation within the new role
    """
    return test_prefix + start_sep + render_single_message(model_prefix, role, role_seperator.join([x['content'] for x in conv]))

all_convs_by_type['user_tagged'] = [prep_mistagged_conv(conv, seperators['tagged_start_sep'], seperators['untagged_role_sep'], 'user', model_prefix) for conv in convs]
all_convs_by_type['tool_tagged'] = [prep_mistagged_conv(conv, seperators['tagged_start_sep'], seperators['untagged_role_sep'], 'tool', model_prefix) for conv in convs]

assert len(all_convs_by_type['untagged']) == len(all_convs_by_type['tagged'])
assert len(all_convs_by_type['untagged']) == len(all_convs_by_type['user_tagged'])
assert len(all_convs_by_type['untagged']) == len(all_convs_by_type['tool_tagged'])
assert all(len(v) == len(convs) for v in all_convs_by_type.values())

test_ix = 1
print('------------------- UNTAGGED ------------------')
print(all_convs_by_type['untagged'][test_ix])

print('------------------- TAGGED ------------------')
print(all_convs_by_type['tagged'][test_ix])

print('------------------- MISROLED (AS USER) ------------------')
print(all_convs_by_type['user_tagged'][test_ix])

print('------------------- MISROLED (AS TOOL) ------------------')
print(all_convs_by_type['tool_tagged'][test_ix])

## Run passes

In [None]:
# del res
gc.collect()
clear_all_cuda_memory()

In [None]:
"""
Run forward passes
"""
# Get input df
input_convs_df =\
    pd.concat([
        pd.DataFrame({'convs': v, 'conv_type': conv_type})\
            .assign(conv_id = lambda df: range(0, len(df)))\
            .merge(convs_df, how = 'inner', on = 'conv_id')
        for conv_type, v in all_convs_by_type.items()
    ], ignore_index = True)\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))
    
display(input_convs_df)

max_input_length =\
    tokenizer(input_convs_df['convs'].tolist(), padding = True, truncation = True, max_length = 1024 * 8, return_tensors = 'pt')\
    ['attention_mask'].sum(dim = 1).max().item()

print(f"Max input length: {max_input_length}")

convs_dl = DataLoader(
    ReconstructableTextDataset(input_convs_df['convs'].tolist(), tokenizer, max_length = max_input_length, prompt_ix = list(range(0, len(input_convs_df)))),
    batch_size = 8,
    shuffle = False,
    collate_fn = stack_collate
)

convs_outputs = run_and_export_states(model, convs_dl, layers_to_keep_acts = layers_to_probe)
convs_hs = convs_outputs['all_hs'].to(torch.float16)
convs_hs = {layer_ix: convs_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(layers_to_probe)}
gc.collect()

In [None]:
"""
Label messages
"""
import utils.substring_assignments
importlib.reload(utils.substring_assignments)
from utils.substring_assignments import flag_message_types

def label_convs(input_convs_df, sample_level_df):
    """
    Return a sample-level dataframe labeled appropriately
    
    Params:
        @input_convs_df: The conversation-level df
        @convs_output: The sample-level df returns by run_and_export_states
    """
    base_sample_df = sample_level_df.assign(sample_ix = lambda df: range(0, len(df)))
    sample_dfs_by_conv = [g for _, g in base_sample_df.groupby('prompt_ix', sort = True)]
    metadata_by_conv = input_convs_df.sort_values('prompt_ix').to_dict('records')

    all_res = []

    # Iterate through (input metadata, token-level sample df) pairs
    for conv_metadata, sample_df_for_conv in tqdm(zip(metadata_by_conv, sample_dfs_by_conv, strict = True)):
        content_spans = [m['content'].strip() for m in conv_metadata['messages']]
        content_roles = [m['role'] for m in conv_metadata['messages']]
        # Note: Set the last arg of flag_message_types for reasoning models whose CoTs quote the FULL user message
        # Setting the flag causes it to default to the first matching base_message (which should be user)
        try:
            res = \
                flag_message_types(sample_df_for_conv, content_spans, False)\
                .merge(pd.DataFrame({'role': content_roles, 'base_message_ix': range(len(content_spans))}), on = 'base_message_ix', how = 'left')
            assert len(res) == len(sample_df_for_conv)
            assert (res['sample_ix'].values == sample_df_for_conv['sample_ix'].values).all() 
            all_res.append(res)
        except Exception as e:
            print(e)
            continue
        
    print(len(all_res))
    display(pd.concat(all_res).pipe(lambda df: df[df['role'].isna()]))

    sample_df_labeled =\
        pd.concat(all_res).reset_index(drop = True)\
        .assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount())

    display(sample_df_labeled.pipe(lambda df: df[~df['role'].isna()])) # Validate all NA values are just tag toks

    # Validate message assignment seems reasonable
    display(
        sample_df_labeled\
        .pipe(lambda df: df[df['prompt_ix'] <= 5])\
        .groupby(['prompt_ix', 'base_message', 'role'], as_index = False)\
        .agg(combined_text = ('token', ''.join))\
        .assign(eot = lambda df: df['combined_text'].str[-30:])
    )

    # Validate counts per message seem reasonable
    display(
        sample_df_labeled\
        .pipe(lambda df: df[~df['role'].isna()])\
        .groupby(['prompt_ix', 'role'], as_index = False)\
        .agg(count = ('sample_ix', 'count'))\
        .pivot(index = ['prompt_ix'], columns = 'role', values = 'count')\
        .head(20)
    )
    return sample_df_labeled

sample_df_labeled = label_convs(input_convs_df, convs_outputs['sample_df'])
assert convs_hs[0].shape[0] == len(sample_df_labeled)
sample_df_labeled

In [None]:
clear_all_cuda_memory()
gc.collect()

## Run projections

In [None]:
"""
Run test projections for single layer
"""
TEST_ROLE_SPACE = ['user', 'cot', 'assistant', 'tool']
TEST_LAYER_IX = sorted(layers_to_probe)[len(layers_to_probe) // 2]

def run_projections(valid_sample_df, layer_hs, probe):
    """
    Run probe-level projections
    
    Params:
        @valid_sample_df: A sample-level df with columns `sample_ix` (1... T - 1), `sample_ix`.
            Can be shorter than full T - 1 due to pre-filters, as long as sample_ix corresponds to the full length.
        @layer_hs: A tensor of size T x D for the layer to project.
        @probe: The probe dict with keys `probe` (the trained model) and `role_space` (the roles list)
    
    Returns:
        A df at (sample_ix, target_role) level with cols `sample_ix`, `target_role`, `prob`
    """
    x_cp = cupy.asarray(layer_hs[valid_sample_df['sample_ix'].tolist(), :])
    y_cp = probe['probe'].predict_proba(x_cp).round(12)

    proj_results = pd.DataFrame(cupy.asnumpy(y_cp), columns = probe['role_space'])
    if len(proj_results) != len(valid_sample_df):
        raise Exception("Error!")

    role_df =\
        pd.concat([
            proj_results.reset_index(drop = True),
            valid_sample_df[['sample_ix']].reset_index(drop = True)
        ], axis = 1)\
        .melt(id_vars = ['sample_ix'], var_name = 'target_role', value_name = 'prob')\
        .reset_index(drop = True)

    return role_df

# top_train_layer =\
#     pd.DataFrame(all_probes)\
#     .assign(probe_ix = lambda df: range(0, len(df)))\
#     .pipe(lambda df: df[df['roles'].apply(lambda r: sorted(r) == sorted(test_roles))])\
#     .pipe(lambda df: df.nlargest(1, 'acc')).to_dict('records')[0]
# print(f"Top layer: {top_train_layer['layer_ix']}")
# test_layer = top_train_layer['layer_ix']

test_projections =\
    run_projections(
        valid_sample_df = sample_df_labeled.pipe(lambda df: df[(~df['role'].isna()) & (df['role'].isin(TEST_ROLE_SPACE))]),
        layer_hs = convs_hs[TEST_LAYER_IX],
        probe = [x for x in all_probes if sorted(x['role_space']) == sorted(TEST_ROLE_SPACE) and x['layer_ix'] == TEST_LAYER_IX][0]
    )\
    .merge(sample_df_labeled[['prompt_ix', 'sample_ix', 'token_in_prompt_ix', 'token', 'role']], how = 'inner', on = ['sample_ix'])\
    .merge(input_convs_df[['prompt_ix', 'dataset', 'conv_id', 'conv_type']], how = 'inner', on = 'prompt_ix', validate = 'many_to_one')

# Now aggregate up to (prompt, role, target_role) level
prompt_x_role_x_target_role_projections =\
    test_projections\
    .groupby(['conv_type', 'prompt_ix', 'role', 'target_role'], as_index = False)\
    .agg(
        combined_text = ('token', ''.join),
        mean_prob = ('prob', lambda x: np.mean(x).round(3))    
    )\
    .assign(eot = lambda df: df['combined_text'].str[-30:])\

display(prompt_x_role_x_target_role_projections.head(12))

display(
    prompt_x_role_x_target_role_projections\
    .groupby(['conv_type', 'target_role', 'role'], as_index = False)\
    .agg(mean_prob = ('mean_prob', 'mean'))\
    .pivot(index = ['conv_type', 'role'], columns = 'target_role', values = 'mean_prob')
)

In [None]:
print(test_projections\
    .pipe(lambda df: df[df['prompt_ix'] == 10])\
    .pipe(lambda df: df[df['role'] == ('cot')])\
    .groupby('target_role')\
    .agg(combined_text = ('token', ''.join))
    ['combined_text'].tolist()[0]
)

print('-----------')
print(test_projections\
    # .pipe(lambda df: df[df['conv_type'] == 'tagged'])\
    .pipe(lambda df: df[df['prompt_ix'] == len(convs) + 10])\
    .pipe(lambda df: df[df['role'] == ('cot')])\
    .groupby('target_role')\
    .agg(combined_text = ('token', ''.join))\
    ['combined_text'].tolist()[0]
)

In [None]:
print(test_projections\
    .pipe(lambda df: df[df['prompt_ix'] == 10])\
    .pipe(lambda df: df[df['role'] == ('user')])\
    .groupby('target_role')\
    .agg(combined_text = ('token', ''.join))
    ['combined_text'].tolist()[0]
)

print('-----------')
print(test_projections\
    # .pipe(lambda df: df[df['conv_type'] == 'tagged'])\
    .pipe(lambda df: df[df['prompt_ix'] == len(convs) + 10])\
    .pipe(lambda df: df[df['role'] == ('user')])\
    .groupby('target_role')\
    .agg(combined_text = ('token', ''.join))\
    ['combined_text'].tolist()[0]
)

In [None]:
"""
All-layer projections
"""
clear_all_cuda_memory()
gc.collect()

all_projs = [
    run_projections(
        valid_sample_df = sample_df_labeled.pipe(lambda df: df[(~df['role'].isna()) & (df['role'].isin(probe['role_space']))]),
        layer_hs = convs_hs[probe['layer_ix']],
        probe = probe
    ).assign(layer_ix = probe['layer_ix'], role_space = ''.join([x[0] for x in probe['role_space']]))

    for probe in tqdm(all_probes)
]

all_projs_df =\
    pd.concat(all_projs, ignore_index = True)\
    .merge(sample_df_labeled[['prompt_ix', 'sample_ix', 'token_in_prompt_ix', 'token', 'role']], how = 'inner', on = ['sample_ix'])\
    .merge(input_convs_df[['prompt_ix', 'dataset', 'conv_id', 'conv_type']], how = 'inner', on = 'prompt_ix', validate = 'many_to_one')

all_projs_df

In [None]:
"""
Validation
"""
out =\
    all_projs_df\
    .pipe(lambda df: df[df['target_role'] == df['role']])\
    .drop(columns = 'target_role')\
    .rename(columns = {'prob': 'acc'})\
    .groupby(['conv_type', 'role_space', 'layer_ix', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_acc = ('acc', 'mean'))\
    .groupby(['conv_type', 'role_space', 'layer_ix', 'role'], as_index = False)\
    .agg(mean_acc = ('mean_acc', 'mean'))

def draw_proj_val_table(input_df):

    def color_thresholds(v):
        if pd.isna(v):
            return ""
        if v > 0.9:
            return "background-color: #00c853; color: black;" # green
        elif v > 0.75:
            return "background-color: #aeea00; color: black;" # green-yellow
        elif v > 0.6:
            return "background-color: #ffeb3b; color: black;" # yellow
        return ""
    
    wide =\
        input_df\
        .assign(role = lambda df: np.where(df['role'] == 'assistant', 'asst', df['role']))\
        .pivot(index = ['conv_type', 'layer_ix'], columns = ['role_space', 'role'], values = 'mean_acc')

    col_groups = wide.columns.get_level_values(0)
    group_starts = [0] + [i for i in range(1, len(col_groups)) if col_groups[i] != col_groups[i-1]]
    divider_styles = []
    for i in group_starts:
        divider_styles += [
            {"selector": f"td.col{i}", "props": [("border-left", "3px solid #ddd")]},
            {"selector": f"th.col_heading.col{i}", "props": [("border-left", "3px solid #ddd")]},
        ]

    return wide.style.format("{:.2f}").map(color_thresholds).set_table_styles(divider_styles, overwrite = False)

display(draw_proj_val_table(out))
out.to_csv(f'{ws}/experiments/role-analysis/outputs/probe-projections/all_conv_projs_{model_prefix}.csv', index = False)

In [None]:
"""
Validation: real accuracy (argmax) instead of mean prob on true class
"""
hard_preds = (
    all_projs_df
    .sort_values(
        ["conv_type", "role_space", "layer_ix", "sample_ix", "prob"],
        ascending = [True, True, True, True, False],
    )
    .drop_duplicates(["conv_type", "role_space", "layer_ix", "sample_ix"])  # keep max-prob row
    .rename(columns = {"target_role": "pred_role", "prob": "pred_prob"})
    .assign(is_acc = lambda df: df["pred_role"] == df["role"])
)

out = (
    hard_preds
    .groupby(["conv_type", "role_space", "layer_ix", "role", "prompt_ix"], as_index = False)
    .agg(mean_acc = ("is_acc", "mean"))
    .groupby(["conv_type", "role_space", "layer_ix", "role"], as_index = False)
    .agg(mean_acc = ("mean_acc", "mean"))
)

display(draw_proj_val_table(out))
out.to_csv(f"{ws}/experiments/role-analysis/outputs/probe-projections/all_conv_acc_{model_prefix}.csv", index = False)

## Alt-model convs

In [None]:
"""
Prep wrong-convs
"""
all_model_files = [
    'haiku-4.5',
    'minimax-m2.1',
    'kimi-k2-thinking'
]

def get_multiple_model_convs(comparison_models):
    """
    Get multiple model conversations
    """
    m_dfs = [pd.read_csv(f'{ws}/experiments/role-analysis/data/conversations/{m}.csv').assign(model = m) for m in comparison_models]
    m_df = (
        pd.concat(m_dfs, ignore_index = True)
        .assign(conv_id = lambda df: df.groupby(['model', 'conv_id']).ngroup())
    )
    return m_df

alt_model_df = get_multiple_model_convs(all_model_files)
alt_convs, alt_convs_df = filter_convs(alt_model_df, max_samples = 30)

final_alt_convs = fold_cot_into_final(alt_convs)
all_wrong_convs_tagged = [
    test_prefix + seperators['tagged_start_sep'] + x
    for x in tokenizer.apply_chat_template(final_alt_convs, tokenize = False, add_generation_prompt = False)
]
print(all_wrong_convs_tagged[0])

# Find max input length
max_alt_input_length =\
    tokenizer(all_wrong_convs_tagged, padding = True, truncation = True, max_length = 1024 * 8, return_tensors = 'pt')\
    ['attention_mask'].sum(dim = 1).max().item()

print(f"Max input length: {max_alt_input_length}")
print(f"Tagged convs: {len(all_wrong_convs_tagged)}")


In [None]:
gc.collect()
clear_all_cuda_memory()
# del res

In [None]:
"""
Run forward passes
"""
# Get input df
input_alt_convs_df =\
    pd.concat([
        pd.DataFrame({'convs': v, 'conv_type': conv_type})\
            .assign(conv_id = lambda df: range(0, len(df)))\
            .merge(alt_convs_df, how = 'inner', on = 'conv_id')
        for conv_type, v in {'alt_tagged': all_wrong_convs_tagged}.items()
    ], ignore_index = True)\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))
    
display(input_alt_convs_df)

alt_convs_dl = DataLoader(
    ReconstructableTextDataset(input_alt_convs_df['convs'].tolist(), tokenizer, max_length = max_alt_input_length, prompt_ix = list(range(0, len(input_alt_convs_df)))),
    batch_size = 4,
    shuffle = False,
    collate_fn = stack_collate
)

alt_convs_outputs = run_and_export_states(model, alt_convs_dl, layers_to_keep_acts = layers_to_probe)
alt_convs_hs = alt_convs_outputs['all_hs'].to(torch.float16)
alt_convs_hs = {layer_ix: alt_convs_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(layers_to_probe)} # Match layers_to_keep_act
gc.collect()

In [None]:
"""
Label messages
"""
alt_sample_df_labeled = label_convs(input_alt_convs_df, alt_convs_outputs['sample_df'])
assert alt_convs_hs[0].shape[0] == len(alt_sample_df_labeled)
alt_sample_df_labeled

In [None]:
"""
All-layer projections
"""
clear_all_cuda_memory()
gc.collect()

alt_projs = [
    run_projections(
        valid_sample_df = alt_sample_df_labeled.pipe(lambda df: df[(~df['role'].isna()) & (df['role'].isin(probe['role_space']))]),
        layer_hs = alt_convs_hs[probe['layer_ix']],
        probe = probe
    ).assign(layer_ix = probe['layer_ix'], role_space = ''.join([x[0] for x in probe['role_space']]))

    for probe in tqdm(all_probes)
]

alt_projs_df =\
    pd.concat(alt_projs, ignore_index = True)\
    .merge(alt_sample_df_labeled[['prompt_ix', 'sample_ix', 'token_in_prompt_ix', 'token', 'role']], how = 'inner', on = ['sample_ix'])\
    .merge(input_alt_convs_df[['prompt_ix', 'dataset', 'conv_id', 'conv_type']], how = 'inner', on = 'prompt_ix', validate = 'many_to_one')

alt_projs_df

In [None]:
"""
Validation
"""
out2 =\
    alt_projs_df\
    .pipe(lambda df: df[df['target_role'] == df['role']])\
    .drop(columns = 'target_role')\
    .rename(columns = {'prob': 'acc'})\
    .groupby(['conv_type', 'role_space', 'layer_ix', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_acc = ('acc', 'mean'))\
    .groupby(['conv_type', 'role_space', 'layer_ix', 'role'], as_index = False)\
    .agg(mean_acc = ('mean_acc', 'mean'))

display(draw_proj_val_table(out2)) 
out2.to_csv(f'{ws}/experiments/role-analysis/outputs/probe-projections/alt_conv_projs_{model_prefix}.csv', index = False)

In [None]:
"""
Validation: real accuracy (argmax) instead of mean prob on true class
"""
hard_preds = (
    alt_projs_df
    .sort_values(
        ["conv_type", "role_space", "layer_ix", "sample_ix", "prob"],
        ascending = [True, True, True, True, False],
    )
    .drop_duplicates(["conv_type", "role_space", "layer_ix", "sample_ix"])  # keep max-prob row
    .rename(columns = {"target_role": "pred_role", "prob": "pred_prob"})
    .assign(is_acc = lambda df: df["pred_role"] == df["role"])
)

out = (
    hard_preds
    .groupby(["conv_type", "role_space", "layer_ix", "role", "prompt_ix"], as_index = False)
    .agg(mean_acc = ("is_acc", "mean"))
    .groupby(["conv_type", "role_space", "layer_ix", "role"], as_index = False)
    .agg(mean_acc = ("mean_acc", "mean"))
)

display(draw_proj_val_table(out))
out.to_csv(f"{ws}/experiments/role-analysis/outputs/probe-projections/alt_conv_acc_{model_prefix}.csv", index = False)

# Tomato example (run for gpt-oss-20b only)

In [None]:
"""
Define append prefix
"""
tomato_prefix = tokenizer.bos_token or ''
tomato_prefix

In [None]:
"""
Format tomato-style conversations
"""
def load_messages(yaml_path, model_key):
    """Loads system-like, assistant-like, user-like, etc. messages"""
    with open(yaml_path, 'r', encoding = 'utf-8') as f:
        data = yaml.safe_load(f)
    return [x['content'] for x in data[model_key]]

base_message_types = ['system', 'user', 'cot', 'assistant', 'user', 'cot', 'assistant']
base_messages = load_messages(f"{ws}/experiments/role-analysis/experiment-config/tomato.yaml", 'gptoss20')

tomato_prompts = {}
 
tomato_prompts['basic_no_format'] = tomato_prefix + ' ' + '\n\n'.join(base_messages)

tomato_prompts['everything_in_user_tags'] = tomato_prefix + tokenizer.apply_chat_template(
    [{'role': 'user', 'content': '\n\n'.join(base_messages)}],
    tokenize = False, add_generation_prompt = False
)

tomato_prompts['proper_tags'] = tomato_prefix + tokenizer.apply_chat_template(
    [
        {'role': 'system', 'content': base_messages[0]},
        {'role': 'user', 'content': base_messages[1]},
        {'role': 'assistant', 'content': f"<think>{base_messages[2]}</think>{base_messages[3]}"},
        {'role': 'user', 'content': base_messages[4]},
        {'role': 'assistant', 'content': f"<tool_call>{base_messages[5]}<tool_call>{base_messages[6]}"}
    ],
    tokenize = False, add_generation_prompt = False
)

print(tomato_prompts['proper_tags'])

tomato_input_df =\
    pd.DataFrame({
        'prompt': [p for _, p in tomato_prompts.items()],
        'prompt_key': [pk for pk, _ in tomato_prompts.items()]
    })\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))

In [None]:
"""
Run!!
"""
tomato_dl = DataLoader(
    ReconstructableTextDataset(tomato_input_df['prompt'].tolist(), tokenizer, max_length = 512 * 4, prompt_ix = tomato_input_df['prompt_ix'].tolist()),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

tomato_outputs = run_and_export_states(model, tomato_dl, layers_to_keep_acts = layers_to_probe)

tomato_hs = tomato_outputs['all_hs'].to(torch.float16)
tomato_hs = {layer_ix: tomato_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(layers_to_probe)} # Match layers_to_keep_act
print(tomato_hs[0].shape)
gc.collect()

In [None]:
"""
Ensure tokens/section is reasonable; map to base_message_type
"""
tomato_sample_df_labeled =\
    flag_message_types(tomato_outputs['sample_df'], base_messages)\
    .assign(sample_ix = lambda df: range(0, len(df)))\
    .assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount())\
    .merge(tomato_input_df, on = 'prompt_ix', how = 'inner')\
    .merge(
        pd.DataFrame({'base_message_ix': list(range(len(base_message_types))), 'base_message_type': base_message_types}),
        on = 'base_message_ix',
        how = 'inner'
    )

tomato_sample_df_labeled\
    .groupby(['prompt_ix', 'prompt_key', 'base_message_ix', 'base_message_type', 'base_message'], as_index = False)\
    .agg(
        count = ('token', 'count'),
        combined_text = ('token', ''.join)
    )

In [None]:
gc.collect()

tomato_projs = [
    run_projections(
        valid_sample_df = tomato_sample_df_labeled.pipe(lambda df: df[(~df['base_message_type'].isna()) & (df['base_message_type'].isin(probe['role_space']))]),
        layer_hs = tomato_hs[probe['layer_ix']],
        probe = probe
    ).assign(layer_ix = probe['layer_ix'], role_space = ''.join([x[0] for x in probe['role_space']]))

    for probe in tqdm(all_probes)
]

tomato_projs_df =\
    pd.concat(tomato_projs, ignore_index = True)\
    .merge(
        tomato_sample_df_labeled[['sample_ix', 'prompt_ix', 'prompt_key', 'base_message_ix', 'base_message_type', 'token_in_prompt_ix', 'token']],
        how = 'inner',
        on = ['sample_ix']
    )

tomato_projs_df

In [None]:
tomato_projs_df\
    .groupby(['prompt_ix', 'prompt_key', 'target_role', 'layer_ix', 'role_space', 'sample_ix'], as_index = False)\
    .agg(count = ('sample_ix', 'count'))

In [None]:
import plotly.express as px

facet_order = ['system', 'user', 'cot', 'assistant']
tomato_projs_df_single = tomato_projs_df.pipe(lambda df: df[(df['role_space'] == 'suca') & (df['layer_ix'] == 14)])

all_message_types = tomato_projs_df['base_message_type'].unique()
color_map = {msg_type: px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)] for i, msg_type in enumerate(all_message_types)}

group_cols = ['prompt_ix','prompt_key','base_message_type','target_role']

smooth_df = (
    tomato_projs_df_single
    .sort_values(group_cols + ['sample_ix'])
    .assign(
        prob_sma = lambda df: (
            df.groupby(group_cols)['prob']
              .rolling(window=10, min_periods=1).mean()
              .reset_index(level=list(range(len(group_cols))), drop=True)
        ),
        prob_ewma = lambda df: (
            df.groupby(group_cols)['prob']
              .ewm(alpha=.1, min_periods=1).mean()
              .reset_index(level=list(range(len(group_cols))), drop=True)
        ),
    )
)

display(smooth_df)

for this_prompt_key in smooth_df['prompt_key'].unique().tolist():
    this_df = smooth_df.pipe(lambda df: df[df['prompt_key'] == this_prompt_key])
    fig = px.scatter(
        this_df, x = 'token_in_prompt_ix', y = 'prob',
        facet_row = 'target_role',
        color = 'base_message_type',
        color_discrete_map = color_map,
        category_orders = {
            'target_role': facet_order,
            'base_message_type': ['system', 'user', 'cot', 'assistant', 'other']
        },
        hover_name = 'token',
        hover_data = {
            'prob': ':.4f'
        },
        # markers = True,
        title = f'prompt = {this_prompt_key}',
        labels = {
            'token_in_prompt_ix': 'Token Index',
            'prob': 'Prob',
            'prob_smoothed': 'Smoothed Prob',
            'target_role': 'role'
        }
    ).update_yaxes(
        range = [0, 1],
        side = 'left'
    ).update_layout(height = 500, width = 800)

    def pretty(a):
        a.update(
            text = a.text.split('=')[-1],
            x = 0.5, xanchor = "center",
            y = a.y + 0.08,
            textangle = 0,
            font = dict(size = 12),
            bgcolor = 'rgba(255, 255, 255, 0.9)',
            showarrow = False
        )

    fig.for_each_annotation(pretty)
    fig.update_yaxes(title_text = None)
    fig.show()

In [None]:
tomato_projs_df.to_csv(f"{ws}/experiments/role-analysis/outputs/probe-projections/tomato-role-projections-{model_prefix}.csv", index = False)