In [None]:
"""
Train probes on POST-MLP hs for intervention
"""
None

In [None]:
"""
Imports
"""
import torch
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import gc
import pickle
import importlib
import cupy
import cuml
import sklearn
import yaml

from utils.memory import check_memory, clear_all_cuda_memory
from utils.loader import load_model_and_tokenizer, load_custom_forward_pass
from utils.probes import run_and_export_states, run_projections, check_max_seq_len

main_device = 'cuda:0'
seed = 123

clear_all_cuda_memory()
check_memory()

ws = '/workspace/deliberative-alignment-jailbreaks'

# Load model

In [None]:
"""
Load the base tokenizer/model
"""
model_prefix = 'gptoss-20b'
tokenizer, model, model_architecture, model_n_layers = load_model_and_tokenizer(model_prefix, device = main_device)

check_memory()

In [None]:
"""
Load custom forward pass and verify equality to base model forward pass
"""
run_forward_with_hs = load_custom_forward_pass(model_architecture, model, tokenizer)

In [None]:
"""
Test generation is sensible
"""
def test_generation():
    conv = tokenizer.apply_chat_template(
        [{"role": "user", "content": "Write a haiku about GPUs"},],
        tokenize = False,
        enable_thinking = True,
        add_generation_prompt = True
    )
    inputs = tokenizer(conv, return_tensors = 'pt')
    gen_ids = model.generate(inputs['input_ids'].to(main_device), max_new_tokens = 100, do_sample = False)
    print(tokenizer.batch_decode(gen_ids, skip_special_tokens = False)[0])    

test_generation()

# Train probes

## Prep dataset

In [None]:
"""
Load SFT dataset
"""
n_sample_size = yaml.safe_load(open('./../role-analysis/config/probe.yaml'))[model_prefix]['n_sample_size']

def load_raw_ds():

    def get_c4():
        return load_dataset('allenai/c4', 'en', split = 'validation', streaming = True).shuffle(seed = seed, buffer_size = 50_000)
    
    def get_dolma3():
        return load_dataset('allenai/dolma3_mix-150B-1025', split = 'train', revision = '3a8349c', streaming = True).shuffle(seed = seed, buffer_size = 50_000)
    
    def get_data(ds, n_samples, data_source):
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append({'text': sample['text'], 'source': data_source})
        return raw_data
    
    return get_data(get_c4(), int(n_sample_size * .25), 'c4')  + get_data(get_dolma3(), int(n_sample_size * .75), 'dolma3')

raw_data = load_raw_ds()

In [None]:
"""
Test rendering
"""
import utils.role_templates
importlib.reload(utils.role_templates)
from utils.role_templates import render_single_message, render_mixed_cot

def test_render():
    print(tokenizer.apply_chat_template(
        [
            {'role': 'user', 'content': 'Hi! I am a dog and I like to bark'},
            {'role': 'assistant', 'content': 'Hello! What a lovely dog you are!'}
        ],
        tokenize = False, padding = 'max_length', truncation = True, max_length = 512, add_generation_prompt = True,
        enable_thinking = True
    ))
    print('--')
    print(render_single_message(model_prefix, role = 'user', content ='Hi'))
    print('--')
    print(render_mixed_cot(model_prefix, 'The user...', 'Yes!'))
    return True

test_render()

In [None]:
"""
Validate forward passes work
"""
# Define opening texts to append at the start of each role for training. If multiple, will randomize equally.
train_prefixes = yaml.safe_load(open('./../role-analysis/config/probe.yaml'))[model_prefix]['train_prefixes']
for p in train_prefixes:
    print(p)

@torch.no_grad()
def test1():
    conv = tokenizer.apply_chat_template(
        [{'role': 'user', 'content': 'Hi stinky'}],
        tokenize = False,
        add_generation_prompt = True
    )
    inputs = tokenizer(conv, add_special_tokens = True, return_tensors = 'pt')
    gen_ids = model.generate(inputs['input_ids'].to(main_device), attention_mask = inputs['attention_mask'].to(main_device), max_new_tokens = 500, do_sample = False)
    df = pd.DataFrame({'token_id': gen_ids[0].tolist(), 'token': tokenizer.convert_ids_to_tokens(gen_ids[0].tolist()),})
    print(tokenizer.decode(df['token_id'].tolist(), skip_special_tokens = False))
    return df

@torch.no_grad()
def test2():
    """Quick test - can the model run forward passes correctly?"""
    for tr_prefix in train_prefixes:
        conv = tr_prefix + render_single_message(model_prefix, 'user', 'Where is Atlanta?') +\
            (
            '<|start|>assistant<|channel|>analysis<|message|>' # gpt-oss-*
            # '<|im_start|>assistant\n<think></think>' # Nemotron
            # '<|assistant|>\n<think>' # GLM-4.6V-Flash
            # '<|im_start|>assistant\n<think>\n' # Qwen3
            # '\n<|begin_assistant|>\nHere are my reasoning steps:\n' # Apriel-1.6-15B-Thinker
            # '<|assistant|><think>' # GLM-4.7-Flash
            )
        inputs = tokenizer(conv, add_special_tokens = False, return_tensors = 'pt')
        gen_ids = model.generate(inputs['input_ids'].to(main_device), attention_mask = inputs['attention_mask'].to(main_device), max_new_tokens = 500, do_sample = False)
        print(tokenizer.decode(gen_ids[0], skip_special_tokens = False))
        print('\n')

test1()
test2()

In [None]:
"""
Create sample sequences
"""
SEQLEN = yaml.safe_load(open('./../role-analysis/config/probe.yaml'))[model_prefix]['seq_len']
NESTED_REASONING = yaml.safe_load(open('./../role-analysis/config/probe.yaml'))[model_prefix]['nested_reasoning']
GENERALIZE_PREFIX = False

def get_sample_seqs_for_input_seq(probe_text, partner_text, prefix = ''):
    """
    Helper to convert a single input sequence into all examples

    Params
        @probe_text: The text we're extracting states from (appears in all roles)
        @partner_text: Random paired text (only used in merged sample's assistant position)
    """
    seqs = []

    gen_prefix = partner_text if GENERALIZE_PREFIX else ''

    # Only test system roles for gptoss-20b for time
    if model_prefix in ['gptoss-20b']:
        seqs.append({
            'role': 'system',
            'prompt': prefix + gen_prefix + render_single_message(model_prefix, role = 'system', content = probe_text)
        })

    for role in ['user', 'tool', 'cot']:
        seqs.append({
            'role': role,
            'prompt': prefix + gen_prefix + render_single_message(model_prefix, role = role, content = probe_text)
        })

    # If merge AND LRM, render with CoT prefix
    if NESTED_REASONING:
        seqs.append({
            'role': 'assistant',
            'prompt': prefix + render_mixed_cot(model_prefix, cot = partner_text, assistant = probe_text)
        })
    # Else render standalone
    else:
        seqs.append({
            'role': 'assistant',
            'prompt': prefix + gen_prefix + render_single_message(model_prefix, role = 'assistant', content = probe_text)
        })

    return seqs

def build_sample_seqs(train_prefixes):
    """
    Build all sample sequences
    """
    truncated_texts = tokenizer.batch_decode(tokenizer([t['text'] for t in raw_data], add_special_tokens = False, padding = False, truncation = True, max_length = SEQLEN).input_ids)
    n_seqs = len(truncated_texts)

    np.random.seed(seed)
    partner_lengths = ((np.random.beta(0.5, 4.0, size = n_seqs) * (SEQLEN/2 + 1)).astype(int)).tolist()
    partner_texts = [
        tokenizer.decode(tokenizer(text['text'], add_special_tokens = False, padding = False, truncation = True, max_length = int(partner_lengths[i])).input_ids)
        for i, text in enumerate(raw_data)
    ]
    perm = np.random.permutation(n_seqs)
    while np.any(perm == np.arange(n_seqs)):
        perm = np.random.permutation(n_seqs)

    sampled_prefixes = np.random.choice(train_prefixes, size = n_seqs)

    input_list = []
    for base_ix, base_text in enumerate(truncated_texts):
        partner_text = partner_texts[int(perm[base_ix])].strip()
        prefix = sampled_prefixes[base_ix]

        for seq in get_sample_seqs_for_input_seq(base_text, partner_text, prefix):
            row = {'question_ix': base_ix, 'question': base_text, **seq}
            input_list.append(row)

    input_df = pd.DataFrame(input_list).assign(prompt_ix = lambda df: list(range(len(df))))
    return input_df

input_df = build_sample_seqs(train_prefixes = train_prefixes)
display(input_df)

for p in [row['prompt'] for row in input_df.pipe(lambda df: df[df['question_ix'] == 2]).to_dict('records')]:
    print(p)
    print("=" * 80)

In [None]:
""" 
Load dataset into a dataloader which returns original tokens - important for BPE tokenizers to reconstruct the correct string later
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

max_seqlen = check_max_seq_len(tokenizer, input_df['prompt'].tolist())
train_dl = DataLoader(
    ReconstructableTextDataset(input_df['prompt'].tolist(), tokenizer, max_length = max_seqlen, prompt_ix = input_df['prompt_ix'].tolist()),
    batch_size = 32,
    shuffle = False,
    collate_fn = stack_collate
)

## Get hidden states

In [None]:
"""
Run forward passes with POST-MLP HS
"""
layers_to_probe = list(range(0, model_n_layers, 4)) if model_n_layers >= 20 else list(range(0, model_n_layers, 2))
res = run_and_export_states(model, tokenizer, run_model_return_states = run_forward_with_hs, dl = train_dl, layers_to_keep_acts = layers_to_probe, extraction_key = 'all_hidden_states')

# Convert to f16 for cupy compatability
all_probe_hs = res['all_hs'].to(torch.float16)
all_probe_hs = {layer_ix: all_probe_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(layers_to_probe)}
del res['all_hs']
gc.collect()

## Label roles


In [None]:
"""
Take each token and label it with its correct role (user/assistant/system/tool/cot); check the counts are as expected
"""
import utils.role_assignments
importlib.reload(utils.role_assignments)
from utils.role_assignments import label_content_roles
from utils.substring_assignments import flag_message_types

probe_sample_df = (
    # Flag all roles
    label_content_roles(model_prefix, res['sample_df'])    
    .assign(sample_ix = lambda df: range(0, len(df)))
    # Flag those that match target role (for retain later)
    .merge(input_df[['prompt_ix', 'role']].rename(columns = {'role': 'target_role'}), how = 'inner', on = 'prompt_ix')
    .assign(match_target_role = lambda df: np.where(df['role'] == df['target_role'], True, False))
    # Flag those that match the prepend (for drop later)
    .pipe(lambda df: flag_message_types(df, train_prefixes, allow_ambiguous = True))
    .drop(columns = 'base_message')
    # Drop prepend text + non-content tags
    .pipe(lambda df: df[(df['base_message_ix'].isna()) & (df['is_content'] == True) & (df['role'].notna()) & (df['match_target_role'])])
)

# Verify role counts are accurate (should be exactly equal for most models, except when models where role tags can be merged with content into single toks)
display(probe_sample_df.groupby('role', as_index = False).agg(count = ('sample_ix', 'count')))

# Validate roles are flagged correctly
display(
    probe_sample_df\
    .pipe(lambda df: df[df['prompt_ix'] <= 14])\
    .groupby(['prompt_ix', 'seg_ix', 'role'], as_index = False)\
    .agg(combined_text = ('token', ''.join))\
    .assign(eot = lambda df: df['combined_text'].str[-30:])
)

## Run probes

In [None]:
"""
First run a quick grid search over hyperparmeters
"""
SKIP_FIRST_N = 32 if NESTED_REASONING else 0

def fit_lr(x_train, y_train, x_test, y_test, add_scaling = False, **lr_params):
    """
    Fit a probe
    """
    steps = []
    if add_scaling:
        steps.append(('scaler', cuml.preprocessing.StandardScaler()))
    steps.append(('clf', cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 5_000, fit_intercept = True, **lr_params)))
    lr_model = sklearn.pipeline.Pipeline(steps)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    y_test_pred = lr_model.predict(x_test)
    y_test_prob = lr_model.predict_proba(x_test)
    nll = cuml.metrics.log_loss(y_test, y_test_prob,)
    return lr_model, accuracy, nll, y_test_pred

def get_probe_result(sample_df, layer_hs, roles_map, add_scaling = False, **lr_params):
    """
    Get probe results for a single layer and label combination

    Params:
        @sample_df: The sample-level df; with a column `sample_ix` indicating the token order of 0...T-1;
            the actual df may be shorter due to pre-filters
        @layer_hs: A tensor of probe hidden states for a layer, of T x D
        @roles_map: The mapping order of the roles; a dict {}

    Description:
        Trains only on content space for given roles
    """
    # Train/test split
    prompt_ix_train, prompt_ix_test = cuml.train_test_split(sample_df['prompt_ix'].unique(), test_size = 0.1, random_state = seed)
    train_df = sample_df[sample_df['prompt_ix'].isin(prompt_ix_train)]
    test_df = sample_df[sample_df['prompt_ix'].isin(prompt_ix_test)]

    # Get y labels
    role_labels_train_cp = cupy.asarray([roles_map[r] for r in train_df['role']])
    role_labels_test_cp = cupy.asarray([roles_map[r] for r in test_df['role']])

    # Get x labels
    x_train_cp = cupy.asarray(layer_hs[train_df['sample_ix'].tolist(), :].to(torch.float32).detach().cpu())
    x_test_cp = cupy.asarray(layer_hs[test_df['sample_ix'].tolist(), :].to(torch.float32).detach().cpu())
    
    if (len(train_df) != x_train_cp.shape[0]):
        raise Exception(f"Shape mismatch!")

    uniq_train = np.unique(role_labels_train_cp.get())
    if len(uniq_train) < len(roles_map):
        raise Exception(f"Skipping mapping {roles_map}: missing roles in train", uniq_train)

    lr_model, test_acc, test_nll, y_test_pred = fit_lr(
        x_train_cp, role_labels_train_cp, x_test_cp, role_labels_test_cp,
        add_scaling = add_scaling,
        **lr_params
    )

    # Classification metrics
    results_df =\
        test_df\
        .assign(pred = y_test_pred.tolist())\
        .assign(pred = lambda df: df['pred'].map({v: k for k, v in roles_map.items()}))\
        .assign(is_acc = lambda df: df['role'] == df['pred'])

    acc_by_role =\
        results_df\
        .groupby(['role', 'pred'], as_index = False)\
        .agg(count = ('sample_ix', 'count'))

    acc_by_pos =\
        results_df\
        .groupby('token_in_seg_ix', as_index = False)\
        .agg(count = ('sample_ix', 'count'), acc = ('is_acc', 'mean'))

    return {
        'probe': lr_model, 'acc': test_acc, 'nll': test_nll,
        'acc_by_role': acc_by_role, 'acc_by_pos': acc_by_pos
    }

def test_c(C, add_scaling):
    """
    Test hyperparams: mid-layer, roles for UCAT rolespace
    """
    test_roles = ['user', 'assistant', 'tool']
    test_layer_ix = layers_to_probe[(len(layers_to_probe) - 1) // 2] 
    probe_res = get_probe_result(
        sample_df = probe_sample_df[(probe_sample_df['role'].isin(test_roles)) & (probe_sample_df['token_in_seg_ix'] >= SKIP_FIRST_N)].reset_index(drop = True),
        layer_hs = all_probe_hs[test_layer_ix],
        roles_map = {x: i for i, x in enumerate(test_roles)},
        add_scaling = add_scaling,
        C = C
    )
    return {'C': C, 'add_scaling': add_scaling, 'val_acc': probe_res['acc'], 'val_nll': probe_res['nll']}

clear_all_cuda_memory()
gc.collect()

# for scale_val in [False]:
#     print(f"Scaling: {scale_val}")
#     display(pd.DataFrame([test_c(c_val, add_scaling = scale_val) for c_val in tqdm([1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2])]))

In [None]:
"""
Run role probes at every fourth layer for role combinations specified in all_role_combinations (excl CoT role for non-reasoning models)
"""
train_params = yaml.safe_load(open('./../role-analysis/config/probe.yaml'))[model_prefix]['train_params']

all_role_combinations = [
    ('user', 'assistant'),
    ('user', 'assistant', 'tool'),
]

if model_prefix in ['gptoss-20b']:
    all_role_combinations += [
        ('user', 'cot', 'assistant'),
        ('user', 'cot', 'assistant', 'tool'),
        # ('system', 'user', 'assistant'),
        # ('system', 'user', 'assistant', 'tool'),
        ('system', 'user', 'cot', 'assistant'),
        ('system', 'user', 'cot', 'assistant', 'tool'),
    ]

all_role_combinations = [
    {
        'roles': list(roles),
        'roles_map': {x: i for i, x in enumerate(roles)},
        # Sample df for only those roles - filtering here is fine since we retain sample_ix which get_probe_result() uses to trace the original token
        'sample_df': probe_sample_df.pipe(lambda df: df[(df['role'].isin(roles)) & (df['token_in_seg_ix'] >= SKIP_FIRST_N)]).reset_index(drop = True),
    }
    for roles in all_role_combinations
]

clear_all_cuda_memory()
gc.collect()

all_probes = []
for roles_dict in all_role_combinations:
    print(f"Training {roles_dict['roles']}")
    for layer_ix in layers_to_probe:
        probe_res = get_probe_result(
            sample_df = roles_dict['sample_df'],
            layer_hs = all_probe_hs[layer_ix],
            roles_map = roles_dict['roles_map'],
            add_scaling = train_params['add_scaling'],
            C = train_params['C']
        )
        print(f"  Layer [{layer_ix}]: {probe_res['acc']:.2f}")

        all_probes.append({
            **probe_res,
            'layer_ix': layer_ix,
            'role_space': roles_dict['roles'],
            'roles_map': roles_dict['roles_map'],
            'n_inputs': len(roles_dict['sample_df'])
        })

print(f"Num probes: {str(len(all_probes))}")
print(f"Probe layers:\n  {', '.join([str(x) for x in layers_to_probe])}")

In [None]:
"""
Save probes + metrics
"""
def validate_accuracy_and_save(all_probes):
    """
    Calculate accurcay by (rolespace, layer, role) and save
    """
    print('Val accuracy by layer:')
    display(
        pd.DataFrame(all_probes)[['layer_ix', 'role_space', 'acc']]\
            .assign(acc = lambda df: df['acc'].round(2), role_space = lambda df: df['role_space'].apply(lambda x: ','.join(([r[0] for r in x]))))\
            .pivot(index = 'layer_ix', columns = 'role_space', values = 'acc')
    )

    print('Concatenating across layers and saving...')
    acc_by_role = pd.concat([p['acc_by_role'].assign(model = model_prefix, layer_ix = p['layer_ix'], role_space = ','.join(p['role_space'])) for p in all_probes], ignore_index = True)

    with open(f'{ws}/experiments/steer-test/outputs/probes/{model_prefix}.pkl', 'wb') as f:
        pickle.dump(all_probes, f)

    print('Accuracy by role:')
    base_sums = acc_by_role.groupby(['role', 'layer_ix', 'role_space'], as_index = False).agg(base_sum = ('count', 'sum'))

    acc_by_role =\
        acc_by_role\
        .pipe(lambda df: df[df['role'] == df['pred']])\
        .groupby(['role', 'layer_ix', 'role_space'], as_index = False)\
        .agg(sum = ('count', 'sum'))\
        .merge(base_sums, on = ['layer_ix', 'role_space', 'role'], how = 'inner')\
        .assign(acc = lambda df: df['sum']/df['base_sum'])\
        .pivot(index = ['role_space', 'layer_ix'], columns = 'role', values = 'acc')\
        .round(2)

    display(acc_by_role)

    return True

validate_accuracy_and_save(all_probes)