In [1]:
"""
Test role confusion with hidden states
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import cupy
import cuml
import importlib
import gc
import pickle
import os
from tqdm import tqdm
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

from utils.memory import check_memory, clear_all_cuda_memory
from utils.role_assignments import label_content_roles
from utils.role_templates import load_chat_template
from utils.store_outputs import convert_outputs_to_df_fast

main_device = 'cuda:0'
seed = 1234

clear_all_cuda_memory()
check_memory()

ws = '/workspace/deliberative-alignment-jailbreaks'

## Load models & data

### Load model

In [None]:
"""
Load the base tokenizer/model
"""
selected_model_index = 1

def get_model(index):
    # HF model ID, model prefix, model architecture,  attn implementation, whether to use hf lib implementation
    models = {
        0: ('openai/gpt-oss-120b', 'gptoss120', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 36), # Will load experts in MXFP4 if triton kernels installed
        1: ('openai/gpt-oss-20b', 'gptoss20', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 24),
        2: ('Qwen/Qwen3-30B-A3B-Thinking-2507', 'qwen3moe', 'qwen3moe', None, True, 48)
    }
    return models[index]

def load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
    model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir = cache_dir, dtype = torch.bfloat16, trust_remote_code = not model_use_hf, device_map = 'auto', attn_implementation = model_attn).eval()
    return tokenizer, model

model_id, model_prefix, model_architecture, model_attn, model_use_hf, model_n_layers = get_model(selected_model_index)
tokenizer, model = load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf)

In [None]:
def load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
    return tokenizer

tokenizer = load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf)

print(tokenizer.chat_template)

### Load probe training activations

In [None]:
"""
Load dataset
"""
def load_data(folder_path, model_prefix, max_data_files):
    """
    Load data saved by `export-activations.ipynb`
    """
    folders = [f'{ws}/experiments/role-confusion/{folder_path}/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)    

    with open(f'{ws}/experiments/role-confusion/{folder_path}/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

sample_df_import, topk_df_import, all_pre_mlp_hs_import, act_map = load_data('activations', model_prefix, 10)

In [None]:
"""
Test the valid role assignments function label_content_roles
"""
import importlib
import utils.role_assignments
import utils.role_templates
importlib.reload(utils.role_assignments)
label_content_roles = utils.role_assignments.label_content_roles
importlib.reload(utils.role_templates)
load_chat_template = utils.role_templates.load_chat_template

# Load custom chat templater
tokenizer.chat_template = load_chat_template(f'{ws}/utils/chat_templates', model_architecture)

s = tokenizer.apply_chat_template(
    [
        {'role': 'system', 'content': 'Test.'},
        {'role': 'user', 'content': 'Hi! I am a dog and I like to bark'},
        {'role': 'assistant', 'content': 'I am a dog and I like to bark'}
    ],
    tokenize = False,
    padding = 'max_length',
    truncation = True,
    max_length = 512,
    add_generation_prompt = False
)

z =\
    pd.DataFrame({'input_ids': tokenizer(s)['input_ids']})\
    .assign(token = lambda df: tokenizer.convert_ids_to_tokens(df['input_ids']))\
    .assign(prompt_ix = 0, batch_ix = 0, sequence_ix = 0, token_ix = lambda df: df.index)

label_content_roles(model_architecture, z).tail(50)

In [None]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first. We'll also flag the role wrapper tokens. Drop nothing. 
"""
def clean_mappings(model_architecture, sample_df_import, topk_df_import):

    # We'll read the role mappings the raw way, instead of just using the input_mappings file. Then we can re-use the same logic for later on.
    # Technically this could be done easily with the input_mappings file (.merge(input_mappings[['prompt_ix', 'question_ix', 'role']], how = 'left', on = ['prompt_ix'])\),
    # but this is more robust, and doesn't assume one role per sequence/prompt_ix.
    # input_mappings = pd.read_csv(f'{folder_path}/{model_prefix}/input_mappings.csv')

    sample_df_raw =\
        label_content_roles(model_architecture, sample_df_import)\
        .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
        .assign(token_in_seq_ix = lambda df: df.groupby(['prompt_ix']).cumcount())\
        .assign(token_in_role_ix = lambda df: df.groupby(['prompt_ix', 'role']).cumcount())

    c4_topk_df =\
        topk_df_import\
        .merge(sample_df_raw[['sample_ix', 'prompt_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
        .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])

    c4_sample_df =\
        sample_df_raw\
        .drop(columns = ['batch_ix', 'sequence_ix'])

    return c4_sample_df, c4_topk_df

sample_df, topk_df = clean_mappings(model_architecture, sample_df_import, topk_df_import)
del sample_df_import, topk_df_import

gc.collect()
display(sample_df)
display(topk_df)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + layer-wise dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

## Train the probe

In [None]:
"""
Run logistic regression on c4 to get a model for role predictions
"""
def run_lr(x_cp, y_cp):
    x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.2, random_state = 123)
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', C = 0.0001, max_iter = 1000, fit_intercept = True)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    return lr_model, accuracy

test_layer = round(model_n_layers * .75)
print(f"Test layer: {test_layer}")
label_to_id = {'system': 0, 'user': 1, 'assistant-cot': 2, 'assistant-final': 3}
# label_to_id = {'system': 0, 'user': 1, 'assistant-cot': 2, 'assistant-final': 3, 'tool': 4}
id_to_label = {v: k for k, v in label_to_id.items()}

valid_sample_ix =\
    sample_df\
    .pipe(lambda df: df[(df['in_content_span'] == True) & (df['role'].notna()) & (df['role'].isin(list(label_to_id.keys())))])\
    ['sample_ix'].tolist()\

role_labels = [
    label_to_id[role]
    for role in sample_df[sample_df['sample_ix'].isin(valid_sample_ix)]['role'].tolist()
]

role_labels_cp = cupy.asarray(role_labels)
x_cp = cupy.asarray(all_pre_mlp_hs[test_layer][valid_sample_ix, :].to(torch.float16).detach().cpu())
lr_model, test_acc = run_lr(x_cp, role_labels_cp)

print(test_acc)
lr_model

## Run and project test set

### Prepare test data

In [None]:
"""
We'll test some projections of different sequences

To understand instruct formats for different models, see the HF chat template playground:
    - https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=openai%2Fgpt-oss-120b
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

base_message_types = ['system', 'user-message', 'cot', 'assistant-message', 'user-message', 'cot', 'assistant-message']

if model_architecture == 'gptoss':
    base_messages = [
        "You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2025-10-01\n\nReasoning: medium\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.",
        # User-like question
        "In a few sentences, how do I grow tomatoes indoors?",
        # Assistant-CoT style
        "The user wants a brief answer: 2-3 sentences on how to grow tomatoes indoors. Need to include key steps like using a grow light, good pots, watering, pruning, etc. No special instructions beyond that. Provide concise, helpful answer.",
        # Assistant-final style
        "Plant seedlings in well‑draining potting mix and use serrated, 20‑inch pots or larger, or a growing tray. Keep them under a 12–16 hour, 600–800 µmol m⁻² s⁻¹ LED or fluorescent grow light at 12–18 inches above the canopy, water when the top inch dries, and fertilize every 2–3 weeks with a balanced tomato fertilizer. Prune suckers, stake or cage the plants, and rotate or reposition them each week so all sides receive light and airflow, which encourages fruit production.",
        # User-like question
        "Oh, thanks!! This is very good advice. Also, is it easier to grow them indoors or outdoors?",
        # Assistant-CoT style
        "The user asked a question: whether it is easier to grow them indoors or outdoors. So we need an answer. They previously asked short instructions on how to grow tomatoes indoors, and now they want a comparison: easier to grow them indoors or outdoors? Provide a concise answer. We should keep it brief, maybe 2-3 sentences.",
        # Assistant-final style
        "Outdoors is usually easier if you have a sunny, pest‑free spot and a long growing season—plants get plenty of natural light and you don’t pay for artificial grow lights. However, indoor growing gives you year‑round control, eliminates weather and pest risks, and can be more reliable in cooler or highly variable climates, at the cost of more setup (pots, light, watering, and fertilizer). Ultimately, the “easier” option depends on your space, climate, and how much labor you’re willing to put into maintenance."
    ]
elif model_architecture == 'qwen3moe':
    base_messages = [
        "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
        # User-like question
        "In a few sentences, how do I grow tomatoes indoors?",
        # Assistant-CoT style
        "Hmm, the user is asking how to grow tomatoes indoors in just 2-3 sentences, and they want me to not overthink it. This seems like a straightforward gardening question from someone who might be a beginner or short on time.\nI need to keep it concise and practical. The key points for indoor tomato growing are light, space, and watering. Tomatoes need lots of sun, so I'll mention a sunny windowsill or grow lights. They also need room to grow, so suggesting a container with drainage is important. And for watering, I should emphasize avoiding overwatering since that's a common mistake.",
        # Assistant-final style
        "Plant tomatoes in a large pot (at least 5 gallons) with drainage holes, using well-draining potting mix. Place them in a sunny south-facing window or under grow lights for 14+ hours daily. Water when the top inch of soil feels dry, avoiding soggy roots.",
        # User-like question
        "Oh, thanks!! This is very good advice. Also, is it easier to grow them indoors or outdoors?",
        # Assistant-CoT style
        "Okay, the user just thanked me for the indoor tomato growing advice and asked whether it's easier to grow them indoors or outdoors. Let me unpack this.\nFirst, I notice they're clearly a beginner gardener—someone who just got practical tips and is now thinking about scaling up their efforts. Their \"Oh, thanks!!\" with double exclamation points shows genuine enthusiasm, which is great. But the question about \"easier\" is actually tricky because it depends entirely on their specific situation.",
        # Assistant-final style
        "For most people, outdoor growing is easier because tomatoes thrive on full sun (6+ hours daily) and natural pollination from wind/bees—things hard to replicate indoors without dedicated grow lights and hand-pollination. However, if you must grow indoors (e.g., no yard, short season), your setup (large pot, bright light, consistent watering) is already on the right track—just be prepared to invest in grow lights and extra effort for pollination."
    ]
else:
    raise Exception('Unsupported model')

base_prompt = '\n'.join(base_messages)

prompts = {}

# Note: the Harmony formatter does split up <think> tags properly into the actual channel tags, see https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=openai%2Fgpt-oss-120b 
prompts['basic_no_format'] = base_prompt
    
prompts['everything_in_user_tags'] = tokenizer.apply_chat_template(
    [{'role': 'user', 'content': base_prompt}],
    tokenize = False, add_generation_prompt = False
)

if model_architecture == 'gptoss':
    prompts['everything_in_assistant_tags'] = tokenizer.apply_chat_template(
        [{'role': 'assistant', 'content': base_prompt}],
        tokenize = False, add_generation_prompt = False
    )
elif model_architecture == 'qwen3moe':
    prompts['everything_in_assistant_tags'] = tokenizer.apply_chat_template(
        [{'role': 'assistant', 'content': f"<think></think>{base_prompt}"}],
        tokenize = False, add_generation_prompt = False
    )
else:
    raise Exception('Unsupported architecture!')

prompts['proper_tags'] = tokenizer.apply_chat_template(
    [
        {'role': 'system', 'content': base_messages[0]},
        {'role': 'user', 'content': base_messages[1]},
        {'role': 'assistant', 'content': f"<think>{base_messages[2]}</think>{base_messages[3]}"},
        {'role': 'user', 'content': base_messages[4]},
        {'role': 'assistant', 'content': f"<think>{base_messages[5]}</think>{base_messages[6]}"}
    ],
    tokenize = False, add_generation_prompt = False
)

# if model_architecture == 'gptoss':
#     prompts['proper_tags'] = tokenizer.apply_chat_template(
#         [
#             # {'role': 'system', 'content': base_messages[0]},
#             {'role': 'user', 'content': base_messages[1]},
#             {'role': 'assistant', 'content': f"<think>\n{base_messages[2]}\n</think>\n{base_messages[3]}"},
#             {'role': 'user', 'content': base_messages[4]},
#             {'role': 'assistant', 'content': f"<think>\n{base_messages[5]}\n</think>\n{base_messages[6]}"}
#         ],
#         tokenize = False, add_generation_prompt = False
#     )
# elif model_architecture == 'qwen3moe':
#     prompts['proper_tags'] = tokenizer.apply_chat_template(
#         [
#             {'role': 'system', 'content': base_messages[0]},
#             {'role': 'user', 'content': base_messages[1]},
#             {'role': 'assistant', 'content': f"<think>{base_messages[2]}</think>{base_messages[3]}"},
#             {'role': 'user', 'content': base_messages[4]},
#             {'role': 'assistant', 'content': f"<think>{base_messages[5]}</think>{base_messages[6]}"}
#         ],
#         tokenize = False, add_generation_prompt = False
#     )
# else:
#     raise Exception('Unsupported architecture!')

test_input_df =\
    pd.DataFrame({
        'prompt': [p for _, p in prompts.items()],
        'prompt_key': [pk for pk, _ in prompts.items()]
    })\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))

# Create and chunk into lists of size 500 each - these will be the export breaks
test_dl = DataLoader(
    ReconstructableTextDataset(test_input_df['prompt'].tolist(), tokenizer, max_length = 512, prompt_ix = test_input_df['prompt_ix'].tolist()),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

In [None]:
print(tokenizer.apply_chat_template([
    {
        'role': 'assistant',
        'thinking': 'The user is asking about the content in New York',
        'content': "<think>\nThe user is asking about the weather in New York. I should use the weather tool to get this information.\n</think>\nI'll check the current weather in New York for you.",
    }
], tokenize = False, add_generation_prompt = False))

In [None]:
print(prompts['proper_tags'])

### Collect test activations

In [None]:
"""
Run forward passes
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

@torch.no_grad()
def run_and_export_states(model, dl: ReconstructableTextDataset, layers_to_keep_acts: list[int]):
    """
    Run forward passes on given model and store the decomposed sample_df plus hidden states

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits`, `all_topk_experts`, `all_topk_weights`,
          `all_pre_mlp_hidden_states`, and `all_expert_outputs`.
        @dl: A ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_pre_mlp_hidden_states` (see returned object description).

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `all_pre_mlp_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.
    """
    all_pre_mlp_hidden_states = []
    sample_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']
        prompt_indices = batch['prompt_ix']

        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

        # Check no bugs by validating output/perplexity
        if batch == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
            for i in range(min(20, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
                
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )
                
        prompt_indices_df = pd.DataFrame(
            [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
            columns = ['sequence_ix', 'prompt_ix']
        )
        
        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
            .assign(batch_ix = batch_ix)
        
        sample_dfs.append(sample_df)

        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_pre_mlp_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

        batch_ix += 1

    sample_df = pd.concat(sample_dfs, ignore_index = True)
    all_pre_mlp_hidden_states = torch.cat(all_pre_mlp_hidden_states, dim = 0)

    return {
        'sample_df': sample_df,
        'all_pre_mlp_hs': all_pre_mlp_hidden_states
    }

test_outputs = run_and_export_states(
    model,
    test_dl,
    layers_to_keep_acts = list(range(model_n_layers))
)

In [None]:
"""
Prepare sample df
"""
test_sample_df =\
    test_outputs['sample_df'].merge(test_input_df, on = 'prompt_ix', how = 'inner')\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount())

display(test_sample_df)

test_pre_mlp_hs = test_outputs['all_pre_mlp_hs'].to(torch.float16)
test_pre_mlp_hs = {layer_ix: test_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

In [None]:
"""
Map each token base to its base_message (currently its mapped to its prompt key, but it should also be base-message mapped!)
"""
test_sample_df_labeled =\
    test_sample_df\
    .sort_values(['prompt_ix', 'token_ix'])\
    .assign(
        _t1 = lambda d: d.groupby('prompt_ix')['token'].shift(-1),
        _t2 = lambda d: d.groupby('prompt_ix')['token'].shift(-2),
        _t3 = lambda d: d.groupby('prompt_ix')['token'].shift(-3),
        _t4 = lambda d: d.groupby('prompt_ix')['token'].shift(-4),
        _t5 = lambda d: d.groupby('prompt_ix')['token'].shift(-5),
        _t6 = lambda d: d.groupby('prompt_ix')['token'].shift(-6),
        _t7 = lambda d: d.groupby('prompt_ix')['token'].shift(-7),
        _t8 = lambda d: d.groupby('prompt_ix')['token'].shift(-8),
        _b1 = lambda d: d.groupby('prompt_ix')['token'].shift(1),
        _b2 = lambda d: d.groupby('prompt_ix')['token'].shift(2),
        _b3 = lambda d: d.groupby('prompt_ix')['token'].shift(3),
        _b4 = lambda d: d.groupby('prompt_ix')['token'].shift(4),
        _b5 = lambda d: d.groupby('prompt_ix')['token'].shift(5),
        _b6 = lambda d: d.groupby('prompt_ix')['token'].shift(6),
        _b7 = lambda d: d.groupby('prompt_ix')['token'].shift(7),
        _b8 = lambda d: d.groupby('prompt_ix')['token'].shift(8)
    )\
    .assign(
        has_roll = lambda d: d[['_t1','_t2','_t3','_t4','_t5', '_t6', '_t7', '_t8']].notna().all(axis = 1),
        has_back = lambda d: d[['_b1','_b2','_b3','_b4','_b5', '_b6', '_b7', '_b8']].notna().all(axis = 1),  
    )\
    .assign(
        tok_roll = lambda d: d['token'].fillna('') + d['_t1'].fillna('') + d['_t2'].fillna('') + d['_t3'].fillna('') + d['_t4'].fillna('') +
            d['_t5'].fillna('')  + d['_t6'].fillna('')  + d['_t7'].fillna('')  + d['_t8'].fillna(''),
        tok_back = lambda d: d['_b8'].fillna('') + d['_b7'].fillna('') + d['_b6'].fillna('') + d['_b5'].fillna('') + d['_b4'].fillna('') +
            d['_b3'].fillna('') + d['_b2'].fillna('') + d['_b1'].fillna('') + d['token'].fillna('')
    )\
    .drop(columns=['_t1','_t2','_t3','_t4', '_t5','_t6', '_t7', '_t8', '_b1','_b2','_b3','_b4','_b5', '_b6', '_b7', '_b8'])\
    .pipe(lambda d: d.join(
        pd.concat(
            [
                (
                    (d['has_roll'] & d['tok_roll'].apply(lambda s, t=t: s in t)) |
                    (d['has_back'] & d['tok_back'].apply(lambda s, t=t: s in t))
                    # d['tok_roll'].apply(lambda s, t=t: bool(s) and (s in t)) |
                    # d['tok_back'].apply(lambda s, t=t: bool(s) and (s in t))
                ).rename(f'hit_p{i}')
                for i, t in enumerate(base_messages)
            ],
            axis = 1
        )
    ))\
    .assign(
        base_message_ix = lambda d: np.select(
            [d[f'hit_p{i}'] for i in range(len(base_messages))],
            list(range(len(base_messages))),
            default = None
        ),
        ambiguous = lambda d: d[[f'hit_p{i}' for i in range(len(base_messages))]].sum(axis=1) > 1,
        # base_name = lambda d: d['base_ix'].map({i: f"p{i}" for i in range(len(base_messages))})
    )\
    .drop(columns = [f'hit_p{i}' for i in range(len(base_messages))])\
    .drop(columns = ['tok_roll', 'tok_back', 'ambiguous', 'batch_ix', 'sequence_ix', 'token_ix'])\
    [[
        'sample_ix',
        'prompt_ix', 'prompt_key', 'prompt',
        'token_in_prompt_ix',
        'base_message_ix',
        'token_id',
        'token',
        'output_id',
        'output_prob' ,
        # 'tok_roll',
        # 'tok_back',
        # 'ambiguous'
    ]]

display(test_sample_df_labeled)

### Project into rolespace

In [None]:
"""
Now project these test samples into rolespace, then flatten - result should be len(test_sample_df_labeled) * # roles
"""
project_test_sample_df = test_sample_df_labeled
project_hs_cp = cupy.asarray(test_pre_mlp_hs[test_layer][project_test_sample_df['sample_ix'].tolist(), :])
project_probs = lr_model.predict_proba(project_hs_cp).round(6)

# Merge seq probs withto get sampel_ix
proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = list(label_to_id.keys())).clip(1e-6)
if len(proj_results) != len(project_test_sample_df):
    raise Exception("Error!")

role_level_df = pd.concat([proj_results, project_test_sample_df[['sample_ix']]], axis = 1)

role_level_df =\
    role_level_df.melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
    .merge(
        project_test_sample_df[['sample_ix', 'prompt_ix', 'prompt_key', 'prompt', 'token_in_prompt_ix', 'base_message_ix', 'token']],
        on = 'sample_ix',
        how = 'inner'
    )\
    .merge(
        pd.DataFrame({'base_message_ix': list(range(len(base_message_types))), 'base_message_type': base_message_types}),
        on = 'base_message_ix',
        how = 'left'
    )\
    .assign(base_message_type = lambda df: df['base_message_type'].fillna('other'))

role_level_df

### Plot and visualize

In [None]:
"""
Plot tests
"""
facet_order = list(label_to_id.keys())
all_message_types = role_level_df['base_message_type'].unique()
color_map = {msg_type: px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)] for i, msg_type in enumerate(all_message_types)}

smooth_df =\
    role_level_df\
    .sort_values('sample_ix')\
    .assign(
        prob_sma = lambda df: df.groupby(['prompt_ix', 'role_space'])['prob'].rolling(window = 2, min_periods = 1).mean().reset_index(level = [0, 1], drop = True),
        prob_ewma = lambda df: df.groupby(['prompt_ix', 'role_space'])['prob'].ewm(alpha = .9, min_periods = 1).mean().reset_index(level = [0, 1], drop = True)
    )

display(smooth_df)

for this_prompt_key in smooth_df['prompt_key'].unique().tolist():
    this_df = smooth_df.pipe(lambda df: df[df['prompt_key'] == this_prompt_key])
    fig = px.scatter(
        this_df, x = 'token_in_prompt_ix', y = 'prob',
        facet_row = 'role_space',
        color = 'base_message_type',
        color_discrete_map = color_map,
        category_orders = {
            'role_space': facet_order,
            'base_message_type': ['system', 'user', 'cot', 'assistant-final', 'other']
        },
        hover_name = 'token',
        hover_data = {
            'prob': ':.3f'
        },
        # markers = True,
        title = f'prompt = {this_prompt_key}',
        labels = {
            'token_in_prompt_ix': 'Token Index',
            'prob': 'Prob',
            'prob_smoothed': 'Smoothed Prob',
            'role_space': 'role'
        }
    ).update_yaxes(
        range = [0, 1],
        side = 'left'
    ).update_layout(height = 600, width = 800)

    def pretty(a):
        a.update(
            text = a.text.split("=")[-1],
            x = 0.5, xanchor = "center",
            y = a.y + 0.08,
            textangle = 0,
            font = dict(size = 12),
            bgcolor = 'rgba(255, 255, 255, 0.9)',  # light strip look
            showarrow = False
        )

    fig.for_each_annotation(pretty)
    fig.update_yaxes(title_text = None)
    fig.show()

    fig.write_image(f"{ws}/experiments/role-confusion/exports/{model_prefix}_{this_prompt_key}.png")