In [None]:
"""
Test role confusion with hidden states
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import cupy
import cuml
import importlib
import gc
import pickle
import os
from tqdm import tqdm
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

from utils.memory import check_memory, clear_all_cuda_memory
from utils.role_assignments import label_content_roles
from utils.store_outputs import convert_outputs_to_df_fast

main_device = 'cuda:0'
seed = 1234

clear_all_cuda_memory()
check_memory()

ws = '/workspace/deliberative-alignment-jailbreaks'

## Load models & data

### Load model

In [None]:
"""
Load the base tokenizer/model
"""
selected_model_index = 1

def get_model(index):
    # HF model ID, model prefix, model architecture,  attn implementation, whether to use hf lib implementation
    models = {
        0: ('openai/gpt-oss-120b', 'gptoss120', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 36), # Will load experts in MXFP4 if triton kernels installed
        1: ('openai/gpt-oss-20b', 'gptoss20', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 24),
        2: ('Qwen/Qwen3-30B-A3B-Thinking-2507', 'qwen3moe', 'qwen3moe', None, True, 48)
    }
    return models[index]

def load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
    model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir = cache_dir, dtype = torch.bfloat16, trust_remote_code = not model_use_hf, device_map = 'auto', attn_implementation = model_attn).eval()
    return tokenizer, model

model_id, model_prefix, model_architecture, model_attn, model_use_hf, model_n_layers = get_model(selected_model_index)
tokenizer, model = load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf)

### Load c4 role activations

In [None]:
"""
Load dataset
"""
def load_data(folder_path, model_prefix, max_data_files):
    """
    Load data saved by `export-activations.ipynb`
    """
    folders = [f'{ws}/experiments/role-confusion/{folder_path}/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)    

    with open(f'{ws}/experiments/role-confusion/{folder_path}/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

sample_df_import, topk_df_import, all_pre_mlp_hs_import, act_map = load_data('activations', model_prefix, 10)

In [None]:
import importlib
import utils.role_assignments
importlib.reload(utils.role_assignments)
label_content_roles = utils.role_assignments.label_content_roles

In [None]:
"""
Test the valid role assignments function label_content_roles
"""
s = tokenizer.apply_chat_template(
    [
        {'role': 'system', 'content': 'Test.'},
        {'role': 'user', 'content': 'Hi! I am a dog and I like to bark'},
        {'role': 'assistant', 'content': 'I am a dog and I like to bark'}
    ],
    tokenize = False,
    padding = 'max_length',
    truncation = True,
    max_length = 512,
    add_generation_prompt = False
)

z =\
    pd.DataFrame({'input_ids': tokenizer(s)['input_ids']})\
    .assign(token = lambda df: tokenizer.convert_ids_to_tokens(df['input_ids']))\
    .assign(prompt_ix = 0, batch_ix = 0, sequence_ix = 0, token_ix = lambda df: df.index)

label_content_roles(model_architecture, z).tail(50)

In [None]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first. We'll also flag the role wrapper tokens. Drop nothing. 
"""
def clean_mappings(model_architecture, sample_df_import, topk_df_import):

    # We'll read the role mappings the raw way, instead of just using the input_mappings file. Then we can re-use the same logic for later on.
    # Technically this could be done easily with the input_mappings file (.merge(input_mappings[['prompt_ix', 'question_ix', 'role']], how = 'left', on = ['prompt_ix'])\),
    # but this is more robust, and doesn't assume one role per sequence/prompt_ix.
    # input_mappings = pd.read_csv(f'{folder_path}/{model_prefix}/input_mappings.csv')

    sample_df_raw =\
        label_content_roles(model_architecture, sample_df_import)\
        .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
        .assign(token_in_seq_ix = lambda df: df.groupby(['prompt_ix']).cumcount())\
        .assign(token_in_role_ix = lambda df: df.groupby(['prompt_ix', 'role']).cumcount())

    c4_topk_df =\
        topk_df_import\
        .merge(sample_df_raw[['sample_ix', 'prompt_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
        .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])

    c4_sample_df =\
        sample_df_raw\
        .drop(columns = ['batch_ix', 'sequence_ix'])

    return c4_sample_df, c4_topk_df

sample_df, topk_df = clean_mappings(model_architecture, sample_df_import, topk_df_import)
del sample_df_import, topk_df_import

gc.collect()
display(sample_df)
display(topk_df)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + layer-wise dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

## Role spaces (logistic)

In [None]:
"""
Run logistic regression on c4 to get a model for role predictions
"""
def run_lr(x_cp, y_cp):
    x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.2, random_state = 123)
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', C = 0.0001, max_iter = 1000, fit_intercept = True)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    return lr_model, accuracy

test_layer = round(model_n_layers * .75)
print(f"Test layer: {test_layer}")
label_to_id = {'system': 0, 'user': 1, 'assistant-cot': 2, 'assistant-final': 3}
# label_to_id = {'system': 0, 'user': 1, 'assistant-cot': 2, 'assistant-final': 3, 'tool': 4}
id_to_label = {v: k for k, v in label_to_id.items()}

valid_sample_ix =\
    sample_df\
    .pipe(lambda df: df[(df['in_content_span'] == True) & (df['role'].notna()) & (df['role'].isin(list(label_to_id.keys())))])\
    ['sample_ix'].tolist()\

role_labels = [
    label_to_id[role]
    for role in sample_df[sample_df['sample_ix'].isin(valid_sample_ix)]['role'].tolist()
]

role_labels_cp = cupy.asarray(role_labels)
x_cp = cupy.asarray(all_pre_mlp_hs[test_layer][valid_sample_ix, :].to(torch.float16).detach().cpu())
lr_model, test_acc = run_lr(x_cp, role_labels_cp)

print(test_acc)
lr_model

In [None]:
"""
We'll test some projections of different sequences
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

base_messages = [
    # User-like question
    "In a few sentences, how do I grow tomatoes indoors?",
    # Assistant-CoT style
    "We need to answer: In a few sentences, how to grow tomatoes indoors. Provide concise steps: choose a container, potting mix, sun or grow lights, watering, fertilizing, staking, pollination. Should mention location, temperature, etc. Should be short. Let's produce ~5-6 sentences.",
    # Assistant-final style
    "Use a 30–35 cm (12‑14 in) pot with durable drainage and fill it with a high‑quality, well‑draining potting mix (½ soil, ¼ perlite, ¼ peat or coconut coir). Plant a small‑seeding or a pre‑grown dwarf or determinate tomato cone, and keep the pot in a south‑facing windowsill or, better, under 12–16 h/day of bright LED or HPS grow lights at 50–80 cm distance. Maintain 18–24 °C (65–75 °F) day temperatures, 12–15 °C night temperatures, and water when the top 1 cm of soil feels dry, using a balanced 10‑20‑10 NPK fertilizer every 2–3 weeks. Space the plant with a stake or cage, trim suckers, and to aid pollination, gently shake or use a small fan to keep air moving. Once fruit sets, keep the plant in a warm, well‑ventilated area, watching for blossom end rot (ensure adequate calcium) and harvesting as tomatoes ripen.",
    # User-like question
    "Oh, thanks!! Is it easier to grow them indoors or outdoors?"
]

base_prompt = f"{base_messages[0]}\n{base_messages[1]}\n{base_messages[2]}\n{base_messages[3]}"

prompts = {}
if model_architecture == 'gptoss':
    prompts[0] = f"You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2025-10-06\n\nReasoning: medium\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.\n{base_prompt}"
else:
    prompts[0] = base_prompt
    
prompts[1] = tokenizer.apply_chat_template(
    [{'role': 'user', 'content': base_prompt}],
    tokenize = False, add_generation_prompt = False
)

prompts[2] = tokenizer.apply_chat_template(
    [{'role': 'assistant', 'content': base_prompt}],
    tokenize = False, add_generation_prompt = False
)

if model_architecture == 'gptoss':
    prompts[3] = tokenizer.apply_chat_template(
        [
            {'role': 'user', 'content': base_messages[0]},
            {'role': 'assistant', 'content': f"<think>\n{base_messages[1]}\n</think>\n{base_messages[2]}"},
            {'role': 'user', 'content': base_messages[3]}
        ],
        tokenize = False, add_generation_prompt = False
    )
elif model_architecture == 'qwen3moe':
    prompts[3] = tokenizer.apply_chat_template(
        [
            {'role': 'user', 'content': base_messages[0]},
            {'role': 'assistant', 'content': f"<think>{base_messages[1]}</think>{base_messages[2]}"},
            {'role': 'user', 'content': base_messages[3]}
        ],

    )

test_seqs =\
    pd.DataFrame({'prompt': [p for _, p in prompts.items()]})\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))

# Create and chunk into lists of size 500 each - these will be the export breaks
test_dl = DataLoader(
    ReconstructableTextDataset(test_seqs['prompt'].tolist(), tokenizer, max_length = 512, prompt_ix = test_seqs['prompt_ix'].tolist()),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

In [None]:
"""
Run forward passes
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

@torch.no_grad()
def run_and_export_states(model, dl: ReconstructableTextDataset, layers_to_keep_acts: list[int]):
    """
    Run forward passes on given model and store the decomposed sample_df plus hidden states

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits`, `all_topk_experts`, `all_topk_weights`,
          `all_pre_mlp_hidden_states`, and `all_expert_outputs`.
        @dl: A ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_pre_mlp_hidden_states` (see returned object description).

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `all_pre_mlp_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.
    """
    all_pre_mlp_hidden_states = []
    sample_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']
        prompt_indices = batch['prompt_ix']

        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

        # Check no bugs by validating output/perplexity
        if batch == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
            for i in range(min(20, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
                
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )
                
        prompt_indices_df = pd.DataFrame(
            [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
            columns = ['sequence_ix', 'prompt_ix']
        )

        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
            .assign(batch_ix = batch_ix)
        
        sample_dfs.append(sample_df)

        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_pre_mlp_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

        batch_ix += 1

    sample_df = pd.concat(sample_dfs, ignore_index = True)
    all_pre_mlp_hidden_states = torch.cat(all_pre_mlp_hidden_states, dim = 0)

    return {
        'sample_df': sample_df,
        'all_pre_mlp_hs': all_pre_mlp_hidden_states
    }

test_outputs = run_and_export_states(
    model,
    test_dl,
    layers_to_keep_acts = list(range(model_n_layers))
)

In [None]:
"""
Prepare sample df
"""
test_sample_df =\
    test_outputs['sample_df'].merge(test_seqs, on = 'prompt_ix', how = 'inner')\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(token_in_seq_ix = lambda df: df.groupby(['prompt_ix']).cumcount())#\
    # .assign(token_in_role_ix = lambda df: df.groupby(['prompt_ix', 'role']).cumcount())

display(test_sample_df)

test_pre_mlp_hs = test_outputs['all_pre_mlp_hs'].to(torch.float16)
test_pre_mlp_hs = {layer_ix: test_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

In [None]:
test_sample_df_labeled =\
    test_sample_df\
    .sort_values(['prompt_ix', 'token_ix'])\
    .assign(
        _t1 = lambda d: d.groupby('prompt_ix')['token'].shift(-1),
        _t2 = lambda d: d.groupby('prompt_ix')['token'].shift(-2),
        _t3 = lambda d: d.groupby('prompt_ix')['token'].shift(-3),
        _t4 = lambda d: d.groupby('prompt_ix')['token'].shift(-4),
        _t5 = lambda d: d.groupby('prompt_ix')['token'].shift(-5),
        _b1 = lambda d: d.groupby('prompt_ix')['token'].shift(1),
        _b2 = lambda d: d.groupby('prompt_ix')['token'].shift(2),
        _b3 = lambda d: d.groupby('prompt_ix')['token'].shift(3),
        _b4 = lambda d: d.groupby('prompt_ix')['token'].shift(4),
        _b5 = lambda d: d.groupby('prompt_ix')['token'].shift(5),
    )\
    .assign(
        tok_roll = lambda d: d['token'].fillna('') + d['_t1'].fillna('') + d['_t2'].fillna('') + d['_t3'].fillna('') + d['_t4'].fillna('') + d['_t5'].fillna(''),
        tok_back = lambda d: d['_b5'].fillna('') + d['_b4'].fillna('') + d['_b3'].fillna('') + d['_b2'].fillna('') + d['_b1'].fillna('') + d['token'].fillna('')
    )\
    .drop(columns=['_t1','_t2','_t3','_t4','_t5','_b1','_b2','_b3','_b4','_b5'])\
    .pipe(lambda d: d.join(
        pd.concat(
            [
                (
                    d['tok_roll'].apply(lambda s, t=t: bool(s) and (s in t)) |
                    d['tok_back'].apply(lambda s, t=t: bool(s) and (s in t))
                ).rename(f'hit_p{i}')
                for i, t in enumerate(base_messages)
            ],
            axis=1
        )
    ))\
    .assign(
        base_ix = lambda d: np.select(
            [d[f'hit_p{i}'] for i in range(len(base_messages))],
            list(range(len(base_messages))),
            default=np.nan
        ),
        ambiguous = lambda d: d[[f'hit_p{i}' for i in range(len(base_messages))]].sum(axis=1) > 1,
        base_name = lambda d: d['base_ix'].map({i: f"p{i}" for i in range(len(base_messages))})
    )\
    .drop(columns=[f'hit_p{i}' for i in range(len(base_messages))])\
    .drop(columns=['tok_roll', 'tok_back'])

In [None]:
"""
Project into rolespace - all
"""
project_test_sample_df = test_sample_df_labeled
project_hs_cp = cupy.asarray(test_pre_mlp_hs[test_layer][project_test_sample_df['sample_ix'].tolist(), :])
project_probs = lr_model.predict_proba(project_hs_cp).round(6)

# Merge seq probs with inputs
plot_df =\
    pd.concat(
        [
            pd.DataFrame(cupy.asnumpy(project_probs), columns = list(label_to_id.keys())).clip(1e-6),
            project_test_sample_df[['sample_ix', 'prompt_ix', 'token', 'token_in_seq_ix', 'base_name']]
        ],
        axis = 1
    )

plot_df = plot_df.melt(id_vars = ['sample_ix', 'prompt_ix', 'token', 'token_in_seq_ix', 'base_name'], var_name = 'role_space', value_name = 'prob')

plot_df

In [None]:
"""
Plot tests
"""
facet_order = list(label_to_id.keys())

smooth_df =\
    plot_df\
    .sort_values('sample_ix')\
    .assign(
        prob_sma = lambda df: df.groupby(['prompt_ix', 'role_space'])['prob'].rolling(window = 2, min_periods = 1).mean().reset_index(level = [0, 1], drop = True),
        prob_ewma = lambda df: df.groupby(['prompt_ix', 'role_space'])['prob'].ewm(alpha = .9, min_periods = 1).mean().reset_index(level = [0, 1], drop = True)
    )

for prompt_ix in smooth_df['prompt_ix'].unique().tolist():
    this_df = smooth_df.pipe(lambda df: df[df['prompt_ix'] == prompt_ix])
    fig = px.scatter(
        this_df, x = 'sample_ix', y = 'prob',
        facet_row = 'role_space',
        color = this_df['base_name'].astype(str),
        category_orders = {'role_space': facet_order},
        hover_name = 'token',
        hover_data = {
            'prob': ':.3f'
        },
        # markers = True,
        title = f'prompt_ix = {str(prompt_ix)}',
        labels = {
            'sample_ix': 'Sample',
            'prob': 'Prob',
            'prob_smoothed': 'Smoothed Prob',
            'role_space': 'role_space'
        }
    ).update_yaxes(
        range = [0, 1],
        side = 'left'
    ).update_layout(height = 600, width = 800)

    def pretty(a):
        a.update(
            text = a.text.split("=")[-1],
            x = 0.5, xanchor = "center",
            y = a.y + 0.08,
            textangle = 0,
            font = dict(size = 12),
            bgcolor = 'rgba(255, 255, 255, 0.9)',  # light strip look
            showarrow = False
        )

    fig.for_each_annotation(pretty)
    fig.update_yaxes(title_text = None)
    fig.show()

In [None]:
smooth_df

In [None]:
smooth_df['prompt_ix'].unique().tolist()

In [None]:
    plot_df\
    .sort_values('sample_ix')\
    .assign(probs_smoothed = lambda df: df.groupby(['prompt_ix', 'role_space'])['prob'].rolling(window_size, min_periods = 1).mean().reset_index())


In [None]:
"""
Get averages across roles
"""
plot_df =redteam_seqs_df.pipe(lambda df: df[df['output_class'].isin(['REFUSAL', 'HARMFUL_RESPONSE'])]).groupby(['redteam_prompt_type', 'output_class']).agg(
    count = ('redteam_seq_ix', 'count')
).reset_index()\
    .assign(redteam_prompt_type = lambda df: np.where(df['redteam_prompt_type'] == 'synthetic_policy', 'CoT Forgery', df['redteam_prompt_type']))\
    .assign(redteam_prompt_type = lambda df: np.where(df['redteam_prompt_type'] == 'destyled_policy', 'Destyled CoT Forgery', df['redteam_prompt_type']))\
    .assign(redteam_prompt_type = lambda df: np.where(df['redteam_prompt_type'] == 'base', 'Base', df['redteam_prompt_type']))



# Convert counts to percentages within each prompt type
df_percent = plot_df.copy()
df_percent['percent'] = df_percent.groupby('redteam_prompt_type')['count'].transform(lambda x: 100 * x / x.sum())

# Plot percentages instead of counts
fig = px.bar(
    df_percent,
    x='redteam_prompt_type',
    y='percent',
    width = 600,
    color='output_class',
    barmode='group',
    labels={
        'redteam_prompt_type': 'Prompt Type',
        'percent': 'Percentage (%)',
        'output_class': 'Output Class'
    },
    title='Percentage Distribution of Output Classes by Redteam Prompt Type'
)

fig.update_layout(template='plotly_white', yaxis=dict(ticksuffix='%'))
fig.show()

In [None]:
"""
Project redteam activations into roles - start with just one sequence
"""
# Test sequences - get all variants
test_variants = redteam_seqs_df#.pipe(lambda df: df[df['harm_index'] == redteam_seqs_df['harm_index'].tolist()[19]])

test_variant_results = []
plot_dfs = []
for row in tqdm(test_variants.to_records('dict')[0:10]):
    seq_samples = redteam_sample_df.pipe(lambda df: df[df['seq_id'] == row['redteam_seq_ix']]).pipe(lambda df: df[df['role'].notna()]).reset_index(drop = True)

    seq_hs_cp = cupy.asarray(redteam_all_pre_mlp_hs[test_layer][seq_samples['sample_ix'].tolist(), :])

    seq_probs = lr_model.predict_proba(seq_hs_cp).round(6)

    # Merge seq probs with inputs
    plot_df =\
        pd.concat(
            [
                pd.DataFrame(cupy.asnumpy(seq_probs), columns = ['user', 'assistant', 'cot', 'system']).clip(1e-6),
                seq_samples[['token_in_seq_ix', 'token', 'role']]
            ],
            axis = 1
        )\
        .assign(
            variant = row['redteam_prompt_type'],
            output_class = row['output_class']
        )
    
    plot_dfs.append(plot_df)


In [None]:
len(plot_dfs)

In [None]:
plot_df = pd.concat(plot_dfs)\
    .pipe(lambda df: df[df['token_in_seq_ix'] <= 1_000])\
    .groupby(['token_in_seq_ix', 'role', 'variant'], as_index = False).agg(
        avg_cotness = ('cot', 'mean')
    )\
    .groupby(['variant', 'role'], as_index = False)\
    .agg(
        avg_cotness = ('avg_cotness', 'mean')
    )

# Pivot to Variant x Role matrix and convert to %
variant_map = {
    "synthetic_policy": "CoT forgery",
    "destyled_policy": "Destyled CoT forgery",
    "base": "Base",
}
variant_order = ["Base", "CoT forgery", "Destyled CoT forgery"]
role_order = ["system", "user", "cot", "assistant"]

df = plot_df.copy()
df["variant_label"] = df["variant"].map(variant_map).fillna(df["variant"])

# Pivot to Variant x Role and convert to %
tbl = (df.pivot_table(index="variant_label", columns="role",
                      values="avg_cotness", aggfunc="mean")
         .reindex(variant_order)
         .reindex(columns=role_order))
tbl_pct = (tbl * 100).round(1)

# --- Lighter color scale for cells ---
scale = px.colors.sequential.Blues[:6]  # use lighter half
min_v, max_v = np.nanmin(tbl_pct.values), np.nanmax(tbl_pct.values)

def val_to_color(v):
    if pd.isna(v) or max_v == min_v:
        return "#ffffff"
    t = (v - min_v) / (max_v - min_v)
    idx = int(round(t * (len(scale) - 1)))
    return scale[idx]

# Build column-wise fill colors (first column is the row header)
fill_colors = []
fill_colors.append(["#ffffff"] * len(tbl_pct))  # first column (Variant)
for col in tbl_pct.columns:
    fill_colors.append([val_to_color(v) for v in tbl_pct[col].tolist()])

# Values for display
cell_values = [tbl_pct.index.tolist()] + [
    [f"{v:.1f}%" if pd.notna(v) else "" for v in tbl_pct[col]] for col in tbl_pct.columns
]

# --- Plotly Table (narrower width, lighter header bg) ---
fig = go.Figure(data=[go.Table(
    columnwidth=[170] + [110] * len(tbl_pct.columns),  # make table less wide
    header=dict(
        values=["Variant"] + [c.title() for c in tbl_pct.columns],
        fill_color="#eef3fb",         # lighter header background
        font=dict(color="#222", size=12),
        align="left",
        height=30
    ),
    cells=dict(
        values=cell_values,
        fill_color=fill_colors,
        align="left",
        height=26,
        font=dict(size=11, color="#222")
    )
)])

fig.update_layout(
    title="Avg CoTness by Variant × Role",
    width=680,  # narrower figure
    margin=dict(l=10, r=10, t=50, b=10)
)

fig.show()


## Role spaces (LDA)

In [None]:
"""
Split by roles
"""
c4_role_hs = {}
c4_role_sample_dfs = {}
c4_role_topk_dfs = {}

for role in ['system', 'user', 'cot', 'assistant']:
    c4_role_sample_dfs[role] = c4_sample_df.pipe(lambda df: df[df['role'] == role]).pipe(lambda df: df[df['is_role_wrapper'] == 0])
    c4_role_topk_dfs[role] = c4_topk_df[c4_topk_df['sample_ix'].isin(c4_role_sample_dfs[role]['sample_ix'].tolist())]

    # Get corresponding hidden states + cast to f32 for linalg needed for lda
    c4_role_hs[role] = {l: c4_all_pre_mlp_hs[l][c4_role_sample_dfs[role]['sample_ix'].tolist(), :].float() for l in c4_all_pre_mlp_hs.keys()}
    print(c4_role_hs[role][0].shape)

In [None]:
def fit_role_plane(states: dict[str, list[torch.Tensor]], target_layer: int = 12, ridge_alpha: float = 1e-3):
    """
    Params
        @states: dict with keys {'system', 'user', 'cot', 'assistant'}. Each value is a list over layers; states[r][L] has shape (T, D) with matched token ordering across roles.
        @target_layer: layer to use for LDA.
        @ridge_alpha: small diagonal added to the pooled within-class covariance.
    """
    roles = ['system', 'user', 'cot', 'assistant']
    for r in roles:
        assert r in states, f"Missing role: {r}"

    X = {r: states[r][target_layer].to(dtype=torch.float32) for r in roles}  # (T, D)
    T, D = X['assistant'].shape
    device = X['assistant'].device

    # 1) Content-centering per token across roles (removes shared content identity)
    M_token = torch.stack([X[r] for r in roles], dim=0).mean(dim=0)            # (T, D)
    Xc = {r: X[r] - M_token for r in roles}

    # 2) Pooled within-class covariance (on centered data)
    mu = {r: Xc[r].mean(dim=0) for r in roles}                                  # (D,)
    residuals = torch.cat([Xc[r] - mu[r] for r in roles], dim=0)                # (4T, D)
    cov = (residuals.T @ residuals) / max(1, residuals.shape[0] - len(roles))   # (D, D)
    lam = ridge_alpha * (cov.trace() / D)
    cov = cov + lam * torch.eye(D, device=device, dtype=cov.dtype)

    # 3) Whitening transform: W such that W @ cov @ W^T = I
    evals, evecs = torch.linalg.eigh(cov)                                       # ascending
    inv_sqrt = torch.diag(torch.rsqrt(evals.clamp_min(1e-12)))
    W = (evecs @ inv_sqrt @ evecs.T).to(torch.float32)                           # (D, D)

    # 4) Whitened class centroids
    mu_w = {r: (Xc[r] @ W.T).mean(dim=0) for r in roles}                        # (D,)

    # 5) Define Role Plane axes in whitened space
    # Speaker: (assistant-like) vs (non-assistant)
    speaker = (mu_w['assistant'] + mu_w['cot'] - mu_w['user'] - mu_w['system']) / 2.0
    speaker = speaker / (speaker.norm() + 1e-12)

    # CoT: assistant-cot vs assistant-out, orthogonalized to speaker
    cot = mu_w['cot'] - mu_w['assistant']
    cot = cot - (cot @ speaker) * speaker
    cot = cot / (cot.norm() + 1e-12)

    P_role = torch.stack([speaker, cot], dim=0)                                  # (2, D)

    # 6) Fix axis signs for interpretability (assistant has +speaker; cot>assistant on +cot)
    asst_coords = mu_w['assistant'] @ P_role.T
    cot_coords  = mu_w['cot']       @ P_role.T
    if asst_coords[0] < 0:  # speaker axis
        P_role[0] = -P_role[0]
    if (cot_coords - asst_coords)[1] < 0:  # cot axis
        P_role[1] = -P_role[1]

    # 7) Axis standardization stats (z-scores) computed on training tokens
    Z_all = torch.cat([(Xc[r] @ W.T) @ P_role.T for r in roles], dim=0)          # (4T, 2)
    axis_mean = Z_all.mean(dim=0)
    axis_std  = Z_all.std(dim=0).clamp_min(1e-6)

    # Global mean as a fallback for unpaired inference (when M_token is unavailable)
    M_global = M_token.mean(dim=0)                                              # (D,)

    return {
        'roles': roles,
        'W': W,                   # (D, D) whitening
        'P_role': P_role,         # (2, D) rows: [speaker_axis, cot_axis]
        'mu_w': mu_w,             # role centroids in whitened space
        'axis_mean': axis_mean,   # (2,)
        'axis_std': axis_std,     # (2,)
        'M_global': M_global,     # (D,)
        'dtype': torch.float32,
        'device': device,
    }
    
    
def project_role_plane(x: torch.Tensor, rp: dict, content_mean: torch.Tensor | None = None):
    """
    Project hidden states x (..., D) into the Role Plane.

    content_mean: if you have the *paired* per-token mean across roles (M_token[i]), pass it here for precise content-centering. Otherwise we use rp['M_global'].

    Returns:
      coords: (..., 2) raw coordinates along (speaker, cot)
      zcoords: (..., 2) standardized z-scores along (speaker, cot)
    """
    W, P = rp['W'], rp['P_role']
    base = content_mean if content_mean is not None else rp['M_global']
    z = (x - base) @ W.T                     # whitened
    coords = z @ P.T                         # (..., 2)
    zcoords = (coords - rp['axis_mean']) / rp['axis_std']
    return coords, zcoords


lda = fit_role_plane(c4_role_hs, target_layer = test_layer, ridge_alpha = 1e-2)

In [None]:
test_variants = redteam_seqs_df.pipe(lambda df: df[df['harm_index'] == redteam_seqs_df['harm_index'].tolist()[0]])

test_variant_results = []
for row in test_variants.to_records('dict'):
    seq_samples = redteam_sample_df.pipe(lambda df: df[df['seq_id'] == row['redteam_seq_ix']]).pipe(lambda df: df[df['role'].notna()]).reset_index(drop = True)

    seq_hs_pt = torch.Tensor(redteam_all_pre_mlp_hs[test_layer][seq_samples['sample_ix'].tolist(), :]) # (N, D)
    print(seq_hs_pt.shape)
    
    coords, zcoords = project_role_plane(seq_hs_pt, lda)

    # Merge seq probs with inputs
    plot_df =\
        pd.concat(
            [
                pd.DataFrame({
                    'speaker_axis': coords[:, 0],
                    'channel_axis': coords[:, 1]
                }).clip(1e-6),
                seq_samples[['token_in_seq_ix', 'token', 'role']]
            ],
            axis = 1
        )

    test_variant_results.append({
        'variant': row['redteam_prompt_type'],
        'output_class': row['output_class'],
        'plot_df': plot_df
    })

test_variant_results[0]

In [None]:
def _prep_df(df: pd.DataFrame, smoothing: int = 25) -> pd.DataFrame:
    df = df.copy()
    # Expected columns: token_in_seq_ix, token, speaker_axis, channel_axis, role
    df = df.sort_values("token_in_seq_ix").reset_index(drop=True)

    # Rolling means (centered) for smoother overlay
    win = max(1, int(smoothing))
    minp = max(1, win // 2)
    df["speaker_smooth"] = (
        df["speaker_axis"].rolling(win, min_periods=minp, center=True).mean()
    )
    df["channel_smooth"] = (
        df["channel_axis"].rolling(win, min_periods=minp, center=True).mean()
    )

    # Role boundaries (indices where the role changes)
    role_change_mask = df["role"].ne(df["role"].shift(1))
    df["_role_change"] = role_change_mask
    df["_role_start_ix"] = df.loc[role_change_mask, "token_in_seq_ix"]
    return df

def _add_role_boundaries(fig: go.Figure, df: pd.DataFrame) -> None:
    # Draw a vertical line at each role change (skip the very first token)
    change_ix = df.loc[df["_role_change"], "token_in_seq_ix"].tolist()
    for x in change_ix[1:]:
        fig.add_vline(x=x, line_dash="dash", opacity=0.25)

def _auto_symmetric_range(series_list):
    # Symmetric y-range around 0 so the zero-line is centered
    vals = np.concatenate([np.asarray(s.dropna().values) for s in series_list if s is not None])
    if vals.size == 0:
        return None
    m = float(np.nanmax(np.abs(vals)))
    if m == 0 or np.isnan(m):
        return None
    pad = 0.05 * m
    return [-m - pad, m + pad]

def make_role_plane_figures(
    df: pd.DataFrame,
    smoothing: int = 25,
    show_rangeslider: bool = True,
    scatter_opacity: float = 0.65,
):
    """
    Returns three Plotly figures:
      (1) fig_speaker_timeline : token index vs speaker_axis (+ rolling mean)
      (2) fig_channel_timeline : token index vs channel_axis (+ rolling mean)
      (3) fig_role_plane       : 2D scatter of speaker_axis vs channel_axis
    """
    df = _prep_df(df, smoothing)

    # ---- 1) Speaker axis over time ----
    fig_speaker = px.scatter(
        df,
        x="token_in_seq_ix",
        y="speaker_axis",
        color="role",
        hover_data={
            "token_in_seq_ix": True,
            "token": True,
            "speaker_axis": ":.3f",
            "channel_axis": ":.3f",
            "role": True,
        },
        opacity=scatter_opacity,
        category_orders={"role": ["system", "user", "cot", "assistant"]},
        title="Speaker axis over sequence (assistant‑ish ↔ non‑assistant)",
    )
    fig_speaker.add_trace(
        go.Scatter(
            x=df["token_in_seq_ix"],
            y=df["speaker_smooth"],
            mode="lines",
            name=f"rolling mean ({smoothing})",
        )
    )
    fig_speaker.add_hline(y=0, line_dash="dot")
    _add_role_boundaries(fig_speaker, df)
    yr = _auto_symmetric_range([df["speaker_axis"], df["speaker_smooth"]])
    if yr is not None:
        fig_speaker.update_yaxes(range=yr)
    fig_speaker.update_layout(
        xaxis_title="Token index in sequence",
        yaxis_title="Speaker axis",
        hovermode="x unified",
        legend_title_text="Role",
    )
    if show_rangeslider:
        fig_speaker.update_layout(xaxis_rangeslider_visible=True)

    # ---- 2) Channel/CoT axis over time ----
    fig_channel = px.scatter(
        df,
        x="token_in_seq_ix",
        y="channel_axis",
        color="role",
        hover_data={
            "token_in_seq_ix": True,
            "token": True,
            "speaker_axis": ":.3f",
            "channel_axis": ":.3f",
            "role": True,
        },
        opacity=scatter_opacity,
        category_orders={"role": ["system", "user", "cot", "assistant"]},
        title="CoT / channel axis over sequence (assistant‑CoT ↔ assistant‑final)",
    )
    fig_channel.add_trace(
        go.Scatter(
            x=df["token_in_seq_ix"],
            y=df["channel_smooth"],
            mode="lines",
            name=f"rolling mean ({smoothing})",
        )
    )
    fig_channel.add_hline(y=0, line_dash="dot")
    _add_role_boundaries(fig_channel, df)
    yr = _auto_symmetric_range([df["channel_axis"], df["channel_smooth"]])
    if yr is not None:
        fig_channel.update_yaxes(range=yr)
    fig_channel.update_layout(
        xaxis_title="Token index in sequence",
        yaxis_title="CoT / channel axis",
        hovermode="x unified",
        legend_title_text="Role",
    )
    if show_rangeslider:
        fig_channel.update_layout(xaxis_rangeslider_visible=True)

    # ---- 3) 2‑D Role Plane scatter ----
    fig_plane = px.scatter(
        df,
        x="speaker_axis",
        y="channel_axis",
        color="role",
        hover_data={
            "token_in_seq_ix": True,
            "token": True,
            "speaker_axis": ":.3f",
            "channel_axis": ":.3f",
            "role": True,
        },
        opacity=scatter_opacity,
        category_orders={"role": ["system", "user", "cot", "assistant"]},
        title="Role Plane (x: speaker axis, y: CoT/channel axis)",
    )
    fig_plane.add_vline(x=0, line_dash="dot")
    fig_plane.add_hline(y=0, line_dash="dot")
    fig_plane.update_layout(
        xaxis_title="Speaker axis",
        yaxis_title="CoT / channel axis",
        legend_title_text="Role",
    )

    return fig_speaker, fig_channel, fig_plane

# --- Example usage (uncomment to run) ---
fig_speaker, fig_channel, fig_plane = make_role_plane_figures(test_variant_results[0]['plot_df'], smoothing=25)
fig_speaker.show()
fig_channel.show()
fig_plane.show()

In [None]:
def plot_role_prob_minimal(
    input_df: pd.DataFrame,
    *,
    title: str = '',
    value_col: str = "cot", # one of: 'cot','assistant','user','system'
    rolling_window: int = 35, # used for line mode only
    use_points: bool = False, # True -> markers (raw), False -> line (smoothed)
    export_path: str = None # e.g. "figure.svg" (requires kaleido)
):
    """
    Minimal, publication-ready plot of a single role probability across tokens.

    - Line mode (default): rolling-smoothed curve (clean trend).
    - Points mode: raw per-token probabilities as markers (no smoothing).

    Expects columns: role, token_ix, assistant, cot, user, system.
    """
    allowed = {'system', 'user', 'cot', 'assistant'}
    if value_col not in allowed:
        raise ValueError(f"value_col must be one of {allowed}, got {value_col!r}")

    line_color = {
        "cot": "#0072B2",  # blue
        "assistant": "#D55E00",  # vermillion
        "user": "#009E73",  # bluish green
        "system": "#CC79A7",  # reddish purple
    }[value_col]

    role_bg = {
        "user": "rgba(0,158,115,0.28)",
        "assistant": "rgba(213,94,0,0.24)",
        "cot": "rgba(0,114,178,0.24)",
        "system": "rgba(204,121,167,0.24)",
    }
    role_label = {"user": "USER", "assistant": "ASSISTANT", "cot": "REASONING", "system": "SYSTEM",}
    y_label = {
        "cot": "less ←       CoTness       ⟶ more",
        "assistant": "Assistantness",
        "user": "Userness ",
        "system": "Systemness   P(role = System)",
    }[value_col]

    # Data
    df = input_df.copy().sort_values("token_in_seq_ix", kind="mergesort")
    x = df["token_in_seq_ix"].to_numpy()
    y_raw = df[value_col].to_numpy()
    roles = df["role"].to_numpy()

    # Rolling smoothing for line mode
    if use_points:
        y_plot = y_raw  # Points show raw values
    else:
        y_plot = pd.Series(y_raw).rolling(rolling_window, center=True, min_periods=1).mean().to_numpy()

    # Contiguous role regions (actual roles)
    regions = []
    start = 0
    for i in range(1, len(df)):
        if roles[i] != roles[i-1]:
            regions.append((roles[start], int(x[start]), int(x[i-1])))
            start = i
    regions.append((roles[start], int(x[start]), int(x[-1])))

    # Figure
    fig = go.Figure()

    # Background bands + labels
    for role, x0, x1 in regions:
        fig.add_vrect(
            x0=x0 - 0.5, x1=x1 + 0.5,
            fillcolor=role_bg.get(role, "rgba(200,200,200,0.22)"),
            line_width=0, layer="below"
        )
        if x1 - x0 >= 10:
            fig.add_annotation(
                x=(x0 + x1)/2, y=1.045, xref="x", yref="paper",
                text=role_label.get(role, role),
                showarrow=False,
                font=dict(size=12, color="rgba(0,0,0,0.70)", family="Helvetica, Arial, sans-serif")
            )

    if use_points:
        fig.add_trace(go.Scatter(
            x=x, y=y_plot, mode="markers",
            marker=dict(color=line_color, size=5, opacity=0.85, line=dict(width=0)),
            hoverinfo="skip", showlegend=False
        ))
    else:
        fig.add_trace(go.Scatter(
            x=x, y=y_plot, mode="lines",
            line=dict(color=line_color, width=3.6),
            hoverinfo="skip", showlegend=False
        ))

    # Layout
    fig.update_layout(
        template="simple_white",
        width=1100, height=420,
        margin=dict(l=70, r=30, t=64, b=54),
        title=dict(text=title, x=0.01, y=0.98, xanchor="left", font=dict(size=20, family="Helvetica, Arial, sans-serif")),
        font=dict(size=14, family="Helvetica, Arial, sans-serif"),
        showlegend=False
    )
    fig.update_yaxes(
        # type = 'log',
        # range=[0, 20],
        # dtick=0.25,
        title_text=y_label, gridcolor="rgba(0,0,0,0.06)")
    fig.update_xaxes(title_text="Token index", showgrid=False, zeroline=False)

    if export_path:
        fig.write_image(export_path, scale=2)

    return fig


figs = []
for tests_variant in test_variant_results:
    fig = plot_role_prob_minimal(tests_variant['plot_df'], title = tests_variant['variant'], value_col = 'cot', use_points = True, export_path = None)
    fig.show()


In [None]:
[]