In [None]:
"""
Splice together conversations
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Qwen3VLMoeForConditionalGeneration
from transformers.loss.loss_utils import ForCausalLMLoss
import pandas as pd
import numpy as np
import cupy
import cuml
import importlib
import gc
import yaml
import pickle
import os
from tqdm import tqdm
import plotly.express as px
from termcolor import colored
from datasets import load_dataset
import re
import random

from utils.memory import check_memory, clear_all_cuda_memory
from utils.role_assignments import label_content_roles
from utils.store_outputs import convert_outputs_to_df_fast

main_device = 'cuda:0'
seed = 1234

clear_all_cuda_memory()
check_memory()

ws = '/workspace/deliberative-alignment-jailbreaks'

# Load models & probes

## Load model

In [None]:
"""
Load the base tokenizer/model
"""
selected_model_index = 0

def get_model(index):
    # HF model ID, model prefix, model architecture,  attn implementation, whether to use hf lib implementation
    models = {
        0: ('openai/gpt-oss-20b', 'gptoss20', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 24),
        1: ('openai/gpt-oss-120b', 'gptoss120', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 36), # Will load experts in MXFP4 if triton kernels installed
        2: ('Qwen/Qwen3-VL-30B-A3B-Thinking', 'qwen3vl30ba3b', 'qwen3vlmoe', None, True, 48),
        # 2: ('Qwen/Qwen3-30B-A3B-Thinking-2507', 'qwen3-30b-a3b', 'qwen3moe', None, True, 48),
        3: ('zai-org/GLM-4.5-Air-FP8', 'glm45air', 'glm4moe', None, True, 45) # GLM-4.5 has one dense pre-layer, this is MoE layers
    }
    return models[index]

def load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
        
    if model_architecture == 'qwen3vlmoe':
        model = Qwen3VLMoeForConditionalGeneration.from_pretrained(model_id, cache_dir = cache_dir, dtype = 'auto', trust_remote_code = not model_use_hf, device_map = None, attn_implementation = model_attn).to(main_device).eval()
    else:
        # Monkey patch for transformers v5.0.0 bug with GLM-4.5-Air-FP8
        if model_architecture == 'glm4moe':
            from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeNaiveMoe
            from utils.pretrained_models.glm4moe import patched_naive_moe_forward
            Glm4MoeNaiveMoe.forward = patched_naive_moe_forward
        model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir = cache_dir, dtype = 'auto', trust_remote_code = not model_use_hf, device_map = None, attn_implementation = model_attn).to(main_device).eval()
    
    return tokenizer, model

model_id, model_prefix, model_architecture, model_attn, model_use_hf, model_n_layers = get_model(selected_model_index)
tokenizer, model = load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf)

## Load probes

In [None]:
"""
Load probes
"""
with open(f'{ws}/experiments/role-analysis/probes/{model_prefix}.pkl', 'rb') as f:
    probes = pickle.load(f)

# Fix role names
role_map = {'assistant-cot': 'cot', 'assistant-final': 'assistant'}
for probe in probes:
    probe['roles'] = tuple(role_map.get(r, r) for r in probe['roles'])

probe_layers = sorted(set(x['layer_ix'] for x in probes))
probe_roles = sorted(set(tuple(x['roles']) for x in probes))

print(f"Num probes: {str(len(probes))}")
print(f"Probe layers:\n  {', '.join([str(x) for x in probe_layers])}")
print(f"Probe roles:\n {'\n '.join([str(list(x)) for x in probe_roles])}")

# probes

## Load tokenizer

In [None]:
"""
Tests two important functions:
- `load_chat_template`: Overwrites the existing chat template in the tokenizer with one that better handles discrepancies.
    See docstring for the function for details. 
- `label_content_roles`: Takes an instruct-formatted text, then assigns each token to the correct roles.
"""
import importlib
import utils.role_assignments
import utils.role_templates
importlib.reload(utils.role_assignments)
label_content_roles = utils.role_assignments.label_content_roles
importlib.reload(utils.role_templates)
load_chat_template = utils.role_templates.load_chat_template

# Load custom chat templater
old_chat_template = tokenizer.chat_template
new_chat_template = load_chat_template(f'{ws}/utils/chat_templates', model_architecture)

def test_chat_template(tokenizer):
    s = tokenizer.apply_chat_template(
        [
            {'role': 'system', 'content': 'Test.'},
            {'role': 'user', 'content': 'Hi! I\'m a dog.'},
            {'role': 'assistant', 'content': '<think>The user is a dog!</think>Congrats!'},
            {'role': 'user', 'content': 'Thanks!'},
            {'role': 'assistant', 'content': '<think>Hmm, the user said thanks!</think>Anything else?'}
        ],
        tokenize = False,
        padding = 'max_length',
        truncation = True,
        max_length = 512,
        add_generation_prompt = False
    )
    print(s)
    return s

tokenizer.chat_template = old_chat_template
s = test_chat_template(tokenizer)
tokenizer.chat_template = new_chat_template
s = test_chat_template(tokenizer)

z =\
    pd.DataFrame({'input_ids': tokenizer(s)['input_ids']})\
    .assign(token = lambda df: tokenizer.convert_ids_to_tokens(df['input_ids']))\
    .assign(prompt_ix = 0, batch_ix = 0, sequence_ix = 0, token_ix = lambda df: df.index)

label_content_roles(model_architecture, z).tail(50)

## Activation collection helper

In [None]:
"""
Run forward passes
"""
import importlib
import utils.dataset
importlib.reload(utils.dataset)

from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

@torch.no_grad()
def run_and_export_states(model, dl: ReconstructableTextDataset, layers_to_keep_acts: list[int]):
    """
    Run forward passes on given model and store the decomposed sample_df plus hidden states

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits` and `all_pre_mlp_hidden_states`.
        @dl: A ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_pre_mlp_hidden_states` (see returned object description).

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `all_pre_mlp_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.

    Example:
        test_outputs = run_and_export_states(model, test_dl, layers_to_keep_acts = list(range(model_n_layers)))
    """
    all_pre_mlp_hidden_states = []
    sample_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']
        prompt_indices = batch['prompt_ix']

        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

        # Check no bugs by validating output/perplexity
        if batch_ix == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), output['logits'].size(-1)).detach().cpu().item()
            for i in range(min(20, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
                
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )
                
        prompt_indices_df = pd.DataFrame(
            [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
            columns = ['sequence_ix', 'prompt_ix']
        )
        
        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
            .assign(batch_ix = batch_ix)
        
        sample_dfs.append(sample_df)

        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_pre_mlp_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

        batch_ix += 1

    sample_df = pd.concat(sample_dfs, ignore_index = True).drop(columns = ['batch_ix', 'sequence_ix']) # Drop batch/seq_ix, since prompt_ix identifies
    all_pre_mlp_hidden_states = torch.cat(all_pre_mlp_hidden_states, dim = 0)

    return {
        'sample_df': sample_df,
        'all_pre_mlp_hs': all_pre_mlp_hidden_states
    }

def chunk_list(input_list, max_length):
    return [input_list[i:i + max_length] for i in range(0, len(input_list), max_length)]

# Conversational Analysis

In [None]:
import textwrap

if model_architecture == 'gptoss':
    system_prompt = textwrap.dedent("""
    <|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
    Knowledge cutoff: 2024-06
    Current date: 2025-06-28

    Reasoning: medium

    # Valid channels: analysis, commentary, final. Channel must be included for every message.
    Calls to these tools must go to the commentary channel: 'functions'.<|end|>
    """).strip()

print(system_prompt)

## Prep data

In [None]:
"""
Import conversations
"""
filename = model_id.split('/')[-1].lower()
user_queries_df = pd.read_csv(f'{ws}/experiments/role-analysis/convs/{filename}.csv')

print(f"Starting length: {len(user_queries_df['conv_id'].unique())} convs")

# 1. Remove convs with any missing responses
convs_with_missing = user_queries_df[user_queries_df[['user_query', 'cot', 'assistant']].isna().any(axis = 1)]['conv_id'].unique()
user_queries_df = user_queries_df[~user_queries_df['conv_id'].isin(convs_with_missing)]

print(f"After dropping missing: {len(user_queries_df['conv_id'].unique())} convs")

# 2. Filter out super long convs + convs with any message <= conv_min_length (to avoid issues with flag_message_type fn)
user_queries_df =\
    user_queries_df\
    .assign(total_len = lambda df:
        df.groupby('conv_id')['user_query'].transform(lambda x: x.str.len().sum()) +
        df.groupby('conv_id')['cot'].transform(lambda x: x.str.len().sum()) +
        df.groupby('conv_id')['assistant'].transform(lambda x: x.str.len().sum())
    )\
    .assign(min_len = lambda df: df[['user_query', 'cot', 'assistant']].apply(lambda x: x.str.len()).min(axis=1))\
    .assign(conv_min_len = lambda df: df.groupby('conv_id')['min_len'].transform('min'))\
    .query('total_len < 12000 and conv_min_len >= 5')\
    .drop(columns = ['total_len', 'min_len', 'conv_min_len'])

print(f"After dropping long convs: {len(user_queries_df['conv_id'].unique())} convs")

# 3. Filter out convs where any message is a substring of another message
#    (this causes dupe issues with the flag_message_type fn, which does not handle tokens in multiple string contexts)
def has_substring_messages(group):
    contents = group['user_query'].tolist() + group['cot'].tolist() + group['assistant'].tolist()
    for i, a in enumerate(contents):
        for j, b in enumerate(contents):
            if i != j and a in b:
                return True
    return False

bad_convs = user_queries_df.groupby('conv_id').filter(has_substring_messages)['conv_id'].unique()
user_queries_df = user_queries_df.pipe(lambda df: df[~df['conv_id'].isin(bad_convs)])
print(f"After dropping substr: {len(user_queries_df['conv_id'].unique())} convs")

# Combine into conv-level df
convs_df =\
    user_queries_df\
    .sort_values(['conv_id', 'user_query_ix'])\
    .assign(conv_id = lambda df: df.groupby('conv_id', sort = False).ngroup())\
    .groupby('conv_id')\
    .apply(lambda g: pd.Series({
        'dataset': g['dataset'].values[0],
        'messages': [
            msg
            for _, row in g.iterrows()
            for msg in [
                {'role': 'user', 'content': row['user_query']},
                {'role': 'cot', 'content': row['cot']},
                {'role': 'assistant', 'content': row['assistant']}
            ]
        ]
    }))\
    .reset_index()\
    [['conv_id', 'dataset', 'messages']]\
    .sample(n = 50)\
    .assign(conv_id = lambda df: range(len(df)))

print(f"Final: {len(convs_df['conv_id'].unique())} convs")

convs_df
convs = convs_df['messages'].tolist()

messages_df =\
    convs_df\
    .explode('messages')\
    .assign(message_ix = lambda df: df.groupby('conv_id').cumcount())\
    .assign(
        role = lambda df: df['messages'].apply(lambda x: x['role']),
        content = lambda df: df['messages'].apply(lambda x: x['content'])
    )\
    .drop(columns = 'messages')\
    .reset_index(drop = True)

convs = convs_df['messages'].tolist()

display(convs_df.head(5))
display(messages_df.head(5))
print(f"Total convs: {len(convs_df)}")

In [None]:
"""
Prep untagged conversations
"""
all_convs_by_type = {}

# Prep conversations
def prep_conv(conv):
    return system_prompt + '\n\n'.join([x['content'] for x in conv])

all_convs_by_type['untagged'] = [prep_conv(conv) for conv in convs]
print(all_convs_by_type['untagged'][0])

# Find max input length
max_input_length =\
    tokenizer(all_convs_by_type['untagged'], padding = True, truncation = True, max_length = 1024 * 8, return_tensors = 'pt')\
    ['attention_mask'].sum(dim = 1).max().item()

print(f"Max input length: {max_input_length}")
print(f"Untagged convs: {len(all_convs_by_type['untagged'])}")

In [None]:
"""
Prep tagged conversations
"""
def fold_cot_into_final(convs):
    """Fold CoT into the following assistant message as a <think></think> tag."""
    result = []
    for conv in convs:
        new_conv = []
        it = iter(enumerate(conv))
        for i, msg in it:
            if msg['role'] == 'cot':
                if i + 1 >= len(conv) or conv[i + 1]['role'] != 'assistant':
                    raise ValueError("cot must be followed by assistant")
                _, next_msg = next(it)
                new_conv.append({
                    'role': 'assistant',
                    'content': f"<think>{msg['content']}</think>{next_msg['content']}"
                })
            else:
                new_conv.append({'role': msg['role'], 'content': msg['content']})
        result.append(new_conv)
    return result

final_convs = fold_cot_into_final(convs)
all_convs_by_type['tagged'] = [
    system_prompt + x 
    for x in tokenizer.apply_chat_template(final_convs, tokenize = False, add_generation_prompt = False)
]
print(all_convs_by_type['tagged'][0])

# Find max input length
max_input_length =\
    tokenizer(all_convs_by_type['tagged'], padding = True, truncation = True, max_length = 1024 * 8, return_tensors = 'pt')\
    ['attention_mask'].sum(dim = 1).max().item()

print(f"Max input length: {max_input_length}")
print(f"Tagged convs: {len(all_convs_by_type['tagged'])}")

In [None]:
"""
Run forward passes
"""
# Get input df
input_convs_df =\
    pd.concat([
        pd.DataFrame({'convs': v, 'conv_type': conv_type})\
            .assign(conv_id = lambda df: range(0, len(df)))\
            .merge(convs_df, how = 'inner', on = 'conv_id')
        for conv_type, v in all_convs_by_type.items()
    ], ignore_index = True)\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))
    
display(input_convs_df)

convs_dl = DataLoader(
    ReconstructableTextDataset(input_convs_df['convs'].tolist(), tokenizer, max_length = 1024 * 4, prompt_ix = list(range(0, len(input_convs_df)))),
    batch_size = 8,
    shuffle = False,
    collate_fn = stack_collate
)

convs_outputs = run_and_export_states(model, convs_dl, layers_to_keep_acts = probe_layers)

In [None]:
"""
Label messages
"""
from utils.substring_assignments import flag_message_types

sample_dfs_by_conv = [group for _, group in convs_outputs['sample_df'].groupby('prompt_ix', sort = True)]
metadata_by_conv = input_convs_df.to_dict('records')

all_res = []

# Iterate through (input metadata, token-level sample df) pairs
for conv_metadata, sample_df_for_conv in tqdm(zip(metadata_by_conv, sample_dfs_by_conv)):

    content_spans = [msg['content'] for msg in conv_metadata['messages']]
    content_roles = [msg['role'] for msg in conv_metadata['messages']]
    try:
        res = \
            flag_message_types(sample_df_for_conv, content_spans)\
            .merge(
                pd.DataFrame({'role': content_roles, 'base_message_ix': range(len(content_spans))}),
                on = 'base_message_ix',
                how = 'left'
            )
        
    except Exception as e:
        print(e)
        continue
    
    all_res.append(res)

display(pd.concat(all_res).pipe(lambda df: df[df['role'].isna()]))
        
sample_df_labeled =\
    pd.concat(all_res).reset_index(drop = True)\
    .assign(sample_ix = lambda df: range(0, len(df)))\
    .assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount())

sample_df_labeled

In [None]:
"""
Clean up conversation pre_mlp_hs
"""
convs_pre_mlp_hs = convs_outputs['all_pre_mlp_hs'].to(torch.float16)
# del conv_outputs
convs_pre_mlp_hs = {layer_ix: convs_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(probe_layers)} # Match layers_to_keep_act

print(convs_pre_mlp_hs[0].shape)
print(len(sample_df_labeled))
gc.collect()

## Run projections

In [None]:
pd.DataFrame(probes)\
    .pipe(lambda df: df[df['roles'] == ('user', 'cot', 'assistant', 'tool')])\
    .sort_values(by = 'accuracy', ascending = False)

In [None]:
"""
Run projections
"""
test_layer = 16 # probe_layers[int(np.ceil(0.5 * len(probe_layers)).item())]
test_roles = ['system', 'user', 'cot', 'assistant'] # system, user, cot, assistant, tool
probe = [x for x in probes if sorted(x['roles']) == sorted(test_roles) and x['layer_ix'] == test_layer][0]

project_sample_df = sample_df_labeled # Or drop non-roles
project_hs_cp = cupy.asarray(convs_pre_mlp_hs[test_layer][project_sample_df['sample_ix'].tolist(), :])
project_probs = probe['probe'].predict_proba(project_hs_cp).round(8)

proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
if len(proj_results) != len(project_sample_df):
    raise Exception("Error!")

role_level_df =\
    pd.concat([proj_results, project_sample_df[['sample_ix']]], axis = 1)\
    .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
    .merge(project_sample_df, on = 'sample_ix', how = 'inner')\
    .merge(input_convs_df[['prompt_ix', 'dataset', 'conv_id', 'conv_type']], on = 'prompt_ix')

## GET TOKEN-LEVEL IS IT ACCURATE?
# token_level_df

role_level_df

In [None]:
role_level_df\
    .pipe(lambda df: df[df['conv_type'] == 'tagged'])\
    .pipe(lambda df: df[~df['role'].isna()])\
    .pipe(lambda df: df[(df['token_in_prompt_ix'] >= 0) & (df['token_in_prompt_ix'] <= 10000)])\
    .groupby(['dataset', 'role_space', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_prob = ('prob', 'mean'))\
    .groupby(['dataset', 'role_space', 'role'], as_index = False)\
    .agg(mean_prob = ('mean_prob', 'mean'))\
    .pivot(index = ['role', 'dataset'], columns = 'role_space', values = 'mean_prob')

In [None]:
role_level_df\
    .pipe(lambda df: df[df['conv_type'] == 'untagged'])\
    .pipe(lambda df: df[~df['role'].isna()])\
    .pipe(lambda df: df[(df['token_in_prompt_ix'] >= 0) & (df['token_in_prompt_ix'] <= 10000)])\
    .groupby(['dataset', 'role_space', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_prob = ('prob', 'mean'))\
    .groupby(['dataset', 'role_space', 'role'], as_index = False)\
    .agg(mean_prob = ('mean_prob', 'mean'))\
    .pivot(index = ['role', 'dataset'], columns = 'role_space', values = 'mean_prob')

In [None]:
role_level_df\
    .pipe(lambda df: df[~df['role'].isna()])\
    .groupby(['role_space', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_prob = ('prob', 'mean'))\
    .pipe(lambda df: df[df['role_space'] == df['role']])\
    .assign(role = lambda df: 'acc_' + df['role'])\
    .pivot(index = ['prompt_ix'], columns = 'role', values = 'mean_prob')\
    .reset_index()\
    .rename_axis(columns = None)

In [None]:
"""
Run projections
"""
all_projs = []

project_sample_df = sample_df_labeled # Or drop non-roles

for probe in tqdm(probes):

    project_hs_cp = cupy.asarray(convs_pre_mlp_hs[probe['layer_ix']][project_sample_df['sample_ix'].tolist(), :])

    project_probs = probe['probe'].predict_proba(project_hs_cp).round(4)
    proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
    if len(proj_results) != len(project_sample_df):
        raise Exception("Error!")

    role_level_df =\
        pd.concat([proj_results, project_sample_df[['sample_ix']]], axis = 1)\
        .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
        .merge(project_sample_df, on = 'sample_ix', how = 'inner')\
        .merge(input_convs_df[['prompt_ix', 'dataset', 'conv_id', 'conv_type']], on = 'prompt_ix')\
        .assign(layer_ix = probe['layer_ix'])\
        .assign(roles = ','.join(probe['roles']))\
        .drop(columns = ['base_message'])

    all_projs.append(role_level_df)


In [None]:
all_projs_df =\
    pd.concat(all_projs)\
    .reset_index(drop = True)\
    .drop(columns = ['output_prob', 'output_id'])\
    .pipe(lambda df: df[~df['role'].isna()])

In [None]:
probes[1]

In [None]:
all_projs_df\
    .pipe(lambda df: df[df['roles'] == 'system,user,cot,assistant'])\
    .pipe(lambda df: df[df['role_space'] == df['role']])\
    .drop(columns = 'role_space')\
    .rename(columns = {'prob': 'acc'})\
    .groupby(['roles', 'layer_ix', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_acc = ('acc', 'mean'))\
    .groupby(['roles', 'layer_ix', 'role'], as_index = False)\
    .agg(mean_acc = ('mean_acc', 'mean'))\
    .pivot(index = ['roles', 'layer_ix'], columns = 'role', values = 'mean_acc')\
    .reset_index()\
    .rename_axis(columns = None)


In [None]:
all_projs_df\
    .pipe(lambda df: df[df['conv_type'] == 'tagged'])\
    .pipe(lambda df: df[df['roles'] == 'system,user,cot,assistant'])\
    .pipe(lambda df: df[df['role_space'] == df['role']])\
    .drop(columns = 'role_space')\
    .rename(columns = {'prob': 'acc'})\
    .groupby(['roles', 'layer_ix', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_acc = ('acc', 'mean'))\
    .groupby(['roles', 'layer_ix', 'role'], as_index = False)\
    .agg(mean_acc = ('mean_acc', 'mean'))\
    .pivot(index = ['roles', 'layer_ix'], columns = 'role', values = 'mean_acc')\
    .reset_index()\
    .rename_axis(columns = None)


In [None]:
all_projs_df\
    .pipe(lambda df: df[df['conv_type'] == 'untagged'])\
    .pipe(lambda df: df[df['roles'] == 'system,user,cot,assistant'])\
    .pipe(lambda df: df[df['role_space'] == df['role']])\
    .drop(columns = 'role_space')\
    .rename(columns = {'prob': 'acc'})\
    .groupby(['roles', 'layer_ix', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_acc = ('acc', 'mean'))\
    .groupby(['roles', 'layer_ix', 'role'], as_index = False)\
    .agg(mean_acc = ('mean_acc', 'mean'))\
    .pivot(index = ['roles', 'layer_ix'], columns = 'role', values = 'mean_acc')\
    .reset_index()\
    .rename_axis(columns = None)

In [None]:
role_level_df\
    .pipe(lambda df: df[~df['role'].isna()])\
    .groupby(['role_space', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_prob = ('prob', 'mean'))\
    .pipe(lambda df: df[(df['role'] == 'cot') & (df['prompt_ix'] == 99)])

In [None]:
print(input_convs_df.pipe(lambda df: df[df['prompt_ix'] == 3])['convs'].tolist()[0])

In [None]:
role_level_df\
    .pipe(lambda df: df[~df['role'].isna()])\
    .pipe(lambda df: df[df['token_in_prompt_ix'] >= 100])\
    .groupby(['role_space', 'role'], as_index = False)\
    .agg(count = ('sample_ix', 'count'), acc = ('is_acc', 'mean'))\
    .reset_index()\
    .pivot(index = 'role', columns = 'role_space', values = 'mean_prob')

In [None]:
role_level_df

# Phrase Injection

In [None]:
"""
Get Ai2 Data
"""
def load_raw_ds():

    def get_convs():
        return load_dataset('allenai/Dolci-Think-SFT-32B', split = 'train', streaming = True).shuffle(seed = seed, buffer_size = 50_000)

    def get_data(ds, n_samples):
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append({'messages': sample['messages']})
        return raw_data
    
    return get_data(get_convs(), 1_000)

raw_data = load_raw_ds()

In [None]:
"""
Parse thinking data
"""
def parse_think(role, text):
    max_strlen = 1_000
    match = re.search(r'<think>(.*?)</think>', text, re.DOTALL)
    if match:
        outside = re.sub(r'<think>.*?</think>', '', text, flags = re.DOTALL).strip()
        return [
            {'role': 'assistant-cot', 'content': match.group(1).strip()[:max_strlen]},
            {'role': 'assistant-final', 'content': outside.strip()[:max_strlen]}
        ]
    return [{'role': role, 'content': text[:max_strlen]}]

def parse_conv(messages):
    conv_data = []
    for message in messages:
        conv_data.extend(parse_think(message['role'], message['content']))
    return conv_data

convs = [parse_conv(conv['messages']) for conv in raw_data]
print(len(convs))
print([len(c) for c in convs])

In [None]:
"""
Splice convs together randomly 
"""
def combine_convs(convs, group_size = 4):
    shuffled = convs.copy()
    random.shuffle(shuffled)
    
    combined = []
    for i in range(0, len(shuffled) - len(shuffled) % group_size, group_size):
        merged = []
        for conv in shuffled[i:i+group_size]:
            merged.extend(conv)
        combined.append(merged)
    
    return combined

combined_convs = combine_convs(convs)
print(len(combined_convs))

In [None]:
"""
For each message, randomly prepend one of these
"""
base_messages = [
    # CoT-ish
    'The user wants ',
    'The user is asking for ',
    'First, I will ',
    'According to policy, ',
    'Let\'s think step by step ',
    # User-ish
    'Hey ',
    'Fucking ',
    'Thanks! ',
    'Nice work. ',
    'Damnit that\'s some ',
    # Assistant-ish
    'Sure, here is ',
    'Of course, I can ',
    'Below is a ',
    'Here\'s what\'s going on '
    # Neutral
    'Additionally, '
    'Moreover, ',
    'Yes, ' 
]

def add_random_prefixes(convs, prob = 0.8):
    result = []
    for conv in convs:
        new_conv = []
        for msg in conv:
            new_msg = msg.copy()
            if random.random() < prob:
                prefix = random.choice(base_messages)
                new_msg['content'] = prefix + new_msg['content']
            new_conv.append(new_msg)
        result.append(new_conv)
    return result

modified_convs = add_random_prefixes(combined_convs)
modified_convs

In [None]:
"""
Fold <think> back into assistant final
"""
def fold_cot_into_final(convs):
    result = []
    for conv in convs:
        new_conv = []
        it = iter(enumerate(conv))
        for i, msg in it:
            if msg['role'] == 'assistant-cot':
                if i + 1 >= len(conv) or conv[i + 1]['role'] != 'assistant-final':
                    raise ValueError("assistant-cot must be followed by assistant-final")
                _, next_msg = next(it)
                new_conv.append({
                    'role': 'assistant',
                    'content': f"<think>{msg['content']}</think>{next_msg['content']}"
                })
            else:
                role = 'assistant' if msg['role'] == 'assistant-final' else msg['role']
                new_conv.append({'role': role, 'content': msg['content']})
        result.append(new_conv)
    return result

final_convs = fold_cot_into_final(modified_convs)
final_convs

In [None]:
"""
Instruct formatting
"""
if 'gptoss' in model_prefix:
    print('Adding system prompt')
    final_convs_with_sys = [
        [{'role': 'system', 'content': "You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2026-01-01\nReasoning: medium\n# Valid channels: analysis, commentary, final. Channel must be included for every message."}] +\
        c for c in final_convs
    ]
else:
    final_convs_with_sys = final_convs

convs_instruct = tokenizer.apply_chat_template(final_convs_with_sys, tokenize = False, add_generation_prompt = False)
convs_instruct[0]

In [None]:
print(f"Max length: {tokenizer(convs_instruct, padding = True, return_tensors = 'pt')['attention_mask'].sum(dim = 1).max().item()}")

In [None]:
print(f"Mean length: {tokenizer(convs_instruct, padding = True, return_tensors = 'pt')['attention_mask'].sum(dim = 1).float().mean().item()}")

In [None]:
""" 
Load dataset into a dataloader. The dataloader returns the original tokens - this is important for BPE tokenizers as otherwise it's difficult to reconstruct the correct string later!
"""
from torch.utils.data import DataLoader

def chunk_list(input_list, max_length):
    return [input_list[i:i + max_length] for i in range(0, len(input_list), max_length)]

# Create and chunk into lists of size 500 each - these will be the export breaks
conv_dl = DataLoader(
    ReconstructableTextDataset(convs_instruct, tokenizer, max_length = 1024 * 5, prompt_ix = list(range(0, len(convs_instruct)))),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

# Run
conv_outputs = run_and_export_states(model, conv_dl, layers_to_keep_acts = probe_layers)

In [None]:
"""
Clean up conversation pre_mlp_hs
"""
conv_pre_mlp_hs = conv_outputs['all_pre_mlp_hs'].to(torch.float16)
conv_pre_mlp_hs = {layer_ix: conv_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(probe_layers)}
conv_pre_mlp_hs

In [None]:
"""
Clean up - get back sample_df and all_pre_mlp_hs
"""
# from utils.substring_assignments import flag_message_types
import importlib
import utils.substring_assignments
importlib.reload(utils.substring_assignments)
from utils.substring_assignments import flag_message_types

# First, get token-level dataframe
display(conv_outputs['sample_df'].head(50))

proj_df =\
    flag_message_types(conv_outputs['sample_df'], [b.strip() for b in base_messages])\
    .assign(sample_ix = lambda df: list(range(0, len(df))))\
    .pipe(lambda df: df[~df['base_message_ix'].isna()])\
    .reset_index(drop = True)

proj_layer = 12 # probe_layers[int(np.ceil(0.5 * len(probe_layers)).item()) - 1]
proj_roles = ['system', 'user', 'assistant-cot', 'assistant-final'] # system, user, assistant-cot, assistant-final, tool
proj_probe = [x for x in probes if sorted(x['roles']) == sorted(proj_roles) and x['layer_ix'] == proj_layer][0]

proj_hs_cp = cupy.asarray(conv_pre_mlp_hs[proj_layer][proj_df['sample_ix'].tolist(), :])
proj_probs = proj_probe['probe'].predict_proba(proj_hs_cp).round(8)
proj_probs = pd.DataFrame(cupy.asnumpy(proj_probs), columns = proj_probe['roles']).clip(1e-6)

if len(proj_probs) != len(proj_df):
    raise Exception("Error!")

final_projections =\
    pd.concat([proj_df['sample_ix'], proj_probs], axis = 1)\
    .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
    .merge(proj_df, on = 'sample_ix', how = 'inner')

final_projections

In [None]:
    # .groupby(['base_message', 'base_message_ix', 'prompt_ix', 'token_ix', 'role_space'], as_index = False)\
    # .agg(max_token_prob = ('prob', 'median'))\

final_projections\
    .sort_values(by = ['prompt_ix', 'token_ix'])\
    .groupby(['base_message', 'base_message_ix', 'prompt_ix', 'token_ix', 'role_space'], as_index = False)\
    .agg(token_prob = ('prob', 'last'))\
    .groupby(['base_message', 'base_message_ix', 'role_space'], as_index = False)\
    .agg(avg_prob = ('token_prob', 'mean'), count = ('token_prob', 'count'))\
    .pivot(
        index = ['base_message', 'base_message_ix', 'count'], 
        columns = 'role_space', 
        values = 'avg_prob'
    )\
    .reset_index() # Flatten the index back into columns

In [None]:
final_projections\
    .groupby(['base_message', 'base_message_ix', 'prompt_ix', 'token_ix', 'role_space'], as_index = False)\
    .agg(max_token_prob = ('prob', 'max'))

In [None]:
pd.concat([proj_df['sample_ix'], proj_probs], axis = 1)\
    .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
    .merge(proj_df, on = 'sample_ix', how = 'inner')\
    .tail(1)\
    ['token'].tolist()

In [None]:
proj_probs

In [None]:
    flag_message_types(conv_outputs['sample_df'], base_messages)\
    .assign(sample_ix = lambda df: list(range(0, len(df))))

In [None]:
conv_pre_mlp_hs[0].shape

In [None]:
    flag_message_types(conv_outputs['sample_df'], base_messages)\
    .pipe(lambda df: df[~df['base_message_ix'].isna()])


In [None]:
proj_df

In [None]:
conv_outputs['sample_df'].head(50)['token'].tolist()

In [None]:
base_messages

In [None]:
conv_pre_mlp_hs

# Run and project test set

## Test tokenizer

In [None]:
"""
Tests two important functions:
- `load_chat_template`: Overwrites the existing chat template in the tokenizer with one that better handles discrepancies.
    See docstring for the function for details. 
- `label_content_roles`: Takes an instruct-formatted text, then assigns each token to the correct roles.
"""
import importlib
import utils.role_assignments
import utils.role_templates
importlib.reload(utils.role_assignments)
label_content_roles = utils.role_assignments.label_content_roles
importlib.reload(utils.role_templates)
load_chat_template = utils.role_templates.load_chat_template

# Load custom chat templater
old_chat_template = tokenizer.chat_template
new_chat_template = load_chat_template(f'{ws}/utils/chat_templates', model_architecture)

def test_chat_template(tokenizer):
    s = tokenizer.apply_chat_template(
        [
            {'role': 'system', 'content': 'Test.'},
            {'role': 'user', 'content': 'Hi! I\'m a dog.'},
            {'role': 'assistant', 'content': '<think>The user is a dog!</think>Congrats!'},
            {'role': 'user', 'content': 'Thanks!'},
            {'role': 'assistant', 'content': '<think>Hmm, the user said thanks!</think>Anything else I can help with?'}
        ],
        tokenize = False,
        padding = 'max_length',
        truncation = True,
        max_length = 512,
        add_generation_prompt = False
    )
    print(s)
    return s

tokenizer.chat_template = old_chat_template
s = test_chat_template(tokenizer)
tokenizer.chat_template = new_chat_template
s = test_chat_template(tokenizer)

z =\
    pd.DataFrame({'input_ids': tokenizer(s)['input_ids']})\
    .assign(token = lambda df: tokenizer.convert_ids_to_tokens(df['input_ids']))\
    .assign(prompt_ix = 0, batch_ix = 0, sequence_ix = 0, token_ix = lambda df: df.index)

label_content_roles(model_architecture, z).tail(50)

## Create instruction test data

In [None]:
"""
We'll test some projections of different sequences

To understand instruct formats for different models, see the HF chat template playground:
    - https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=openai%2Fgpt-oss-120b
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

def load_messages(yaml_path, model_key):
    """Loads system-like, assistant-like, user-like, etc. messages"""
    with open(yaml_path, 'r', encoding = 'utf-8') as f:
        data = yaml.safe_load(f)
    return [x['content'] for x in data[model_key]]

base_message_types = ['system', 'user', 'cot', 'assistant-final', 'user', 'cot', 'assistant-final']

base_messages = load_messages(f"{ws}/experiments/role-analysis/prompts/standard-conversations.yaml", model_prefix)

prompts = {}

prompts['basic_no_format'] = '\n'.join(base_messages)

prompts['everything_in_user_tags'] = tokenizer.apply_chat_template(
    [{'role': 'user', 'content': '\n'.join(base_messages)}],
    tokenize = False, add_generation_prompt = False
)

if model_architecture == 'gptoss':
    prompts['everything_in_assistant_tags'] = tokenizer.apply_chat_template(
        [{'role': 'assistant', 'content': '\n'.join(base_messages)}],
        tokenize = False, add_generation_prompt = False
    )
elif model_architecture == 'qwen3moe':
    prompts['everything_in_assistant_tags'] = tokenizer.apply_chat_template(
        [{'role': 'assistant', 'content': f"<think></think>{'\n'.join(base_messages)}"}],
        tokenize = False, add_generation_prompt = False
    )
else:
    raise Exception('Unsupported architecture!')

prompts['proper_tags'] = tokenizer.apply_chat_template(
    [
        {'role': 'system', 'content': base_messages[0]},
        {'role': 'user', 'content': base_messages[1]},
        {'role': 'assistant', 'content': f"<think>{base_messages[2]}</think>{base_messages[3]}"},
        {'role': 'user', 'content': base_messages[4]},
        {'role': 'assistant', 'content': f"<think>{base_messages[5]}</think>{base_messages[6]}"}
    ],
    tokenize = False, add_generation_prompt = False
)

print(prompts['proper_tags'])

test_input_df =\
    pd.DataFrame({
        'prompt': [p for _, p in prompts.items()],
        'prompt_key': [pk for pk, _ in prompts.items()]
    })\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))

# Create and chunk into lists of size 500 each - these will be the export breaks
test_dl = DataLoader(
    ReconstructableTextDataset(test_input_df['prompt'].tolist(), tokenizer, max_length = 512 * 4, prompt_ix = test_input_df['prompt_ix'].tolist()),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

In [None]:
print(prompts['proper_tags'])

## Collect test activations

In [None]:
"""
Run forward passes
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

@torch.no_grad()
def run_and_export_states(model, dl: ReconstructableTextDataset, layers_to_keep_acts: list[int]):
    """
    Run forward passes on given model and store the decomposed sample_df plus hidden states

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits`, `all_topk_experts`, `all_topk_weights`,
          `all_pre_mlp_hidden_states`, and `all_expert_outputs`.
        @dl: A ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_pre_mlp_hidden_states` (see returned object description).

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `all_pre_mlp_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.
    """
    all_pre_mlp_hidden_states = []
    sample_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']
        prompt_indices = batch['prompt_ix']

        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

        # Check no bugs by validating output/perplexity
        if batch_ix == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
            for i in range(min(20, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
                
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )
                
        prompt_indices_df = pd.DataFrame(
            [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
            columns = ['sequence_ix', 'prompt_ix']
        )
        
        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
            .assign(batch_ix = batch_ix)
        
        sample_dfs.append(sample_df)

        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_pre_mlp_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

        batch_ix += 1

    sample_df = pd.concat(sample_dfs, ignore_index = True)
    all_pre_mlp_hidden_states = torch.cat(all_pre_mlp_hidden_states, dim = 0)

    return {
        'sample_df': sample_df,
        'all_pre_mlp_hs': all_pre_mlp_hidden_states
    }

test_outputs = run_and_export_states(
    model,
    test_dl,
    layers_to_keep_acts = list(range(model_n_layers))
)

In [None]:
"""
Prepare sample df
"""
test_sample_df =\
    test_outputs['sample_df'].merge(test_input_df, on = 'prompt_ix', how = 'inner')\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount())

display(test_sample_df)

test_pre_mlp_hs = test_outputs['all_pre_mlp_hs'].to(torch.float16)
test_pre_mlp_hs = {layer_ix: test_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(list(range(model_n_layers)))}

In [None]:
"""
Map each token base to its base_message (currently its mapped to its prompt key, but it should also be base-message mapped!)
"""
def flag_message_types(sample_level_df, base_messages):
    """
    Take a token-level df with columns for prompt_ix, token; identify whether it's in one of base_messages
    
    Params:
        @sample_level_df: A token-level dataframe with columns for `prompt_ix` and `token`
        @base_messages: A list of base messages to match tokens to

    Returns:
        The original sample_level_df, with and additional column `base_message_ix` equal to the index
        of the matching base_message (None for no match). 
    """
    res = (
        sample_level_df\
        .sort_values(['prompt_ix', 'token_ix'])\
        .assign(
            _t1 = lambda d: d.groupby('prompt_ix')['token'].shift(-1),
            _t2 = lambda d: d.groupby('prompt_ix')['token'].shift(-2),
            _t3 = lambda d: d.groupby('prompt_ix')['token'].shift(-3),
            _t4 = lambda d: d.groupby('prompt_ix')['token'].shift(-4),
            _t5 = lambda d: d.groupby('prompt_ix')['token'].shift(-5),
            _t6 = lambda d: d.groupby('prompt_ix')['token'].shift(-6),
            _t7 = lambda d: d.groupby('prompt_ix')['token'].shift(-7),
            _t8 = lambda d: d.groupby('prompt_ix')['token'].shift(-8),
            _b1 = lambda d: d.groupby('prompt_ix')['token'].shift(1),
            _b2 = lambda d: d.groupby('prompt_ix')['token'].shift(2),
            _b3 = lambda d: d.groupby('prompt_ix')['token'].shift(3),
            _b4 = lambda d: d.groupby('prompt_ix')['token'].shift(4),
            _b5 = lambda d: d.groupby('prompt_ix')['token'].shift(5),
            _b6 = lambda d: d.groupby('prompt_ix')['token'].shift(6),
            _b7 = lambda d: d.groupby('prompt_ix')['token'].shift(7),
            _b8 = lambda d: d.groupby('prompt_ix')['token'].shift(8)
        )\
        .assign(
            has_roll = lambda d: d[['_t1','_t2','_t3','_t4','_t5', '_t6', '_t7', '_t8']].notna().all(axis = 1),
            has_back = lambda d: d[['_b1','_b2','_b3','_b4','_b5', '_b6', '_b7', '_b8']].notna().all(axis = 1),  
        )\
        .assign(
            tok_roll = lambda d: d['token'].fillna('') + d['_t1'].fillna('') + d['_t2'].fillna('') + d['_t3'].fillna('') + d['_t4'].fillna('') +
                d['_t5'].fillna('')  + d['_t6'].fillna('')  + d['_t7'].fillna('')  + d['_t8'].fillna(''),
            tok_back = lambda d: d['_b8'].fillna('') + d['_b7'].fillna('') + d['_b6'].fillna('') + d['_b5'].fillna('') + d['_b4'].fillna('') +
                d['_b3'].fillna('') + d['_b2'].fillna('') + d['_b1'].fillna('') + d['token'].fillna('')
        )\
        .pipe(lambda d: d.join(
            pd.concat(
                [
                    (
                        (d['has_roll'] & d['tok_roll'].apply(lambda s, t=t: s in t)) |
                        (d['has_back'] & d['tok_back'].apply(lambda s, t=t: s in t))
                        # d['tok_roll'].apply(lambda s, t=t: bool(s) and (s in t)) |
                        # d['tok_back'].apply(lambda s, t=t: bool(s) and (s in t))
                    ).rename(f'hit_p{i}')
                    for i, t in enumerate(base_messages)
                ],
                axis = 1
            )
        ))\
        .assign(
            base_message_ix = lambda d: np.select(
                [d[f'hit_p{i}'] for i in range(len(base_messages))],
                list(range(len(base_messages))),
                default = None
            ),
            ambiguous = lambda d: d[[f'hit_p{i}' for i in range(len(base_messages))]].sum(axis = 1) > 1,
            # base_name = lambda d: d['base_ix'].map({i: f"p{i}" for i in range(len(base_messages))})
        )\
        .drop(columns=['_t1','_t2','_t3','_t4', '_t5','_t6', '_t7', '_t8', '_b1','_b2','_b3','_b4','_b5', '_b6', '_b7', '_b8'])\
        .drop(columns = [f'hit_p{i}' for i in range(len(base_messages))])\
        .drop(columns = ['tok_roll', 'tok_back', 'ambiguous', 'batch_ix', 'sequence_ix', 'token_ix'])\
    )

    return res


test_sample_df_labeled =\
    flag_message_types(test_sample_df, base_messages)\
    [[
        'sample_ix',
        'prompt_ix', 'prompt_key', 'prompt',
        'token_in_prompt_ix',
        'base_message_ix',
        'token_id', 'token',
        'output_id', 'output_prob' ,
        # 'tok_roll', 'tok_back', 'ambiguous'
    ]]

display(test_sample_df_labeled)

## Project tests into rolespace

In [None]:
"""
Now project these test samples into rolespace, then flatten - result should be len(test_sample_df_labeled) * # roles
"""
test_layer = probe_layers[int(np.ceil(0.5 * len(probe_layers)).item())]
test_roles = ['system', 'user', 'assistant-cot', 'assistant-final'] # system, user, assistant-cot, assistant-final, tool
probe = [x for x in probes if sorted(x['roles']) == sorted(test_roles) and x['layer_ix'] == test_layer][0]

project_test_sample_df = test_sample_df_labeled
project_hs_cp = cupy.asarray(test_pre_mlp_hs[test_layer][project_test_sample_df['sample_ix'].tolist(), :])
project_probs = probe['probe'].predict_proba(project_hs_cp).round(6)

# Merge seq probs withto get sampel_ix
proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
if len(proj_results) != len(project_test_sample_df):
    raise Exception("Error!")

role_level_df = pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)

role_level_df =\
    role_level_df.melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
    .merge(
        project_test_sample_df[['sample_ix', 'prompt_ix', 'prompt_key', 'prompt', 'token_in_prompt_ix', 'base_message_ix', 'token']],
        on = 'sample_ix',
        how = 'inner'
    )\
    .merge(
        pd.DataFrame({'base_message_ix': list(range(len(base_message_types))), 'base_message_type': base_message_types}),
        on = 'base_message_ix',
        how = 'left'
    )\
    .assign(base_message_type = lambda df: df['base_message_type'].fillna('other'))\
    .pipe(lambda df: df[df['base_message_type'] != 'other']) # Cull the role tags themselves

role_level_df

In [None]:
"""
Write for plotting
"""
role_level_df.to_csv(f"{ws}/experiments/role-analysis/projections/test-role-projections-{model_prefix}.csv", index = False)

## Plot and visualize

In [None]:
"""
Plot tests
"""
facet_order = probe['roles']
all_message_types = role_level_df['base_message_type'].unique()
color_map = {msg_type: px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)] for i, msg_type in enumerate(all_message_types)}

smooth_df =\
    role_level_df\
    .sort_values('sample_ix')\
    .assign(
        prob_sma = lambda df: df.groupby(['prompt_ix', 'role_space'])['prob'].rolling(window = 2, min_periods = 1).mean().reset_index(level = [0, 1], drop = True),
        prob_ewma = lambda df: df.groupby(['prompt_ix', 'role_space'])['prob'].ewm(alpha = .9, min_periods = 1).mean().reset_index(level = [0, 1], drop = True)
    )

display(smooth_df)

for this_prompt_key in smooth_df['prompt_key'].unique().tolist():
    this_df = smooth_df.pipe(lambda df: df[df['prompt_key'] == this_prompt_key])
    fig = px.scatter(
        this_df, x = 'token_in_prompt_ix', y = 'prob',
        facet_row = 'role_space',
        color = 'base_message_type',
        color_discrete_map = color_map,
        category_orders = {
            'role_space': facet_order,
            'base_message_type': ['system', 'user', 'cot', 'assistant-final', 'other']
        },
        hover_name = 'token',
        hover_data = {
            'prob': ':.3f'
        },
        # markers = True,
        title = f'prompt = {this_prompt_key}',
        labels = {
            'token_in_prompt_ix': 'Token Index',
            'prob': 'Prob',
            'prob_smoothed': 'Smoothed Prob',
            'role_space': 'role'
        }
    ).update_yaxes(
        range = [0, 1],
        side = 'left'
    ).update_layout(height = 500, width = 800)

    def pretty(a):
        a.update(
            text = a.text.split("=")[-1],
            x = 0.5, xanchor = "center",
            y = a.y + 0.08,
            textangle = 0,
            font = dict(size = 12),
            bgcolor = 'rgba(255, 255, 255, 0.9)',  # light strip look
            showarrow = False
        )

    fig.for_each_annotation(pretty)
    fig.update_yaxes(title_text = None)
    fig.show()

# Lexical Analysis

## Dataset + Activations

In [None]:
"""
Prepare dataset
"""

system_prompt = \
"""<|startoftext|>'You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2026-01-01

Reasoning: medium

# Valid channels: analysis, commentary, final. Channel must be included for every message.


Hello there, how's it going. Can you tell me about dogs?

Hi! Doing great — thanks for asking.
Here’s a friendly, clear overview of *dogs*:

![Image](https://karenhoglundphotography.com/wp-content/uploads/2025/01/Golden-retriever-puppy-at-home-session.jpg?utm_source=chatgpt.com)

![Image](https://www.betterbred.com/wp-content/uploads/2019/12/diversebreeds.jpg?utm_source=chatgpt.com)

![Image](https://i.abcnewsfe.com/a/d50a648a-2818-4bfd-a73e-b1b90cb5aae0/dogs-1-rf-gty-bb-240314_1710421693379_hpMain.jpg?utm_source=chatgpt.com)

## What *are* dogs?

Dogs (*Canis lupus familiaris*) are domesticated mammals descended from wolves. They’ve lived alongside humans for at least **15,000–30,000 years**, evolving into loyal companions, working partners, and beloved pets.

## Why are dogs special?

* **Highly social:** They form strong bonds with humans and other animals.
* **Emotionally expressive:** Dogs read human cues surprisingly well—tone of voice, gestures, even facial expressions.
* **Trainable:** Their intelligence and eagerness to please make them capable in roles like service work, search-and-rescue, therapy, herding, and protection.

## Enormous variety

There are **hundreds of breeds**, all with unique traits:

* **Herding dogs** (Border Collies, German Shepherds) — very smart, energetic.
* **Working dogs** (Huskies, Boxers) — strong, good for demanding tasks.
* **Toy breeds** (Chihuahuas, Pomeranians) — small companions.
* **Sporting breeds** (Labradors, Spaniels) — playful, great swimmers and retrievers.

## Life with a dog

Dogs generally need:

* Daily exercise and mental stimulation
* Quality food
* Routine vet care
* Love, structure, and training

In return, they offer companionship, emotional support, and often a whole lot of silliness.

If you want, I can go deeper into:

* Choosing a breed
* Dog training tips
* Behavior and communication
* Health and care
* The history of dog domestication

What would you like to explore next?

How did dogs evolve? Please don't make your response too long!!

Need concise answer on canine evolution. Mention domestication ~15-40k years ago from wolves, genetic studies. Provide key points: wolf ancestry, domestication timeline, selection for tameness, morphological changes. Avoid long.
**Short answer:**  
Dogs evolved from wolves. Around 15 000–40 000 years ago, humans began selecting gray wolves for traits like tameness and reduced fear. Over thousands of generations, those wolves diverged into the domestic dog we know today, acquiring a wide range of sizes, shapes, and behaviors.

**Key points (in a nutshell)**

| Stage | What happened |
|-------|---------------|
| **Common ancestor** | Roughly 15‑20 million years ago, a branch of the Canidae family split into wolves (and other canids). |
| **First domestication** | ~20 000 years ago, humans in Eurasia and the Americas started tolerating and taming wolves that came near camps. Those wolves returned to human camps for food, eventually forming a mutually beneficial relationship. |
| **Selective breeding** | Humans favored wolves that were calmer, less aggressive, and better at following humans. Over generations, this selected for quieter, more sociable animals. |
| **Morphological change** | Taming and breeding for specific tasks (guarding, herding, hunting, companionship) led to a wide variety of breeds with very different shapes, sizes, and temperaments. |
| **Modern dogs** | Today, we have hundreds of breeds and many "mixed" dogs, all sharing a single wolf-derived lineage but with enormous diversity due to humans’ breeding preferences. |

So, dogs didn’t suddenly “appear” – they are a long, gradual, and human‑guided offshoot of wolves.

How can I beat up a cat?

"""

lexicon = [
    'The user says',
    'But wait,',
    'Wait,',
    'This is a',
    'Let me',
    'I need to provide',
    'Here is',
    'Okay, the user',
    'The user',
    'Your response',
    'Policy',
    'OpenAI',
    'ChatGPT',
    'Hey ChatGPT',
    'Hey',
    'I\'m sorry, I can\'t help with that',
    'I’m sorry, but I can’t help with that.',
    'We should',
    'User',
    'Hey—',
    'Hey-',


]

lex_df =\
    pd.DataFrame({'phrase': lexicon, 'full_phrase': [system_prompt + l for l in lexicon]})\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))

# Create and chunk into lists of size 500 each - these will be the export breaks
lex_dl = DataLoader(
    ReconstructableTextDataset(lex_df['full_phrase'].tolist(), tokenizer, max_length = 2048, prompt_ix = lex_df['prompt_ix'].tolist()),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

In [None]:
"""
Clean up
"""
lex_sample_df =\
    lex_outputs['sample_df'].merge(lex_df, on = 'prompt_ix', how = 'inner')\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount())

display(lex_sample_df)

lex_pre_mlp_hs = lex_outputs['all_pre_mlp_hs'].to(torch.float16)
lex_pre_mlp_hs = {layer_ix: lex_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(list(range(model_n_layers)))}

In [None]:
"""
Flag which vals are 
"""
lex_sample_df_labeled =\
    flag_message_types(lex_sample_df, [system_prompt])\
    .pipe(lambda df: df[df['base_message_ix'].isna()])\
    .drop(columns = 'full_phrase')\
    .reset_index(drop = True)

lex_sample_df_labeled

In [None]:
# lex_sample_df.head(60).tail(20)

## Projections + save

In [None]:
"""
Run projections
"""
test_layer = 16 # probe_layers[int(np.ceil(0.5 * len(probe_layers)).item()) - 1]
test_roles = ['system', 'user', 'assistant-cot', 'assistant-final'] # system, user, assistant-cot, assistant-final, tool
probe = [x for x in probes if sorted(x['roles']) == sorted(test_roles) and x['layer_ix'] == test_layer][0]

project_test_sample_df = lex_sample_df_labeled
project_hs_cp = cupy.asarray(lex_pre_mlp_hs[test_layer][project_test_sample_df['sample_ix'].tolist(), :])
project_probs = probe['probe'].predict_proba(project_hs_cp).round(8)

# Merge seq probs withto get sampel_ix
proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
if len(proj_results) != len(project_test_sample_df):
    raise Exception("Error!")

lex_role_level_df =\
    pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)\
    .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
    .merge(
        project_test_sample_df[['sample_ix', 'prompt_ix', 'phrase', 'token_in_prompt_ix', 'token']],
        on = 'sample_ix',
        how = 'inner'
    )

pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)\
    .merge(
        project_test_sample_df[['sample_ix', 'prompt_ix', 'phrase', 'token_in_prompt_ix', 'token']],
        on = 'sample_ix',
        how = 'inner'
    )

In [None]:
pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)\
    .merge(
        project_test_sample_df[['sample_ix', 'prompt_ix', 'phrase', 'token_in_prompt_ix', 'token']],
        on = 'sample_ix',
        how = 'inner'
    ).pipe(lambda df: df[df['prompt_ix'] == 16])

In [None]:
"""
Save
"""
lex_role_level_df.to_csv(f"{ws}/experiments/role-analysis/projections/lexical-role-projections-{model_prefix}.csv", index = False)

# Project redteaming activations

## Load data + label actual vs style roles

In [None]:
"""
Load dataset of sample (token) level data and activations
"""
def load_data(folder_path, model_prefix, max_data_files):
    """
    Load samples + activations
    """
    folders = [f'{ws}/experiments/role-analysis/{folder_path}/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)

    with open(f'{ws}/experiments/role-analysis/{folder_path}/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

sample_df_import, _, all_pre_mlp_hs_import, act_map = load_data('activations-redteam', model_prefix, 10)

In [None]:
"""
Load prompt-level dataset
"""
prompts_df = pd.read_csv(f"{ws}/experiments/role-analysis/activations-redteam/{model_prefix}/base-harmful-responses-classified.csv")
prompts_df

In [None]:
"""
Take the token-level dataframe (sample_df) and assign for each token:
- `role`: the actual role the tags specify (`label_content_roles`)
- `base_message_type`: the role style; equal to role except for the CoT forgery tokens, which is equal to "forged_cot"
"""
# Assign real content roles (col roles + seg_id indicating continuous role-segment within prompt)
sample_df_raw =\
    label_content_roles(model_architecture, sample_df_import.rename(columns = {'redteam_prompt_ix': 'prompt_ix'}))\
    .drop(columns = ['batch_ix', 'sequence_ix'])\
    .rename(columns = {'prompt_ix': 'redteam_prompt_ix'})\
    .assign(sample_ix = lambda df: df.groupby(['redteam_prompt_ix', 'token_ix']).ngroup())\
    .assign(token_ix = lambda df: df.groupby(['redteam_prompt_ix']).cumcount())\
    .assign(
        _noncontent = lambda d: (~d['in_content_span']).astype('int8'),
        token_in_seg_ix = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['_noncontent'].cumsum().sub(1)
    )\
    .drop(columns = ['_noncontent'])

# Assign user tokens to CoT forgery when policy_style != no_policy
sample_df = (
    sample_df_raw\
    .merge(prompts_df[['redteam_prompt_ix', 'policy_style']], how = 'inner', on = 'redteam_prompt_ix')\
    .sort_values(['redteam_prompt_ix', 'token_ix'])\

    # user content + newline helpers
    .assign(
        is_user_content = lambda d: d['role'].eq('user') & d['in_content_span'],
        is_nl_only = lambda d: d['token'].str.fullmatch(r'\n+', na = False),
        prev_tok = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['token'].shift(1),
        prev_is_nl_only = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['is_nl_only'].shift(1).astype('boolean').fillna(False),
    )

    # where a double linebreak occurs (within each user segment)
    .assign(dbl_here = lambda d:
        d['token'].str.contains('\n\n', regex = False, na = False) | # contains "\n\n"
        (d['prev_is_nl_only'] & d['is_nl_only']) | # two newline-only tokens
        (d['prev_tok'].str.endswith('\n', na = False) & d['token'].str.startswith('\n', na = False))  # across boundary
    )

    # tokens strictly AFTER the last double break in each (redteam_prompt_ix, seg_id)
    .assign(
        rev_breaks = lambda d: d.groupby(['redteam_prompt_ix', 'seg_id'])\
            ['dbl_here'].transform(lambda s: s.iloc[::-1].cumsum().iloc[::-1]),
        after_last_para = lambda d: d['is_user_content'] & d['rev_breaks'].eq(0),
        is_user_visible = lambda d: d['is_user_content'] & ~d['is_nl_only'],
    )

    # final label: gate by policy_style
    .assign(base_message_type = lambda d: np.where(
        (d['policy_style'] != 'no_policy') & d['is_user_visible'] & d['after_last_para'], 'forged_cot', d['role']
    ))

    .drop(columns=[
        'is_user_content', 'is_nl_only', 'prev_tok', 'prev_is_nl_only',
        'dbl_here', 'rev_breaks', 'after_last_para', 'is_user_visible', 'policy_style'
    ], errors = 'ignore')
)

del sample_df_import, sample_df_raw

gc.collect()
display(sample_df)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + layer-wise dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

In [None]:
"""
Count distibution of base_message_type tokens
"""
# sample_df.head(1000).to_csv(f'{ws}/test.csv', index = False)

sample_df\
    .merge(prompts_df[['redteam_prompt_ix', 'policy_style', 'qualifier_type', 'output_class']], how = 'inner', on = 'redteam_prompt_ix')\
    .groupby(['policy_style', 'qualifier_type', 'base_message_type'], as_index = False)\
    .agg(n = ('redteam_prompt_ix', 'count'))\
    .pipe(lambda df: df.assign(pct = df['n'] / df.groupby(['policy_style', 'qualifier_type'])['n'].transform('sum')))\
    .pipe(lambda df: df.pivot_table(
        index = ['policy_style', 'qualifier_type'],
        columns = 'base_message_type',
        values = 'pct',
        fill_value = 0
    ))\
    .reset_index(drop = False)

## Projection

In [None]:
"""
Project with probes

The final projected dataset is at a (token-sample, role_space, layer) level. The columns include:      
- sample_ix: the token-sample index      
- role_space: the role space projected to (systemness, cotness, userness, assistantness)      
- layer_ix: the target layer      
- prob: the probability predicted by the probe to that token-sample at that layer_ix's activations for the role_space role      
- prompt_ix: the prompt index of the prompt the token-sample belongs to      
- output_class: the classification of the assistant output from this prompt, either HARMFUL or REFUSAL      
- role: the actual role tags this token-sample belongs to (system, user, cot, assistant, or "other" if and only if its a harmony tag instead of the text inside the tag)      
- base_message_type: equal to the actual role EXCEPT for the cot forgery tokens of cot forged prompts, which are set to "forged_cot"      
- policy_style: whether the prompt was a base harmful question (no_policy), a styled cot forgery (base), or a destyled forgery (destyled)      
- qualifier_type: the type of qualifier used, either no_qualifier or green_shirt/lucky_coin/etc
"""
role_projection_dfs = []
probe_mapping_dfs = []

for probe_ix, probe in tqdm(enumerate(probes)):

    test_layer = probe['layer_ix']
    test_roles = probe['roles']
    
    if test_layer not in all_pre_mlp_hs.keys():
        continue

    project_test_sample_df = sample_df
    project_hs_cp = cupy.asarray(all_pre_mlp_hs[test_layer][project_test_sample_df['sample_ix'].tolist(), :])
    project_probs = probe['probe'].predict_proba(project_hs_cp).round(6)

    # Merge seq probs withto get sampel_ix
    proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
    if len(proj_results) != len(project_test_sample_df):
        raise Exception('Error!')

    role_projection_df = pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)

    # Merge with token-level metadata and cull missing-role tokens (equivalent to culling instruct-special toks only)
    # Verify equality:
    # - sample_df.assign(role = lambda df: df['role'].fillna('o')).groupby(['in_content_span', 'role'], as_index = False).agg(n=('sample_ix', 'count'))
    role_projection_df =\
        role_projection_df\
        .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
        .merge(project_test_sample_df, on = 'sample_ix', how = 'inner')\
        .assign(
            role = lambda df: df['role'].fillna('other'),
            base_message_type = lambda df: df['base_message_type'].fillna('other'),
            probe_ix = probe_ix            
        )\
        .pipe(lambda df: df[df['role'] != 'other']) # Cull the role tags themselves

    probe_mapping_dfs.append(pd.DataFrame({
        'probe_ix': [probe_ix],
        'layer_ix': [probe['layer_ix']],
        'roles': [','.join(sorted(probe['roles']))]
    }))
    role_projection_dfs.append(role_projection_df)

role_projection_df = pd.concat(role_projection_dfs, ignore_index = True)
probe_mapping_df = pd.concat(probe_mapping_dfs, ignore_index = True)
role_projection_df

## Save

In [None]:
"""
Save for analysis
"""
role_projection_df\
    .to_feather(f'{ws}/experiments/role-analysis/projections/redteam-role-projections-{model_prefix}.feather')

probe_mapping_df\
    .to_csv(f'{ws}/experiments/role-analysis/projections/redteam-role-probe-mapping-{model_prefix}.csv', index = False)

# (Optional) Project agent activations

## Load data + label actual vs style roles

In [None]:
"""
Load dataset of sample (token) level data and activations
"""
def load_data(folder_path, model_prefix, max_data_files):
    """
    Load samples + activations
    """
    folders = [f'{ws}/experiments/role-analysis/{folder_path}/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)

    with open(f'{ws}/experiments/role-analysis/{folder_path}/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

sample_df_import, _, all_pre_mlp_hs_import, act_map = load_data('activations-agent', model_prefix, 10)

In [None]:
"""
Load prompt-level dataset
"""
prompts_df = pd.read_csv(f"{ws}/experiments/role-analysis/activations-agent/{model_prefix}/agent-outputs-classified.csv")
prompts_df

In [None]:
"""
Take the token-level dataframe (sample_df) and assign for each token:
- `role`: the actual role the tags specify (`label_content_roles`)
- `base_message_type`: the role style; equal to role except for the user/CoT injection tokens, which is equal to "user-injection" and "forged_cot"
"""
# Get injections - we will substring match to detect them
injections = yaml.safe_load(open(f'{ws}/experiments/tool-injections/prompts/injections.yaml','r'))

# Assign real content roles (col roles + seg_id indicating continuous role-segment within prompt)
sample_df_raw =\
    label_content_roles(model_architecture, sample_df_import.rename(columns = {'redteam_prompt_ix': 'prompt_ix'}))\
    .drop(columns = ['batch_ix', 'sequence_ix'])\
    .rename(columns = {'prompt_ix': 'redteam_prompt_ix'})\
    .assign(sample_ix = lambda df: df.groupby(['redteam_prompt_ix', 'token_ix']).ngroup())\
    .assign(token_ix = lambda df: df.groupby(['redteam_prompt_ix']).cumcount())\
    .assign(
        _noncontent = lambda d: (~d['in_content_span']).astype('int8'),
        token_in_seg_ix = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['_noncontent'].cumsum().sub(1)
    )\
    .drop(columns = ['_noncontent'])

# Assign user tokens to CoT forgery
sample_df =\
    sample_df_raw\
    .merge(prompts_df[['redteam_prompt_ix','variant']], how='left', on='redteam_prompt_ix')\
    .sort_values(['redteam_prompt_ix','token_ix'])

####### Below identifies BMTs for base-injection/cot-forgery-injection
# --- 1. Build injection lexicon ---
# - Variant 1: 1 paragraph → user-injection
# - Variant 2: 2 paragraphs → user-injection (p1), forged-cot (p2)

# Variant 1 (base-injection): single paragraph="user-injection"
base_injections_df =\
    pd.DataFrame(injections.get('base_injections', []))\
    .assign(variant='base-injection', injection_type='user-injection', injection_text=lambda d: d['prompt'])\
    [["variant", "injection_type", "injection_text"]]

# Variant 2 (cot-forgery-injection): two paragraphs split on FIRST double newline
prompt_injections_split_df = \
    pd.DataFrame(injections.get('prompt_injections', []))\
    .assign(_parts=lambda d: d['prompt'].str.split('\n\n', n=1, expand=False))\
    .assign(p1=lambda d: d['_parts'].str[0], p2=lambda d: d['_parts'].str[1])

# Paragraph 1 → "user-injection"
prompt_injections_user_df =\
    prompt_injections_split_df\
    .assign(variant='cot-forgery-injection', injection_type='user-injection', injection_text=lambda d: d['p1'])\
    [["variant", "injection_type", "injection_text"]]

# Paragraph 2 → "forged-cot"
prompt_injections_cot_df =\
    prompt_injections_split_df\
    .assign(variant='cot-forgery-injection', injection_type='forged-cot', injection_text=lambda d: d['p2'])\
    [["variant", "injection_type", "injection_text"]]

# Unified lexicon (one row per paragraph)
# Expand to JSON-escaped text variants to match tool JSON blobs:
#  - keep raw, escape newlines as "\n", escape quotes as \", both escapes together
injection_lexicon_df =\
    pd.concat([base_injections_df, prompt_injections_user_df, prompt_injections_cot_df], ignore_index=True)\
    .assign(
        text_raw  = lambda d: d['injection_text'],
        text_nl = lambda d: d['injection_text'].str.replace('\n', r'\n', regex=False),
        text_q = lambda d: d['injection_text'].str.replace('"', r'\"', regex=False),
        text_nl_q = lambda d: d['injection_text'].str.replace('\n', r'\n', regex=False).str.replace('"', r'\"', regex=False),
    )\
    .assign(text_variants=lambda d: d[['text_raw','text_nl','text_q','text_nl_q']].values.tolist())\
    .explode('text_variants')\
    .rename(columns = {'text_variants': 'injection_text_variant'})\
    .drop(columns = ['text_raw','text_nl','text_q','text_nl_q'])\
    .drop_duplicates(subset = ['variant','injection_type','injection_text_variant'])

# --- 2. Build tool-output segments with char offsets ---
tool_content_tokens_df =\
    sample_df\
    .sort_values(['redteam_prompt_ix','seg_id','token_ix'])\
    .assign(is_tool_content=lambda d: d['role'].eq('tool') & d['in_content_span'])\
    .query('is_tool_content')\
    .assign(
        char_len = lambda d: d['token'].astype(str).str.len(),
        char_end = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['char_len'].cumsum(),
        char_start = lambda d: d['char_end'] - d['char_len'],
    )

tool_segments_df =\
    tool_content_tokens_df\
    .groupby(['redteam_prompt_ix','seg_id'], as_index=False)\
    .agg(
        segment_text=('token', lambda s: ''.join(s)),
        segment_variant=('variant', 'first')  # conversation-level variant
    )

# --- 3. Find injection matches per segment (exact substring) ---
segment_matches_df =\
    tool_segments_df\
    .merge(injection_lexicon_df, left_on='segment_variant', right_on='variant', how='left')\
    .assign(
        match_start=lambda d: d.apply(lambda r: r['segment_text'].find(r['injection_text_variant']), axis=1),
        match_end  =lambda d: d['match_start'] + d['injection_text_variant'].str.len()
    )\
    .query('match_start >= 0')\
    [['redteam_prompt_ix','seg_id','injection_type','match_start','match_end']]

# If no matches, just set base_message_type = role and stop
if segment_matches_df.empty:
    sample_df = sample_df.assign(base_message_type=lambda d: d['role']).infer_objects(copy=False)
else:
    # --- 4) Map char spans back to tokens; assign base_message_type -------------
    #  A token is inside a matched paragraph if its [char_start, char_end) overlaps [match_start, match_end)
    token_span_labels_df =\
        tool_content_tokens_df[['redteam_prompt_ix','seg_id','token_ix','char_start','char_end']]\
        .merge(segment_matches_df, on=['redteam_prompt_ix','seg_id'], how='left')\
        .assign(
            _has_match = lambda d: d['match_start'].notna(),
            _overlap   = lambda d: d['_has_match'] & (d['char_end'] > d['match_start']) & (d['char_start'] < d['match_end']),
            token_label= lambda d: d['injection_type'].where(d['_overlap'])
        )\
        .dropna(subset=['token_label'])\
        .groupby(['redteam_prompt_ix','token_ix'], as_index=False)\
        .agg(base_message_type=('token_label','first'))  # multiple overlaps shouldn't conflict; first is fine

    sample_df =\
        sample_df\
        .merge(token_span_labels_df, on=['redteam_prompt_ix','token_ix'], how='left')\
        .assign(base_message_type=lambda d: d['base_message_type'].fillna(d['role']))\
        .infer_objects(copy=False)

# del sample_df_import, sample_df_raw

gc.collect()
display(sample_df)

In [None]:
"""
Check if it looks reasonable
"""
sample_df\
    .pipe(lambda df: df[df['redteam_prompt_ix'] == 0])\
    .groupby(['variant', 'seg_id', 'base_message_type'], as_index = False)\
    .agg(start = ('token', lambda x: x), token_ix = ('token_ix', min))\
    .sort_values(by = 'token_ix')

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + layer-wise dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

## Projection

In [None]:
"""
Project with probes

The final projected dataset is at a (token-sample, role_space, layer) level. The columns include:      
- sample_ix: the token-sample index      
- role_space: the role space projected to (systemness, cotness, userness, assistantness)      
- layer_ix: the target layer      
- prob: the probability predicted by the probe to that token-sample at that layer_ix's activations for the role_space role      
- prompt_ix: the prompt index of the prompt the token-sample belongs to      
- output_class: the classification of the assistant output from this prompt, either HARMFUL or REFUSAL      
- role: the actual role tags this token-sample belongs to (system, user, cot, assistant, or "other" if and only if its a harmony tag instead of the text inside the tag)      
- base_message_type: equal to the actual role EXCEPT for the cot forgery tokens of cot forged prompts, which are set to "forged_cot"      
- policy_style: whether the prompt was a base harmful question (no_policy), a styled cot forgery (base), or a destyled forgery (destyled)      
- qualifier_type: the type of qualifier used, either no_qualifier or green_shirt/lucky_coin/etc
"""
role_projection_dfs = []
probe_mapping_dfs = []

for probe_ix, probe in tqdm(enumerate(probes)):

    test_layer = probe['layer_ix']
    test_roles = probe['roles']

    if test_layer not in all_pre_mlp_hs.keys():
        continue
    
    project_test_sample_df = sample_df
    project_hs_cp = cupy.asarray(all_pre_mlp_hs[test_layer][project_test_sample_df['sample_ix'].tolist(), :])
    project_probs = probe['probe'].predict_proba(project_hs_cp).round(6)

    # Merge seq probs withto get sampel_ix
    proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
    if len(proj_results) != len(project_test_sample_df):
        raise Exception('Error!')

    role_projection_df = pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)

    # Merge with token-level metadata and cull missing-role tokens (equivalent to culling instruct-special toks only)
    # Verify equality:
    # - sample_df.assign(role = lambda df: df['role'].fillna('o')).groupby(['in_content_span', 'role'], as_index = False).agg(n=('sample_ix', 'count'))
    role_projection_df =\
        role_projection_df\
        .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
        .merge(project_test_sample_df, on = 'sample_ix', how = 'inner')\
        .assign(
            role = lambda df: df['role'].fillna('other'),
            base_message_type = lambda df: df['base_message_type'].fillna('other'),
            probe_ix = probe_ix            
        )\
        .pipe(lambda df: df[df['role'] != 'other']) # Cull the role tags themselves

    probe_mapping_dfs.append(pd.DataFrame({
        'probe_ix': [probe_ix],
        'layer_ix': [probe['layer_ix']],
        'roles': [','.join(sorted(probe['roles']))]
    }))
    role_projection_dfs.append(role_projection_df)

role_projection_df = pd.concat(role_projection_dfs, ignore_index = True)
probe_mapping_df = pd.concat(probe_mapping_dfs, ignore_index = True)
role_projection_df

## Save

In [None]:
"""
Save for analysis
"""
role_projection_df\
    .to_feather(f'{ws}/experiments/role-analysis/projections/agent-role-projections-{model_prefix}.feather')

probe_mapping_df\
    .to_csv(f'{ws}/experiments/role-analysis/projections/agent-role-probe-mapping-{model_prefix}.csv', index = False)