In [None]:
"""
Train probes
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Glm4vForConditionalGeneration, AutoModelForImageTextToText
from transformers.loss.loss_utils import ForCausalLMLoss
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import gc
import pickle
from termcolor import colored
import importlib
import cupy

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_outputs import convert_outputs_to_df_fast

main_device = 'cuda:0'
seed = 123

ws = '/workspace/deliberative-alignment-jailbreaks'

clear_all_cuda_memory()
check_memory()

# Load model

In [None]:
"""
Load the base tokenizer/model
"""
selected_model_index = 2

def get_model(index):
    """
    - HF model id, model prefix (short model identifier), model arch
    - Attn implementation, whether to use the HF default implementation, # hidden layers
    """
    models = {
        0: ('openai/gpt-oss-20b', 'gptoss-20b', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 24),
        1: ('openai/gpt-oss-120b', 'gptoss-120b', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 36),
        2: ('nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16', 'nemotron-3-nano', 'nemotron3', None, False, 52),
        3: ('Qwen/Qwen3-30B-A3B-Thinking-2507', 'qwen3-30b-a3b', 'qwen3moe', None, True, 48),
        4: ('ai21labs/AI21-Jamba-Reasoning- 3B', 'jamba-reasoning', 'jamba', None, True, 28),
        5: ('ServiceNow-AI/Apriel-1.6-15b-Thinker', 'apriel-1.6-15b-thinker', 'apriel', None, True, 48),
        6: ('allenai/Olmo-3-7B-Think', 'olmo3-7b-think', 'olmo3', None, True, 32),
        7: ('zai-org/GLM-4.6V-Flash', 'glm-4.6v-flash', 'glm46v', None, True, 40),
        8: ('zai-org/GLM-4.7-Flash', 'glm-4.7-flash', 'glm4moelite', None, True, 46)
    }
    return models[index]

def load_model_and_tokenizer(model_id, model_architecture, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
    load_params = {'cache_dir': cache_dir, 'dtype': 'auto', 'trust_remote_code': not model_use_hf, 'device_map': None, 'attn_implementation': model_attn}    
    if model_architecture == 'glm46v':
        model = Glm4vForConditionalGeneration.from_pretrained(model_id, **load_params).to(main_device).eval()
    elif model_architecture == 'apriel':
        model = AutoModelForImageTextToText.from_pretrained(model_id, **load_params).to(main_device).eval()
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id, **load_params).to(main_device).eval()

    return tokenizer, model

model_id, model_prefix, model_architecture, model_attn, model_use_hf, model_n_layers = get_model(selected_model_index)
tokenizer, model = load_model_and_tokenizer(model_id, model_architecture, model_attn, model_use_hf)

if model_architecture == 'gptoss':
    print(model.model.layers[0].mlp.experts.down_proj) # Precision should be MXFP4
    print(model.model.config._attn_implementation) # Attn should be FA3

if tokenizer.pad_token is None:
    print('Setting pad token automatically')
    tokenizer.pad_token = tokenizer.eos_token

check_memory()

In [None]:
"""
Load reverse-engineered forward pass functions (usage note - this can be replaced by simpler hooks if desired)
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

@torch.no_grad()
def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 640).to(model.device)
    original_results = model(**inputs, use_cache = False)
    custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = True)
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), custom_results['logits'].size(-1)).detach().cpu().item()
    print(f"LM loss: {loss}")
    print(f"Hidden states layers (pre-mlp | post-layer): {len(custom_results['all_pre_mlp_hidden_states'])} | {len(custom_results['all_hidden_states'])}")
    print(f"Hidden state size (pre-mlp | post-layer): {(custom_results['all_pre_mlp_hidden_states'][0].shape)} | {(custom_results['all_hidden_states'][0].shape)}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

# Load probes

In [None]:
TEST_MODEL = 'gptoss-20b'
TEST_LAYER_IX = 12
TEST_ROLE_SPACE = ['system', 'user', 'cot', 'assistant']

with open(f'{ws}/experiments/role-analysis/outputs/probes/{TEST_MODEL}.pkl', 'rb') as f:
    probes = pickle.load(f)

probe = [p for p in probes if p['layer_ix'] == TEST_LAYER_IX and p['role_space'] == TEST_ROLE_SPACE][0]
probe

# Create dataset and collect states

In [None]:
"""
Define a lexicon
"""
MAX_ROWS = 25

test_prefix = yaml.safe_load(open('experiment-config/probe.yaml'))[model_prefix]['test_prefix']
raw_convs = pd.read_csv(f'{ws}/experiments/role-analysis/data/conversations/{TEST_MODEL}.csv')

lexicon_df = pd.DataFrame([
    {'role': role, 'content': test_prefix + ' ' + ' '.join([row[col]] * 20)}
    for row in raw_convs.head(MAX_ROWS).to_dict('records')
    for role, col in [('user', 'user_query'), ('cot', 'cot'), ('assistant', 'assistant')]
])

lexicon_df

In [None]:
"""
Helper for running + storing states
"""
@torch.no_grad()
def run_and_export_states(model, dl: ReconstructableTextDataset, layers_to_keep_acts: list[int]):
    """
    Run forward passes on given model and store the decomposed sample_df plus hidden states

    Params:
        @model: The model to run forward passes on via `run_model_return_topk`. Should return a dict with keys `logits` and `all_hidden_states`.
        @dl: A ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, `original_tokens`, and `prompt_ix`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_hidden_states` (see returned object description).

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `all_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.

    Example:
        test_outputs = run_and_export_states(model, train_dl, layers_to_keep_acts = list(range(model_n_layers)))
    """
    all_hidden_states = []
    sample_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']
        prompt_indices = batch['prompt_ix']

        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

        # Check no bugs by validating output/perplexity
        if batch_ix == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), output['logits'].size(-1)).detach().cpu().item()
            for i in range(min(20, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
                
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )
                
        prompt_indices_df = pd.DataFrame(
            [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
            columns = ['sequence_ix', 'prompt_ix']
        )
        
        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
            .assign(batch_ix = batch_ix)
        
        sample_dfs.append(sample_df)

        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

    sample_df = pd.concat(sample_dfs, ignore_index = True).drop(columns = ['batch_ix', 'sequence_ix']) # Drop batch/seq_ix, since prompt_ix identifies
    all_hidden_states = torch.cat(all_hidden_states, dim = 0)

    return {
        'sample_df': sample_df,
        'all_hs': all_hidden_states
    }


In [None]:
"""
Prep
"""
MAX_INPUT_LEN = 768

input_dl = DataLoader(
    ReconstructableTextDataset(lexicon_df['content'].tolist(), tokenizer, max_length = MAX_INPUT_LEN, prompt_ix = list(range(0, len(lexicon_df)))),
    batch_size = 8,
    shuffle = False,
    collate_fn = stack_collate
)

run_result = run_and_export_states(model, input_dl, layers_to_keep_acts = [TEST_LAYER_IX])
hs = run_result['all_hs'].to(torch.float16)
hs = {layer_ix: hs[:, save_ix, :] for save_ix, layer_ix in enumerate([TEST_LAYER_IX])}
sample_df = run_result['sample_df'].assign(sample_ix = lambda df: range(0, len(df)))
gc.collect()

In [None]:
"""
Role space projections
"""
def run_projections(valid_sample_df, layer_hs, probe):
    """
    Run probe-level projections
    
    Params:
        @valid_sample_df: A sample-level df with columns `sample_ix` (1... T - 1), `sample_ix`.
            Can be shorter than full T - 1 due to pre-filters, as long as sample_ix corresponds to the full length.
        @layer_hs: A tensor of size T x D for the layer to project.
        @probe: The probe dict with keys `probe` (the trained model) and `role_space` (the roles list)
    
    Returns:
        A df at (sample_ix, target_role) level with cols `sample_ix`, `target_role`, `prob`
    """
    x_cp = cupy.asarray(layer_hs[valid_sample_df['sample_ix'].tolist(), :])
    y_cp = probe['probe'].predict_proba(x_cp).round(12)

    proj_results = pd.DataFrame(cupy.asnumpy(y_cp), columns = probe['role_space'])
    if len(proj_results) != len(valid_sample_df):
        raise Exception("Error!")

    role_df =\
        pd.concat([
            proj_results.reset_index(drop = True),
            valid_sample_df[['sample_ix']].reset_index(drop = True)
        ], axis = 1)\
        .melt(id_vars = ['sample_ix'], var_name = 'target_role', value_name = 'prob')\
        .reset_index(drop = True)\
        .assign(prob = lambda df: df['prob'].round(4))

    return role_df

output_projections =\
    run_projections(valid_sample_df = sample_df, layer_hs = hs[TEST_LAYER_IX], probe = probe)\
    .merge(
        sample_df[['prompt_ix', 'sample_ix', 'token']].assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount()), 
        how = 'inner',
        on = ['sample_ix']
    )\
    .merge(lexicon_df.assign(prompt_ix = lambda df: range(0, len(df))), how = 'inner', on = 'prompt_ix')\
    .assign(rollprob = lambda df: df.groupby(['prompt_ix', 'target_role'])['prob'].transform(lambda x: x.ewm(alpha = 0.25).mean()))

output_projections


In [None]:
target_cols = output_projections['target_role'].unique().tolist()

output_agg =\
    output_projections\
    .groupby(['token_in_prompt_ix', 'target_role', 'role'], as_index = False)\
    .agg(mean_rollprob = ('rollprob', lambda x: x.mean()))
    
output_agg\
    .pivot(index = ['role', 'token_in_prompt_ix'], columns = 'target_role', values = 'mean_rollprob')\
    .reset_index()\
    .pipe(lambda df: df[df['token_in_prompt_ix'] >= MAX_INPUT_LEN - 10])

# .to_csv('lexical_role_projections.csv', index = False)

In [None]:
output_agg

In [None]:
import plotly.express as px

fig = px.line(output_agg, x = 'token_in_prompt_ix', y = 'mean_rollprob', title = 'role', color = 'target_role', facet_col = 'role')

fig.show()

# User PI Analysis