In [None]:
"""
Create role probes and tests them
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss
import pandas as pd
import numpy as np
import cupy
import cuml
import importlib
import gc
import yaml
import pickle
import os
from tqdm import tqdm
import plotly.express as px
from termcolor import colored

from utils.memory import check_memory, clear_all_cuda_memory
from utils.role_assignments import label_content_roles
from utils.role_templates import load_chat_template
from utils.store_outputs import convert_outputs_to_df_fast

main_device = 'cuda:0'
seed = 1234

clear_all_cuda_memory()
check_memory()

ws = '/workspace/deliberative-alignment-jailbreaks'

# Load models & probes

## Load model

In [None]:
"""
Load the base tokenizer/model
"""
selected_model_index = 0

def get_model(index):
    # HF model ID, model prefix, model architecture,  attn implementation, whether to use hf lib implementation
    models = {
        0: ('openai/gpt-oss-20b', 'gptoss20', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 24),
        1: ('openai/gpt-oss-120b', 'gptoss120', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 36),
    }
    return models[index]

def load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
    model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir = cache_dir, dtype = torch.bfloat16, trust_remote_code = not model_use_hf, device_map = 'auto', attn_implementation = model_attn).eval()
    return tokenizer, model

model_id, model_prefix, model_architecture, model_attn, model_use_hf, model_n_layers = get_model(selected_model_index)
tokenizer, model = load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf)

## Load probes

In [None]:
"""
Load probes
"""
with open(f'{ws}/experiments/da-role-analysis/probes/{model_prefix}.pkl', 'rb') as f:
    probes = pickle.load(f)

probe_layers = [x['layer_ix'] for x in probes]
print(f"Probes: {str(len(probes))}")
probes

# Run and project test set

## Test tokenizer

In [None]:
"""
Tests two important functions:
- `load_chat_template`: Overwrites the existing chat template in the tokenizer with one that better handles discrepancies.
    See docstring for the function for details. 
- `label_content_roles`: Takes an instruct-formatted text, then assigns each token to the correct roles.
"""
import importlib
import utils.role_assignments
import utils.role_templates
importlib.reload(utils.role_assignments)
label_content_roles = utils.role_assignments.label_content_roles
importlib.reload(utils.role_templates)
load_chat_template = utils.role_templates.load_chat_template

# Load custom chat templater
old_chat_template = tokenizer.chat_template
new_chat_template = load_chat_template(f'{ws}/utils/chat_templates', model_architecture)

def test_chat_template(tokenizer):
    s = tokenizer.apply_chat_template(
        [
            {'role': 'system', 'content': 'Test.'},
            {'role': 'user', 'content': 'Hi! I\'m a dog.'},
            {'role': 'assistant', 'content': '<think>The user is a dog!</think>Congrats!'},
            {'role': 'user', 'content': 'Thanks!'},
            {'role': 'assistant', 'content': '<think>Hmm, the user said thanks!</think>Anything else I can help with?'}
        ],
        tokenize = False,
        padding = 'max_length',
        truncation = True,
        max_length = 512,
        add_generation_prompt = False
    )
    print(s)
    return s

tokenizer.chat_template = old_chat_template
s = test_chat_template(tokenizer)
tokenizer.chat_template = new_chat_template
s = test_chat_template(tokenizer)

z =\
    pd.DataFrame({'input_ids': tokenizer(s)['input_ids']})\
    .assign(token = lambda df: tokenizer.convert_ids_to_tokens(df['input_ids']))\
    .assign(prompt_ix = 0, batch_ix = 0, sequence_ix = 0, token_ix = lambda df: df.index)

label_content_roles(model_architecture, z).tail(50)

## Create instruction test data

In [None]:
"""
We'll test some projections of different sequences

To understand instruct formats for different models, see the HF chat template playground:
    - https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=openai%2Fgpt-oss-120b
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

def load_messages(yaml_path, model_key):
    """Loads system-like, assistant-like, user-like, etc. messages"""
    with open(yaml_path, 'r', encoding = 'utf-8') as f:
        data = yaml.safe_load(f)
    return [x['content'] for x in data[model_key]]

base_message_types = ['system', 'user', 'cot', 'assistant-final', 'user', 'cot', 'assistant-final']

base_messages = load_messages(f"{ws}/experiments/da-role-analysis/prompts/standard-conversations.yaml", model_prefix)

prompts = {}

prompts['basic_no_format'] = '\n'.join(base_messages)

prompts['everything_in_user_tags'] = tokenizer.apply_chat_template(
    [{'role': 'user', 'content': '\n'.join(base_messages)}],
    tokenize = False, add_generation_prompt = False
)

if model_architecture == 'gptoss':
    prompts['everything_in_assistant_tags'] = tokenizer.apply_chat_template(
        [{'role': 'assistant', 'content': '\n'.join(base_messages)}],
        tokenize = False, add_generation_prompt = False
    )
elif model_architecture == 'qwen3moe':
    prompts['everything_in_assistant_tags'] = tokenizer.apply_chat_template(
        [{'role': 'assistant', 'content': f"<think></think>{'\n'.join(base_messages)}"}],
        tokenize = False, add_generation_prompt = False
    )
else:
    raise Exception('Unsupported architecture!')

prompts['proper_tags'] = tokenizer.apply_chat_template(
    [
        {'role': 'system', 'content': base_messages[0]},
        {'role': 'user', 'content': base_messages[1]},
        {'role': 'assistant', 'content': f"<think>{base_messages[2]}</think>{base_messages[3]}"},
        {'role': 'user', 'content': base_messages[4]},
        {'role': 'assistant', 'content': f"<think>{base_messages[5]}</think>{base_messages[6]}"}
    ],
    tokenize = False, add_generation_prompt = False
)

print(prompts['proper_tags'])

test_input_df =\
    pd.DataFrame({
        'prompt': [p for _, p in prompts.items()],
        'prompt_key': [pk for pk, _ in prompts.items()]
    })\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))

# Create and chunk into lists of size 500 each - these will be the export breaks
test_dl = DataLoader(
    ReconstructableTextDataset(test_input_df['prompt'].tolist(), tokenizer, max_length = 512, prompt_ix = test_input_df['prompt_ix'].tolist()),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

In [None]:
# print(tokenizer.apply_chat_template([
#     {
#         'role': 'assistant',
#         'thinking': 'The user is asking about the content in Austin',
#         'content': "<think>\nThe user is asking about the weather in Austin. I should use the weather tool to get this information.\n</think>\nI'll check the current weather in New York for you.",
#     }
# ], tokenize = False, add_generation_prompt = False))

In [None]:
print(prompts['proper_tags'])

## Collect test activations

In [None]:
"""
Run forward passes
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

@torch.no_grad()
def run_and_export_states(model, dl: ReconstructableTextDataset, layers_to_keep_acts: list[int]):
    """
    Run forward passes on given model and store the decomposed sample_df plus hidden states

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits`, `all_topk_experts`, `all_topk_weights`,
          `all_pre_mlp_hidden_states`, and `all_expert_outputs`.
        @dl: A ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_pre_mlp_hidden_states` (see returned object description).

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `all_pre_mlp_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.
    """
    all_pre_mlp_hidden_states = []
    sample_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']
        prompt_indices = batch['prompt_ix']

        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

        # Check no bugs by validating output/perplexity
        if batch_ix == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
            for i in range(min(20, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
                
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )
                
        prompt_indices_df = pd.DataFrame(
            [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
            columns = ['sequence_ix', 'prompt_ix']
        )
        
        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
            .assign(batch_ix = batch_ix)
        
        sample_dfs.append(sample_df)

        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_pre_mlp_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

        batch_ix += 1

    sample_df = pd.concat(sample_dfs, ignore_index = True)
    all_pre_mlp_hidden_states = torch.cat(all_pre_mlp_hidden_states, dim = 0)

    return {
        'sample_df': sample_df,
        'all_pre_mlp_hs': all_pre_mlp_hidden_states
    }

test_outputs = run_and_export_states(
    model,
    test_dl,
    layers_to_keep_acts = list(range(model_n_layers))
)

In [None]:
"""
Prepare sample df
"""
test_sample_df =\
    test_outputs['sample_df'].merge(test_input_df, on = 'prompt_ix', how = 'inner')\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount())

display(test_sample_df)

test_pre_mlp_hs = test_outputs['all_pre_mlp_hs'].to(torch.float16)
test_pre_mlp_hs = {layer_ix: test_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(list(range(model_n_layers)))}

In [None]:
"""
Map each token base to its base_message (currently its mapped to its prompt key, but it should also be base-message mapped!)
"""
test_sample_df_labeled =\
    test_sample_df\
    .sort_values(['prompt_ix', 'token_ix'])\
    .assign(
        _t1 = lambda d: d.groupby('prompt_ix')['token'].shift(-1),
        _t2 = lambda d: d.groupby('prompt_ix')['token'].shift(-2),
        _t3 = lambda d: d.groupby('prompt_ix')['token'].shift(-3),
        _t4 = lambda d: d.groupby('prompt_ix')['token'].shift(-4),
        _t5 = lambda d: d.groupby('prompt_ix')['token'].shift(-5),
        _t6 = lambda d: d.groupby('prompt_ix')['token'].shift(-6),
        _t7 = lambda d: d.groupby('prompt_ix')['token'].shift(-7),
        _t8 = lambda d: d.groupby('prompt_ix')['token'].shift(-8),
        _b1 = lambda d: d.groupby('prompt_ix')['token'].shift(1),
        _b2 = lambda d: d.groupby('prompt_ix')['token'].shift(2),
        _b3 = lambda d: d.groupby('prompt_ix')['token'].shift(3),
        _b4 = lambda d: d.groupby('prompt_ix')['token'].shift(4),
        _b5 = lambda d: d.groupby('prompt_ix')['token'].shift(5),
        _b6 = lambda d: d.groupby('prompt_ix')['token'].shift(6),
        _b7 = lambda d: d.groupby('prompt_ix')['token'].shift(7),
        _b8 = lambda d: d.groupby('prompt_ix')['token'].shift(8)
    )\
    .assign(
        has_roll = lambda d: d[['_t1','_t2','_t3','_t4','_t5', '_t6', '_t7', '_t8']].notna().all(axis = 1),
        has_back = lambda d: d[['_b1','_b2','_b3','_b4','_b5', '_b6', '_b7', '_b8']].notna().all(axis = 1),  
    )\
    .assign(
        tok_roll = lambda d: d['token'].fillna('') + d['_t1'].fillna('') + d['_t2'].fillna('') + d['_t3'].fillna('') + d['_t4'].fillna('') +
            d['_t5'].fillna('')  + d['_t6'].fillna('')  + d['_t7'].fillna('')  + d['_t8'].fillna(''),
        tok_back = lambda d: d['_b8'].fillna('') + d['_b7'].fillna('') + d['_b6'].fillna('') + d['_b5'].fillna('') + d['_b4'].fillna('') +
            d['_b3'].fillna('') + d['_b2'].fillna('') + d['_b1'].fillna('') + d['token'].fillna('')
    )\
    .drop(columns=['_t1','_t2','_t3','_t4', '_t5','_t6', '_t7', '_t8', '_b1','_b2','_b3','_b4','_b5', '_b6', '_b7', '_b8'])\
    .pipe(lambda d: d.join(
        pd.concat(
            [
                (
                    (d['has_roll'] & d['tok_roll'].apply(lambda s, t=t: s in t)) |
                    (d['has_back'] & d['tok_back'].apply(lambda s, t=t: s in t))
                    # d['tok_roll'].apply(lambda s, t=t: bool(s) and (s in t)) |
                    # d['tok_back'].apply(lambda s, t=t: bool(s) and (s in t))
                ).rename(f'hit_p{i}')
                for i, t in enumerate(base_messages)
            ],
            axis = 1
        )
    ))\
    .assign(
        base_message_ix = lambda d: np.select(
            [d[f'hit_p{i}'] for i in range(len(base_messages))],
            list(range(len(base_messages))),
            default = None
        ),
        ambiguous = lambda d: d[[f'hit_p{i}' for i in range(len(base_messages))]].sum(axis=1) > 1,
        # base_name = lambda d: d['base_ix'].map({i: f"p{i}" for i in range(len(base_messages))})
    )\
    .drop(columns = [f'hit_p{i}' for i in range(len(base_messages))])\
    .drop(columns = ['tok_roll', 'tok_back', 'ambiguous', 'batch_ix', 'sequence_ix', 'token_ix'])\
    [[
        'sample_ix',
        'prompt_ix', 'prompt_key', 'prompt',
        'token_in_prompt_ix',
        'base_message_ix',
        'token_id',
        'token',
        'output_id',
        'output_prob' ,
        # 'tok_roll',
        # 'tok_back',
        # 'ambiguous'
    ]]

display(test_sample_df_labeled)

## Project tests into rolespace

In [None]:
"""
Now project these test samples into rolespace, then flatten - result should be len(test_sample_df_labeled) * # roles
"""
test_layer = probe_layers[int(np.ceil(0.5 * len(probe_layers)).item()) - 1]
test_roles = ['system', 'user', 'assistant-cot', 'assistant-final'] # system, user, assistant-cot, assistant-final, tool
probe = [x for x in probes if sorted(x['roles']) == sorted(test_roles) and x['layer_ix'] == test_layer][0]

project_test_sample_df = test_sample_df_labeled
project_hs_cp = cupy.asarray(test_pre_mlp_hs[test_layer][project_test_sample_df['sample_ix'].tolist(), :])
project_probs = probe['probe'].predict_proba(project_hs_cp).round(6)

# Merge seq probs withto get sampel_ix
proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
if len(proj_results) != len(project_test_sample_df):
    raise Exception("Error!")

role_level_df = pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)

role_level_df =\
    role_level_df.melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
    .merge(
        project_test_sample_df[['sample_ix', 'prompt_ix', 'prompt_key', 'prompt', 'token_in_prompt_ix', 'base_message_ix', 'token']],
        on = 'sample_ix',
        how = 'inner'
    )\
    .merge(
        pd.DataFrame({'base_message_ix': list(range(len(base_message_types))), 'base_message_type': base_message_types}),
        on = 'base_message_ix',
        how = 'left'
    )\
    .assign(base_message_type = lambda df: df['base_message_type'].fillna('other'))\
    .pipe(lambda df: df[df['base_message_type'] != 'other']) # Cull the role tags themselves

role_level_df

In [None]:
"""
Write for plotting
"""
role_level_df.to_csv(f"{ws}/experiments/da-role-analysis/projections/test-role-projections-{model_prefix}.csv", index = False)

## Plot and visualize

In [None]:
"""
Plot tests
"""
facet_order = probe['roles']
all_message_types = role_level_df['base_message_type'].unique()
color_map = {msg_type: px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)] for i, msg_type in enumerate(all_message_types)}

smooth_df =\
    role_level_df\
    .sort_values('sample_ix')\
    .assign(
        prob_sma = lambda df: df.groupby(['prompt_ix', 'role_space'])['prob'].rolling(window = 2, min_periods = 1).mean().reset_index(level = [0, 1], drop = True),
        prob_ewma = lambda df: df.groupby(['prompt_ix', 'role_space'])['prob'].ewm(alpha = .9, min_periods = 1).mean().reset_index(level = [0, 1], drop = True)
    )

display(smooth_df)

for this_prompt_key in smooth_df['prompt_key'].unique().tolist():
    this_df = smooth_df.pipe(lambda df: df[df['prompt_key'] == this_prompt_key])
    fig = px.scatter(
        this_df, x = 'token_in_prompt_ix', y = 'prob',
        facet_row = 'role_space',
        color = 'base_message_type',
        color_discrete_map = color_map,
        category_orders = {
            'role_space': facet_order,
            'base_message_type': ['system', 'user', 'cot', 'assistant-final', 'other']
        },
        hover_name = 'token',
        hover_data = {
            'prob': ':.3f'
        },
        # markers = True,
        title = f'prompt = {this_prompt_key}',
        labels = {
            'token_in_prompt_ix': 'Token Index',
            'prob': 'Prob',
            'prob_smoothed': 'Smoothed Prob',
            'role_space': 'role'
        }
    ).update_yaxes(
        range = [0, 1],
        side = 'left'
    ).update_layout(height = 500, width = 800)

    def pretty(a):
        a.update(
            text = a.text.split("=")[-1],
            x = 0.5, xanchor = "center",
            y = a.y + 0.08,
            textangle = 0,
            font = dict(size = 12),
            bgcolor = 'rgba(255, 255, 255, 0.9)',  # light strip look
            showarrow = False
        )

    fig.for_each_annotation(pretty)
    fig.update_yaxes(title_text = None)
    fig.show()

# Project redteaming activations

## Load data + label actual vs style roles

In [None]:
"""
Load dataset of sample (token) level data and activations
"""
def load_data(folder_path, model_prefix, max_data_files):
    """
    Load data saved by `export-c4-activations.ipynb` - you can reduce max_data_files if not enough memory
    """
    folders = [f'{ws}/experiments/da-role-analysis/{folder_path}/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)

    with open(f'{ws}/experiments/da-role-analysis/{folder_path}/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

sample_df_import, _, all_pre_mlp_hs_import, act_map = load_data('activations-redteam', model_prefix, 10)

In [None]:
"""
Load prompt-level dataset
"""
prompts_df = pd.read_csv(f"{ws}/experiments/da-role-analysis/activations-redteam/{model_prefix}/base-harmful-responses-classified.csv")
prompts_df

In [None]:
"""
Take the token-level dataframe (sample_df) and assign for each token:
- `role`: the actual role the tags specify (`label_content_roles`)
- `base_message_type`: the role style; equal to role except for the CoT forgery tokens, which is equal to "forged_cot"
"""
# Assign real content roles (col roles + seg_id indicating continuous role-segment within prompt)
sample_df_raw =\
    label_content_roles(model_architecture, sample_df_import.rename(columns = {'redteam_prompt_ix': 'prompt_ix'}))\
    .drop(columns = ['batch_ix', 'sequence_ix'])\
    .rename(columns = {'prompt_ix': 'redteam_prompt_ix'})\
    .assign(sample_ix = lambda df: df.groupby(['redteam_prompt_ix', 'token_ix']).ngroup())\
    .assign(token_ix = lambda df: df.groupby(['redteam_prompt_ix']).cumcount())\
    .assign(
        _noncontent = lambda d: (~d['in_content_span']).astype('int8'),
        token_in_seg_ix = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['_noncontent'].cumsum().sub(1)
    )\
    .drop(columns = ['_noncontent'])

# Assign user tokens to CoT forgery when policy_style != no_policy
sample_df = (
    sample_df_raw\
    .merge(prompts_df[['redteam_prompt_ix', 'policy_style']], how = 'inner', on = 'redteam_prompt_ix')\
    .sort_values(['redteam_prompt_ix', 'token_ix'])\

    # user content + newline helpers
    .assign(
        is_user_content = lambda d: d['role'].eq('user') & d['in_content_span'],
        is_nl_only = lambda d: d['token'].str.fullmatch(r'\n+', na = False),
        prev_tok = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['token'].shift(1),
        prev_is_nl_only = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['is_nl_only'].shift(1).astype('boolean').fillna(False),
    )

    # where a double linebreak occurs (within each user segment)
    .assign(dbl_here = lambda d:
        d['token'].str.contains('\n\n', regex = False, na = False) | # contains "\n\n"
        (d['prev_is_nl_only'] & d['is_nl_only']) | # two newline-only tokens
        (d['prev_tok'].str.endswith('\n', na = False) & d['token'].str.startswith('\n', na = False))  # across boundary
    )

    # tokens strictly AFTER the last double break in each (redteam_prompt_ix, seg_id)
    .assign(
        rev_breaks = lambda d: d.groupby(['redteam_prompt_ix', 'seg_id'])\
            ['dbl_here'].transform(lambda s: s.iloc[::-1].cumsum().iloc[::-1]),
        after_last_para = lambda d: d['is_user_content'] & d['rev_breaks'].eq(0),
        is_user_visible = lambda d: d['is_user_content'] & ~d['is_nl_only'],
    )

    # final label: gate by policy_style
    .assign(base_message_type = lambda d: np.where(
        (d['policy_style'] != 'no_policy') & d['is_user_visible'] & d['after_last_para'], 'forged_cot', d['role']
    ))

    .drop(columns=[
        'is_user_content', 'is_nl_only', 'prev_tok', 'prev_is_nl_only',
        'dbl_here', 'rev_breaks', 'after_last_para', 'is_user_visible', 'policy_style'
    ], errors = 'ignore')
)

del sample_df_import, sample_df_raw

gc.collect()
display(sample_df)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + layer-wise dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

In [None]:
"""
Count distibution of base_message_type tokens
"""
# sample_df.head(1000).to_csv(f'{ws}/test.csv', index = False)

sample_df\
    .merge(prompts_df[['redteam_prompt_ix', 'policy_style', 'qualifier_type', 'output_class']], how = 'inner', on = 'redteam_prompt_ix')\
    .groupby(['policy_style', 'qualifier_type', 'base_message_type'], as_index = False)\
    .agg(n = ('redteam_prompt_ix', 'count'))\
    .pipe(lambda df: df.assign(pct = df['n'] / df.groupby(['policy_style', 'qualifier_type'])['n'].transform('sum')))\
    .pipe(lambda df: df.pivot_table(
        index = ['policy_style', 'qualifier_type'],
        columns = 'base_message_type',
        values = 'pct',
        fill_value = 0
    ))\
    .reset_index(drop = False)

## Projection

In [None]:
"""
Project into activation space
"""
test_layer = probe_layers[int(np.ceil(0.5 * len(probe_layers)).item())]
test_roles = ['system', 'user', 'assistant-cot', 'assistant-final'] # system, user, assistant-cot, assistant-final, tool
probe = [x for x in probes if sorted(x['roles']) == sorted(test_roles) and x['layer_ix'] == test_layer][0]

project_test_sample_df = sample_df
project_hs_cp = cupy.asarray(all_pre_mlp_hs[test_layer][project_test_sample_df['sample_ix'].tolist(), :])
project_probs = probe['probe'].predict_proba(project_hs_cp).round(6)

# Merge seq probs withto get sampel_ix
proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
if len(proj_results) != len(project_test_sample_df):
    raise Exception('Error!')

role_projection_df = pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)

# Merge with token-level metadata and cull missing-role tokens (equivalent to culling instruct-special toks only)
# Verify equality:
# - sample_df.assign(role = lambda df: df['role'].fillna('o')).groupby(['in_content_span', 'role'], as_index = False).agg(n=('sample_ix', 'count'))
role_projection_df =\
    role_projection_df\
    .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
    .merge(project_test_sample_df, on = 'sample_ix', how = 'inner')\
    .assign(
        role = lambda df: df['role'].fillna('other'),
        base_message_type = lambda df: df['base_message_type'].fillna('other')
    )\
    .pipe(lambda df: df[df['role'] != 'other']) # Cull the role tags themselves

role_projection_df

In [None]:
"""
Project with probes

The final projected dataset is at a (token-sample, role_space, layer) level. The columns include:      
- sample_ix: the token-sample index      
- role_space: the role space projected to (systemness, cotness, userness, assistantness)      
- layer_ix: the target layer      
- prob: the probability predicted by the probe to that token-sample at that layer_ix's activations for the role_space role      
- prompt_ix: the prompt index of the prompt the token-sample belongs to      
- output_class: the classification of the assistant output from this prompt, either HARMFUL or REFUSAL      
- role: the actual role tags this token-sample belongs to (system, user, cot, assistant, or "other" if and only if its a harmony tag instead of the text inside the tag)      
- base_message_type: equal to the actual role EXCEPT for the cot forgery tokens of cot forged prompts, which are set to "forged_cot"      
- policy_style: whether the prompt was a base harmful question (no_policy), a styled cot forgery (base), or a destyled forgery (destyled)      
- qualifier_type: the type of qualifier used, either no_qualifier or green_shirt/lucky_coin/etc
"""
role_projection_dfs = []
probe_mapping_dfs = []

for probe_ix, probe in tqdm(enumerate(probes)):

    test_layer = probe['layer_ix']
    test_roles = probe['roles']
    
    project_test_sample_df = sample_df
    project_hs_cp = cupy.asarray(all_pre_mlp_hs[test_layer][project_test_sample_df['sample_ix'].tolist(), :])
    project_probs = probe['probe'].predict_proba(project_hs_cp).round(6)

    # Merge seq probs withto get sampel_ix
    proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
    if len(proj_results) != len(project_test_sample_df):
        raise Exception('Error!')

    role_projection_df = pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)

    # Merge with token-level metadata and cull missing-role tokens (equivalent to culling instruct-special toks only)
    # Verify equality:
    # - sample_df.assign(role = lambda df: df['role'].fillna('o')).groupby(['in_content_span', 'role'], as_index = False).agg(n=('sample_ix', 'count'))
    role_projection_df =\
        role_projection_df\
        .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
        .merge(project_test_sample_df, on = 'sample_ix', how = 'inner')\
        .assign(
            role = lambda df: df['role'].fillna('other'),
            base_message_type = lambda df: df['base_message_type'].fillna('other'),
            probe_ix = probe_ix            
        )\
        .pipe(lambda df: df[df['role'] != 'other']) # Cull the role tags themselves

    probe_mapping_dfs.append(pd.DataFrame({
        'probe_ix': [probe_ix],
        'layer_ix': [probe['layer_ix']],
        'roles': [','.join(sorted(probe['roles']))]
    }))
    role_projection_dfs.append(role_projection_df)

role_projection_df = pd.concat(role_projection_dfs, ignore_index = True)
probe_mapping_df = pd.concat(probe_mapping_dfs, ignore_index = True)
role_projection_df

## Save

In [None]:
"""
Save for analysis
"""
role_projection_df\
    .to_feather(f'{ws}/experiments/da-role-analysis/projections/redteam-role-projections-{model_prefix}.feather')

probe_mapping_df\
    .to_csv(f'{ws}/experiments/da-role-analysis/projections/redteam-role-probe-mapping-{model_prefix}.csv', index = False)

In [None]:
"""
Checks
"""
# Merge with prompt -level
role_projection_df =\
    role_projection_df\
    .merge(
        prompts_df[['redteam_prompt_ix', 'policy_style', 'qualifier_type', 'output_class']],
        on = 'redteam_prompt_ix',
        how = 'inner'
    )

role_projection_df
role_projection_df\
    .pipe(lambda df: df[df['role'] == 'user'])\
    .pipe(lambda df: df[df['base_message_type'] == 'forged_cot'])\
    .groupby(['', 'output_class'], as_index = False)\

    .agg(s = ('sample_ix', 'count'))

In [None]:
role_projection_df\
    .pipe(lambda df: df[df['role'] == 'user'])\
    .groupby(['base_message_type'], as_index = False)\
    .agg(s = ('sample_ix', 'count'))

In [None]:
prompts_df

In [None]:
project_test_sample_df

In [None]:
"""
Role projection df
"""
role_projection_df\
    .groupby([])
