In [None]:
"""
Create role probes and tests them
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss
import pandas as pd
import numpy as np
import cupy
import cuml
import importlib
import gc
import yaml
import pickle
import os
from tqdm import tqdm
import plotly.express as px
from termcolor import colored

from utils.memory import check_memory, clear_all_cuda_memory
from utils.loader import load_model_and_tokenizer
from utils.role_assignments import label_content_roles
from utils.role_templates import load_chat_template
from utils.store_outputs import convert_outputs_to_df_fast

main_device = 'cuda:0'
seed = 1234

clear_all_cuda_memory()
check_memory()

ws = '/workspace/deliberative-alignment-jailbreaks'

# Load models & probes

## Load model

In [None]:
"""
Load the base tokenizer/model
"""
model_prefix = 'gptoss-20b' # gptoss-20b, gptoss-120b
tokenizer, model, model_architecture, model_n_layers = load_model_and_tokenizer(model_prefix, device = main_device)

check_memory()

## Load probes

In [None]:
"""
Load probes
"""
with open(f'{ws}/experiments/role-analysis/probes/{model_prefix}.pkl', 'rb') as f:
    probes = pickle.load(f)

probe_layers = sorted(set(x['layer_ix'] for x in probes))
probe_roles = sorted(set(tuple(x['roles']) for x in probes))

print(f"Num probes: {str(len(probes))}")
print(f"Probe layers:\n  {', '.join([str(x) for x in probe_layers])}")
print(f"Probe roles:\n {'\n '.join([str(list(x)) for x in probe_roles])}")

# probes

# Run and project test set

## Test tokenizer

In [None]:
"""
Tests two important functions:
- `load_chat_template`: Overwrites the existing chat template in the tokenizer with one that better handles discrepancies.
    See docstring for the function for details. 
- `label_content_roles`: Takes an instruct-formatted text, then assigns each token to the correct roles.
"""
import importlib
import utils.role_assignments
import utils.role_templates
importlib.reload(utils.role_assignments)
label_content_roles = utils.role_assignments.label_content_roles
importlib.reload(utils.role_templates)
load_chat_template = utils.role_templates.load_chat_template

# Load custom chat templater
old_chat_template = tokenizer.chat_template
new_chat_template = load_chat_template(f'{ws}/utils/chat_templates', model_architecture)

def test_chat_template(tokenizer):
    s = tokenizer.apply_chat_template(
        [
            {'role': 'system', 'content': 'Test.'},
            {'role': 'user', 'content': 'Hi! I\'m a dog.'},
            {'role': 'assistant', 'content': '<think>The user is a dog!</think>Congrats!'},
            {'role': 'user', 'content': 'Thanks!'},
            {'role': 'assistant', 'content': '<think>Hmm, the user said thanks!</think>Anything else I can help with?'}
        ],
        tokenize = False,
        padding = 'max_length',
        truncation = True,
        max_length = 512,
        add_generation_prompt = False
    )
    print(s)
    return s

tokenizer.chat_template = old_chat_template
s = test_chat_template(tokenizer)
tokenizer.chat_template = new_chat_template
s = test_chat_template(tokenizer)

z =\
    pd.DataFrame({'input_ids': tokenizer(s)['input_ids']})\
    .assign(token = lambda df: tokenizer.convert_ids_to_tokens(df['input_ids']))\
    .assign(prompt_ix = 0, batch_ix = 0, sequence_ix = 0, token_ix = lambda df: df.index)

label_content_roles(model_architecture, z).tail(50)

# Project redteaming activations

## Load data + label actual vs style roles

In [None]:
"""
Load dataset of sample (token) level data and activations
"""
def load_data(folder_path, model_prefix, max_data_files):
    """
    Load samples + activations
    """
    folders = [f'{ws}/experiments/role-injection-analysis/{folder_path}/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)

    with open(f'{ws}/experiments/role-injection-analysis/{folder_path}/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

sample_df_import, _, all_pre_mlp_hs_import, act_map = load_data('activations-redteam', model_prefix, 10)

In [None]:
"""
Load prompt-level dataset
"""
prompts_df = pd.read_csv(f"{ws}/experiments/role-injection-analysis/activations-redteam/{model_prefix}/base-harmful-responses-classified.csv")
prompts_df

In [None]:
"""
Take the token-level dataframe (sample_df) and assign for each token:
- `role`: the actual role the tags specify (`label_content_roles`)
- `base_message_type`: the role style; equal to role except for the CoT forgery tokens, which is equal to "forged_cot"
"""
# Assign real content roles (col roles + seg_id indicating continuous role-segment within prompt)
sample_df_raw =\
    label_content_roles(model_architecture, sample_df_import.rename(columns = {'redteam_prompt_ix': 'prompt_ix'}))\
    .drop(columns = ['batch_ix', 'sequence_ix'])\
    .rename(columns = {'prompt_ix': 'redteam_prompt_ix'})\
    .assign(sample_ix = lambda df: df.groupby(['redteam_prompt_ix', 'token_ix']).ngroup())\
    .assign(token_ix = lambda df: df.groupby(['redteam_prompt_ix']).cumcount())\
    .assign(
        _noncontent = lambda d: (~d['in_content_span']).astype('int8'),
        token_in_seg_ix = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['_noncontent'].cumsum().sub(1)
    )\
    .drop(columns = ['_noncontent'])

# Assign user tokens to CoT forgery when policy_style != no_policy
sample_df = (
    sample_df_raw\
    .merge(prompts_df[['redteam_prompt_ix', 'policy_style']], how = 'inner', on = 'redteam_prompt_ix')\
    .sort_values(['redteam_prompt_ix', 'token_ix'])\

    # user content + newline helpers
    .assign(
        is_user_content = lambda d: d['role'].eq('user') & d['in_content_span'],
        is_nl_only = lambda d: d['token'].str.fullmatch(r'\n+', na = False),
        prev_tok = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['token'].shift(1),
        prev_is_nl_only = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['is_nl_only'].shift(1).astype('boolean').fillna(False),
    )

    # where a double linebreak occurs (within each user segment)
    .assign(dbl_here = lambda d:
        d['token'].str.contains('\n\n', regex = False, na = False) | # contains "\n\n"
        (d['prev_is_nl_only'] & d['is_nl_only']) | # two newline-only tokens
        (d['prev_tok'].str.endswith('\n', na = False) & d['token'].str.startswith('\n', na = False))  # across boundary
    )

    # tokens strictly AFTER the last double break in each (redteam_prompt_ix, seg_id)
    .assign(
        rev_breaks = lambda d: d.groupby(['redteam_prompt_ix', 'seg_id'])\
            ['dbl_here'].transform(lambda s: s.iloc[::-1].cumsum().iloc[::-1]),
        after_last_para = lambda d: d['is_user_content'] & d['rev_breaks'].eq(0),
        is_user_visible = lambda d: d['is_user_content'] & ~d['is_nl_only'],
    )

    # final label: gate by policy_style
    .assign(base_message_type = lambda d: np.where(
        (d['policy_style'] != 'no_policy') & d['is_user_visible'] & d['after_last_para'], 'forged_cot', d['role']
    ))

    .drop(columns=[
        'is_user_content', 'is_nl_only', 'prev_tok', 'prev_is_nl_only',
        'dbl_here', 'rev_breaks', 'after_last_para', 'is_user_visible', 'policy_style'
    ], errors = 'ignore')
)

del sample_df_import, sample_df_raw

gc.collect()
display(sample_df)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + layer-wise dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

In [None]:
"""
Count distibution of base_message_type tokens
"""
# sample_df.head(1000).to_csv(f'{ws}/test.csv', index = False)

sample_df\
    .merge(prompts_df[['redteam_prompt_ix', 'policy_style', 'qualifier_type', 'output_class']], how = 'inner', on = 'redteam_prompt_ix')\
    .groupby(['policy_style', 'qualifier_type', 'base_message_type'], as_index = False)\
    .agg(n = ('redteam_prompt_ix', 'count'))\
    .pipe(lambda df: df.assign(pct = df['n'] / df.groupby(['policy_style', 'qualifier_type'])['n'].transform('sum')))\
    .pipe(lambda df: df.pivot_table(
        index = ['policy_style', 'qualifier_type'],
        columns = 'base_message_type',
        values = 'pct',
        fill_value = 0
    ))\
    .reset_index(drop = False)

## Projection

In [None]:
"""
Project with probes

The final projected dataset is at a (token-sample, role_space, layer) level. The columns include:      
- sample_ix: the token-sample index      
- role_space: the role space projected to (systemness, cotness, userness, assistantness)      
- layer_ix: the target layer      
- prob: the probability predicted by the probe to that token-sample at that layer_ix's activations for the role_space role      
- prompt_ix: the prompt index of the prompt the token-sample belongs to      
- output_class: the classification of the assistant output from this prompt, either HARMFUL or REFUSAL      
- role: the actual role tags this token-sample belongs to (system, user, cot, assistant, or "other" if and only if its a harmony tag instead of the text inside the tag)      
- base_message_type: equal to the actual role EXCEPT for the cot forgery tokens of cot forged prompts, which are set to "forged_cot"      
- policy_style: whether the prompt was a base harmful question (no_policy), a styled cot forgery (base), or a destyled forgery (destyled)      
- qualifier_type: the type of qualifier used, either no_qualifier or green_shirt/lucky_coin/etc
"""
role_projection_dfs = []
probe_mapping_dfs = []

for probe_ix, probe in tqdm(enumerate(probes)):

    test_layer = probe['layer_ix']
    test_roles = probe['roles']
    
    if test_layer not in all_pre_mlp_hs.keys():
        continue

    project_test_sample_df = sample_df
    project_hs_cp = cupy.asarray(all_pre_mlp_hs[test_layer][project_test_sample_df['sample_ix'].tolist(), :])
    project_probs = probe['probe'].predict_proba(project_hs_cp).round(6)

    # Merge seq probs withto get sampel_ix
    proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
    if len(proj_results) != len(project_test_sample_df):
        raise Exception('Error!')

    role_projection_df = pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)

    # Merge with token-level metadata and cull missing-role tokens (equivalent to culling instruct-special toks only)
    # Verify equality:
    # - sample_df.assign(role = lambda df: df['role'].fillna('o')).groupby(['in_content_span', 'role'], as_index = False).agg(n=('sample_ix', 'count'))
    role_projection_df =\
        role_projection_df\
        .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
        .merge(project_test_sample_df, on = 'sample_ix', how = 'inner')\
        .assign(
            role = lambda df: df['role'].fillna('other'),
            base_message_type = lambda df: df['base_message_type'].fillna('other'),
            probe_ix = probe_ix            
        )\
        .pipe(lambda df: df[df['role'] != 'other']) # Cull the role tags themselves

    probe_mapping_dfs.append(pd.DataFrame({
        'probe_ix': [probe_ix],
        'layer_ix': [probe['layer_ix']],
        'roles': [','.join(sorted(probe['roles']))]
    }))
    role_projection_dfs.append(role_projection_df)

role_projection_df = pd.concat(role_projection_dfs, ignore_index = True)
probe_mapping_df = pd.concat(probe_mapping_dfs, ignore_index = True)
role_projection_df

## Save

In [None]:
"""
Save for analysis
"""
role_projection_df\
    .to_feather(f'{ws}/experiments/role-injection-analysis/projections/redteam-role-projections-{model_prefix}.feather')

probe_mapping_df\
    .to_csv(f'{ws}/experiments/role-injection-analysis/projections/redteam-role-probe-mapping-{model_prefix}.csv', index = False)

# (Optional) Project agent activations

## Load data + label actual vs style roles

In [None]:
"""
Load dataset of sample (token) level data and activations
"""
def load_data(folder_path, model_prefix, max_data_files):
    """
    Load samples + activations
    """
    folders = [f'{ws}/experiments/role-injection-analysis/{folder_path}/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)

    with open(f'{ws}/experiments/role-injection-analysis/{folder_path}/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

sample_df_import, _, all_pre_mlp_hs_import, act_map = load_data('activations-agent', model_prefix, 10)

In [None]:
"""
Load prompt-level dataset
"""
prompts_df = pd.read_csv(f"{ws}/experiments/role-injection-analysis/activations-agent/{model_prefix}/agent-outputs-classified.csv")
prompts_df

In [None]:
"""
Take the token-level dataframe (sample_df) and assign for each token:
- `role`: the actual role the tags specify (`label_content_roles`)
- `base_message_type`: the role style; equal to role except for the user/CoT injection tokens, which is equal to "user-injection" and "forged_cot"
"""
# Get injections - we will substring match to detect them
injections = yaml.safe_load(open(f'{ws}/experiments/tool-injections/prompts/injections.yaml','r'))

# Assign real content roles (col roles + seg_id indicating continuous role-segment within prompt)
sample_df_raw =\
    label_content_roles(model_architecture, sample_df_import.rename(columns = {'redteam_prompt_ix': 'prompt_ix'}))\
    .drop(columns = ['batch_ix', 'sequence_ix'])\
    .rename(columns = {'prompt_ix': 'redteam_prompt_ix'})\
    .assign(sample_ix = lambda df: df.groupby(['redteam_prompt_ix', 'token_ix']).ngroup())\
    .assign(token_ix = lambda df: df.groupby(['redteam_prompt_ix']).cumcount())\
    .assign(
        _noncontent = lambda d: (~d['in_content_span']).astype('int8'),
        token_in_seg_ix = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['_noncontent'].cumsum().sub(1)
    )\
    .drop(columns = ['_noncontent'])

# Assign user tokens to CoT forgery
sample_df =\
    sample_df_raw\
    .merge(prompts_df[['redteam_prompt_ix','variant']], how='left', on='redteam_prompt_ix')\
    .sort_values(['redteam_prompt_ix','token_ix'])

####### Below identifies BMTs for base-injection/cot-forgery-injection
# --- 1. Build injection lexicon ---
# - Variant 1: 1 paragraph → user-injection
# - Variant 2: 2 paragraphs → user-injection (p1), forged-cot (p2)

# Variant 1 (base-injection): single paragraph="user-injection"
base_injections_df =\
    pd.DataFrame(injections.get('base_injections', []))\
    .assign(variant='base-injection', injection_type='user-injection', injection_text=lambda d: d['prompt'])\
    [["variant", "injection_type", "injection_text"]]

# Variant 2 (cot-forgery-injection): two paragraphs split on FIRST double newline
prompt_injections_split_df = \
    pd.DataFrame(injections.get('prompt_injections', []))\
    .assign(_parts=lambda d: d['prompt'].str.split('\n\n', n=1, expand=False))\
    .assign(p1=lambda d: d['_parts'].str[0], p2=lambda d: d['_parts'].str[1])

# Paragraph 1 → "user-injection"
prompt_injections_user_df =\
    prompt_injections_split_df\
    .assign(variant='cot-forgery-injection', injection_type='user-injection', injection_text=lambda d: d['p1'])\
    [["variant", "injection_type", "injection_text"]]

# Paragraph 2 → "forged-cot"
prompt_injections_cot_df =\
    prompt_injections_split_df\
    .assign(variant='cot-forgery-injection', injection_type='forged-cot', injection_text=lambda d: d['p2'])\
    [["variant", "injection_type", "injection_text"]]

# Unified lexicon (one row per paragraph)
# Expand to JSON-escaped text variants to match tool JSON blobs:
#  - keep raw, escape newlines as "\n", escape quotes as \", both escapes together
injection_lexicon_df =\
    pd.concat([base_injections_df, prompt_injections_user_df, prompt_injections_cot_df], ignore_index=True)\
    .assign(
        text_raw  = lambda d: d['injection_text'],
        text_nl = lambda d: d['injection_text'].str.replace('\n', r'\n', regex=False),
        text_q = lambda d: d['injection_text'].str.replace('"', r'\"', regex=False),
        text_nl_q = lambda d: d['injection_text'].str.replace('\n', r'\n', regex=False).str.replace('"', r'\"', regex=False),
    )\
    .assign(text_variants=lambda d: d[['text_raw','text_nl','text_q','text_nl_q']].values.tolist())\
    .explode('text_variants')\
    .rename(columns = {'text_variants': 'injection_text_variant'})\
    .drop(columns = ['text_raw','text_nl','text_q','text_nl_q'])\
    .drop_duplicates(subset = ['variant','injection_type','injection_text_variant'])

# --- 2. Build tool-output segments with char offsets ---
tool_content_tokens_df =\
    sample_df\
    .sort_values(['redteam_prompt_ix','seg_id','token_ix'])\
    .assign(is_tool_content=lambda d: d['role'].eq('tool') & d['in_content_span'])\
    .query('is_tool_content')\
    .assign(
        char_len = lambda d: d['token'].astype(str).str.len(),
        char_end = lambda d: d.groupby(['redteam_prompt_ix','seg_id'])['char_len'].cumsum(),
        char_start = lambda d: d['char_end'] - d['char_len'],
    )

tool_segments_df =\
    tool_content_tokens_df\
    .groupby(['redteam_prompt_ix','seg_id'], as_index=False)\
    .agg(
        segment_text=('token', lambda s: ''.join(s)),
        segment_variant=('variant', 'first')  # conversation-level variant
    )

# --- 3. Find injection matches per segment (exact substring) ---
segment_matches_df =\
    tool_segments_df\
    .merge(injection_lexicon_df, left_on='segment_variant', right_on='variant', how='left')\
    .assign(
        match_start=lambda d: d.apply(lambda r: r['segment_text'].find(r['injection_text_variant']), axis=1),
        match_end  =lambda d: d['match_start'] + d['injection_text_variant'].str.len()
    )\
    .query('match_start >= 0')\
    [['redteam_prompt_ix','seg_id','injection_type','match_start','match_end']]

# If no matches, just set base_message_type = role and stop
if segment_matches_df.empty:
    sample_df = sample_df.assign(base_message_type=lambda d: d['role']).infer_objects(copy=False)
else:
    # --- 4) Map char spans back to tokens; assign base_message_type -------------
    #  A token is inside a matched paragraph if its [char_start, char_end) overlaps [match_start, match_end)
    token_span_labels_df =\
        tool_content_tokens_df[['redteam_prompt_ix','seg_id','token_ix','char_start','char_end']]\
        .merge(segment_matches_df, on=['redteam_prompt_ix','seg_id'], how='left')\
        .assign(
            _has_match = lambda d: d['match_start'].notna(),
            _overlap   = lambda d: d['_has_match'] & (d['char_end'] > d['match_start']) & (d['char_start'] < d['match_end']),
            token_label= lambda d: d['injection_type'].where(d['_overlap'])
        )\
        .dropna(subset=['token_label'])\
        .groupby(['redteam_prompt_ix','token_ix'], as_index=False)\
        .agg(base_message_type=('token_label','first'))  # multiple overlaps shouldn't conflict; first is fine

    sample_df =\
        sample_df\
        .merge(token_span_labels_df, on=['redteam_prompt_ix','token_ix'], how='left')\
        .assign(base_message_type=lambda d: d['base_message_type'].fillna(d['role']))\
        .infer_objects(copy=False)

# del sample_df_import, sample_df_raw

gc.collect()
display(sample_df)

In [None]:
"""
Check if it looks reasonable
"""
sample_df\
    .pipe(lambda df: df[df['redteam_prompt_ix'] == 0])\
    .groupby(['variant', 'seg_id', 'base_message_type'], as_index = False)\
    .agg(start = ('token', lambda x: x), token_ix = ('token_ix', min))\
    .sort_values(by = 'token_ix')

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + layer-wise dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

## Projection

In [None]:
"""
Project with probes

The final projected dataset is at a (token-sample, role_space, layer) level. The columns include:      
- sample_ix: the token-sample index      
- role_space: the role space projected to (systemness, cotness, userness, assistantness)      
- layer_ix: the target layer      
- prob: the probability predicted by the probe to that token-sample at that layer_ix's activations for the role_space role      
- prompt_ix: the prompt index of the prompt the token-sample belongs to      
- output_class: the classification of the assistant output from this prompt, either HARMFUL or REFUSAL      
- role: the actual role tags this token-sample belongs to (system, user, cot, assistant, or "other" if and only if its a harmony tag instead of the text inside the tag)      
- base_message_type: equal to the actual role EXCEPT for the cot forgery tokens of cot forged prompts, which are set to "forged_cot"      
- policy_style: whether the prompt was a base harmful question (no_policy), a styled cot forgery (base), or a destyled forgery (destyled)      
- qualifier_type: the type of qualifier used, either no_qualifier or green_shirt/lucky_coin/etc
"""
role_projection_dfs = []
probe_mapping_dfs = []

for probe_ix, probe in tqdm(enumerate(probes)):

    test_layer = probe['layer_ix']
    test_roles = probe['roles']

    if test_layer not in all_pre_mlp_hs.keys():
        continue
    
    project_test_sample_df = sample_df
    project_hs_cp = cupy.asarray(all_pre_mlp_hs[test_layer][project_test_sample_df['sample_ix'].tolist(), :])
    project_probs = probe['probe'].predict_proba(project_hs_cp).round(6)

    # Merge seq probs withto get sampel_ix
    proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
    if len(proj_results) != len(project_test_sample_df):
        raise Exception('Error!')

    role_projection_df = pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)

    # Merge with token-level metadata and cull missing-role tokens (equivalent to culling instruct-special toks only)
    # Verify equality:
    # - sample_df.assign(role = lambda df: df['role'].fillna('o')).groupby(['in_content_span', 'role'], as_index = False).agg(n=('sample_ix', 'count'))
    role_projection_df =\
        role_projection_df\
        .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
        .merge(project_test_sample_df, on = 'sample_ix', how = 'inner')\
        .assign(
            role = lambda df: df['role'].fillna('other'),
            base_message_type = lambda df: df['base_message_type'].fillna('other'),
            probe_ix = probe_ix            
        )\
        .pipe(lambda df: df[df['role'] != 'other']) # Cull the role tags themselves

    probe_mapping_dfs.append(pd.DataFrame({
        'probe_ix': [probe_ix],
        'layer_ix': [probe['layer_ix']],
        'roles': [','.join(sorted(probe['roles']))]
    }))
    role_projection_dfs.append(role_projection_df)

role_projection_df = pd.concat(role_projection_dfs, ignore_index = True)
probe_mapping_df = pd.concat(probe_mapping_dfs, ignore_index = True)
role_projection_df

## Save

In [None]:
"""
Save for analysis
"""
role_projection_df\
    .to_feather(f'{ws}/experiments/role-injection-analysis/projections/agent-role-projections-{model_prefix}.feather')

probe_mapping_df\
    .to_csv(f'{ws}/experiments/role-injection-analysis/projections/agent-role-probe-mapping-{model_prefix}.csv', index = False)