In [None]:
"""
This runs forward passes on samples and stores: (1) pre-MLP activations; (2) top-k expert selections; (3) sample metadata."
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Glm4vForConditionalGeneration, AutoModelForImageTextToText
from transformers.loss.loss_utils import ForCausalLMLoss
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import gc
from termcolor import colored
import importlib
import cupy
import cuml

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_outputs import convert_outputs_to_df_fast
from utils.misc import flatten_list

main_device = 'cuda:0'
seed = 123

ws = '/workspace/deliberative-alignment-jailbreaks'

clear_all_cuda_memory()
check_memory()

# Load base model

In [None]:
"""
Load the base tokenizer/model
"""
selected_model_index = 5

def get_model(index):
    # HF model ID, model prefix, model architecture,  attn implementation, whether to use hf lib implementation, is reasoning
    models = {
        0: ('openai/gpt-oss-20b', 'gptoss20', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 24, True),
        1: ('openai/gpt-oss-120b', 'gptoss120', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 36, True), # Will load experts in MXFP4 if triton kernels installed
        2: ('allenai/Olmo-3-7B-Think', 'olmo3-7', 'olmo3', None, True, 32, True), # OlMo-3
        3: ('allenai/Olmo-3-32B-Think', 'olmo3-32', 'olmo3', None, True, 64, True),
        4: ('zai-org/GLM-4.6V-Flash', 'glm-46v-flash', 'glm46v', None, True, 40, True),
        5: ('ServiceNow-AI/Apriel-1.6-15b-Thinker', 'apriel-16', 'apriel', None, True, 48, True),
        6: ('mistralai/Devstral-Small-2-24B-Instruct-2512', 'mistral3', 'devstral', None, True, 40, False)
        # 3: ('zai-org/GLM-4.5-Air-FP8', 'glm45air', 'glm4moe', None, True, 45), # GLM-4.5 has one dense pre-layer, this is MoE layers
    }
    return models[index]

def load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
    load_params = {'cache_dir': cache_dir, 'dtype': 'auto', 'trust_remote_code': not model_use_hf, 'device_map': None, 'attn_implementation': model_attn}    
    if model_architecture == 'glm46v':
        model = Glm4vForConditionalGeneration.from_pretrained(model_id, **load_params).to(main_device).eval()
    elif model_architecture == 'apriel':
        model = AutoModelForImageTextToText.from_pretrained(model_id, **load_params).to(main_device).eval()
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id, **load_params).to(main_device).eval()

    return tokenizer, model

model_id, model_prefix, model_architecture, model_attn, model_use_hf, model_n_layers, model_is_reasoning = get_model(selected_model_index)
tokenizer, model = load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf)

check_memory()

In [None]:
"""
Some checks for special models (GPT-OSS - check attn + expert precision)
"""
if model_architecture == 'gptoss':
    print(model.model.layers[0].mlp.experts.down_proj)
    print(model.model.config._attn_implementation)

In [None]:
"""
Load reverse-engineered forward pass functions that return topk expert IDs and weights
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

@torch.no_grad()
def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 640).to(model.device)
    original_results = model(**inputs, use_cache = False)
    custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = True)
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), custom_results['logits'].size(-1)).detach().cpu().item()
    print(f"LM loss: {loss}")
    print(f"Hidden states layers (pre-mlp | post-layer): {len(custom_results['all_pre_mlp_hidden_states'])} | {len(custom_results['all_hidden_states'])}")
    print(f"Hidden state size (pre-mlp | post-layer): {(custom_results['all_pre_mlp_hidden_states'][0].shape)} | {(custom_results['all_hidden_states'][0].shape)}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

# Get dataset

In [None]:
"""
Run some IFT tests
"""
import utils.role_templates
importlib.reload(utils.role_templates)
from utils.role_templates import render_single_message, render_mixed_cot

# Test
s = tokenizer.apply_chat_template(
    [
        {'role': 'user', 'content': 'Hi! I am a dog and I like to bark'},
        {'role': 'assistant', 'content': 'Hello! What a lovely dog you are!'}
    ],
    tokenize = False,
    padding = 'max_length',
    truncation = True,
    max_length = 512,
    add_generation_prompt = False
)
print(s)
print('--')
print(render_single_message(model_architecture, role = 'assistant-final', content ='Hi'))
print('--')
print(render_mixed_cot(model_architecture, 'The user...', 'Yes!'))
print('--')

In [None]:
"""
Load dataset - C4 + HPLT2
"""
n_sample_size = 400

def load_raw_ds():

    def get_c4():
        return load_dataset('allenai/c4', 'en', split = 'validation', streaming = True).shuffle(seed = seed, buffer_size = 50_000)
    
    # def get_hplt2():
    #     return load_dataset('HPLT/HPLT2.0_cleaned', 'eng_Latn', split = 'train', streaming = True).shuffle(seed = seed, buffer_size = 50_000)

    def get_data(ds, n_samples, data_source): # en
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append({'text': sample['text'], 'source': data_source})
        return raw_data
    
    return get_data(get_c4(), int(n_sample_size), 'c4-en') #+ get_data(get_hplt2(), int(n_sample_size/2), 'hplt-en')

raw_data = load_raw_ds()

In [None]:
"""
Define opening text to append at the start of each role
"""
import textwrap

if model_architecture == 'gptoss':
    prepend =\
        tokenizer.bos_token +\
        render_single_message(
            model_architecture,
            'system', 
            textwrap.dedent("""
                <|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
                Knowledge cutoff: 2024-06
                Current date: 2026-01-01

                Reasoning: medium

                # Valid channels: analysis, commentary, final. Channel must be included for every message.
                Calls to these tools must go to the commentary channel: 'functions'.<|end|>
            """).strip()
        )

elif model_architecture == 'olmo3':
    prepend =\
        tokenizer.bos_token +\
        render_single_message(
            model_architecture,
            'system', 
            textwrap.dedent(
            """
                You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions>
            """).strip()
        )
    
elif model_architecture in ['glm46v', 'glm4vmoe']:
    prepend =\
        """[gMASK]<sop>"""
        
elif model_architecture == 'apriel':
    prepend =\
        tokenizer.bos_token +\
        render_single_message(
            model_architecture,
            'system',
            textwrap.dedent(
                """
                You are a thoughtful, systematic AI assistant from ServiceNow Language Models (SLAM) lab. Analyze each question carefully, present your reasoning step-by-step, then provide the final response after the marker [BEGIN FINAL RESPONSE].
                """
            ).strip()
        )

else:
    raise ValueError('Missing model architecture')
    
print(prepend)

In [None]:
# """
# Create sample sequences
# """
# truncated_texts = tokenizer.batch_decode(tokenizer([t['text'] for t in raw_data], padding = False, truncation = True, max_length = 512).input_ids)

# def get_sample_seqs(sample_str):
#     sample_seqs = {
#         'system': ('system', sample_str, None),
#         'user': ('user', sample_str, None),
#         'cot': ('assistant-cot', sample_str, None),
#         'assistant': ('assistant-final', sample_str, None),
#         'tool': ('tool', sample_str, None)
#     }
#     return [
#         {'role': k, 'prompt': prepend + render_single_message(model_architecture, v[0], v[1], v[2])}
#         for k, v in sample_seqs.items()
#     ]

# input_list = flatten_list([
#     [
#         {'question_ix': i, 'question': x, **p}
#         for p in get_sample_seqs(x)
#     ]
#     for i, x in enumerate(truncated_texts)
# ])

# input_df =\
#     pd.DataFrame(input_list)\
#     .assign(prompt_ix = lambda df: list(range(len(df))))

# display(input_df)

# # Print examples
# for p in [row['prompt'] for row in input_df.pipe(lambda df: df[df['question_ix'] == 1]).to_dict('records')]:
#     print(p)

In [None]:
"""
Create sample sequences - merged CoT strategy
To splice, don't just merge - prepend only
"""
truncated_texts = tokenizer.batch_decode(tokenizer([t['text'] for t in raw_data], padding = False, truncation = True, max_length = 512).input_ids)
n_seqs = len(truncated_texts)

def get_sample_seqs(probe_text, partner_text):
    """
    Params
        @probe_text: The text we're extracting states from (appears in all roles)
        @partner_text: Random paired text (only used in merged sample's assistant position)
    """
    seqs = []

    for role in ['system', 'user', 'tool']:
        seqs.append({
            'role': role,
            'prompt': prepend + render_single_message(model_architecture, role = role, content = probe_text)
        })

    # Merged assistant â€“ CoT
    seqs.append({
        'role': 'assistant_cot',
        'prompt': prepend + render_mixed_cot(model_architecture, cot = probe_text, assistant = partner_text)
    })

    return seqs

perm = np.random.permutation(n_seqs)
while np.any(perm == np.arange(n_seqs)):
    perm = np.random.permutation(n_seqs)

input_list = []
for base_ix, base_text in enumerate(truncated_texts):
    partner_ix = int(perm[base_ix])
    partner_text = truncated_texts[partner_ix]
    for seq in get_sample_seqs(base_text, partner_text):
        row = {
            'question_ix': base_ix,
            'question': base_text,
            'partner_ix': partner_ix,
            'partner_text': partner_text,
            **seq
        }
        input_list.append(row)

input_df = pd.DataFrame(input_list).assign(prompt_ix = lambda df: list(range(len(df))))

display(input_df)

# Print examples for a particular base text index
for p in input_df.query('question_ix == 1')['prompt']:
    print(p)
    print("=" * 80)

In [None]:
# """
# SHREK: Stacked Heterogeneous Role Encoding with randomized Kontext
# Takes individual prompts and concatenates them into mega-sequences
# """
# # Step 1: Generate all individual prompts (your existing code)
# truncated_texts = tokenizer.batch_decode(tokenizer([t['text'] for t in raw_data], padding = False, truncation = True, max_length = 384).input_ids)
# n_seqs = len(truncated_texts)

# def get_sample_seqs(probe_text, partner_text):
#     """Generate individual role prompts"""
#     seqs = []
    
#     for role in ['system', 'user', 'tool']:
#         seqs.append({
#             'role': role,
#             'prompt': render_single_message(model_architecture, role = role, content = probe_text)
#         })
    
#     # Merged assistant-CoT
#     seqs.append({
#         'role': 'assistant_cot',
#         'prompt': render_mixed_cot(model_architecture, cot = probe_text, assistant = partner_text)
#     })
    
#     return seqs

# # Generate pairings
# perm = np.random.permutation(n_seqs)
# while n_seqs > 1 and np.any(perm == np.arange(n_seqs)):
#     perm = np.random.permutation(n_seqs)

# # Create all individual prompts
# all_prompts = []
# for base_ix, base_text in enumerate(truncated_texts):
#     partner_ix = int(perm[base_ix])
#     partner_text = truncated_texts[partner_ix]
    
#     for seq in get_sample_seqs(base_text, partner_text):
#         all_prompts.append({
#             'question_ix': base_ix,
#             'question': base_text,
#             'partner_ix': partner_ix,
#             'partner_text': partner_text,
#             **seq
#         })

# print(f"Generated {len(all_prompts)} individual prompts")

# # ============================================================
# # SHREK STEP 2: Concatenate into mega-sequences
# # ============================================================
# import random

# MAX_TOKENS = 1024
# PREPEND_LENGTH = 100  # Fixed assumed max length for prepend/BOS tokens
# random.seed(seed)

# # Shuffle all prompts randomly
# shuffled_prompts = all_prompts.copy()
# random.shuffle(shuffled_prompts)

# mega_sequences = []
# current_sequence = []
# current_token_count = PREPEND_LENGTH  # Start with prepend budget

# for prompt_data in tqdm(shuffled_prompts, desc="Building SHREK sequences"):
#     prompt_text = prompt_data['prompt']
    
#     # Tokenize to get length
#     prompt_tokens = tokenizer(prompt_text, add_special_tokens = False)
#     prompt_length = len(prompt_tokens['input_ids'])
    
#     # Check if adding this prompt would exceed limit
#     if current_token_count + prompt_length > MAX_TOKENS:
#         # Save current sequence and start new one
#         if current_sequence:
#             mega_sequences.append(current_sequence)
#         current_sequence = [prompt_data]
#         current_token_count = PREPEND_LENGTH + prompt_length  # Reset with prepend budget
#     else:
#         # Add to current sequence
#         current_sequence.append(prompt_data)
#         current_token_count += prompt_length

# # Don't forget the last sequence
# if current_sequence:
#     mega_sequences.append(current_sequence)

# print(f"Created {len(mega_sequences)} mega-sequences")

# # ============================================================
# # SHREK STEP 3: Format mega-sequences for processing
# # ============================================================

# shrek_input_list = []

# for mega_ix, sequence in enumerate(mega_sequences):
#     # Concatenate all prompts (no prepends in them)
#     concatenated_body = ''.join([p['prompt'] for p in sequence])
    
#     # Add prepend ONCE at the beginning
#     full_prompt = prepend + concatenated_body
    
#     shrek_input_list.append({
#         'mega_ix': mega_ix,
#         'num_prompts': len(sequence),
#         'prompt': full_prompt,
#         'source_prompts': sequence
#     })

# input_df = pd.DataFrame(shrek_input_list).assign(prompt_ix = lambda df: list(range(len(df))))

# print(f"\nSHREK Statistics:")
# print(f"- Total mega-sequences: {len(input_df)}")
# print(f"- Avg prompts per sequence: {input_df['num_prompts'].mean():.1f}")

# display(input_df[['mega_ix', 'num_prompts']])

# # Example
# print("\n" + "="*80)
# print("EXAMPLE MEGA-SEQUENCE:")
# print("="*80)
# print(input_df.iloc[1]['prompt'])

In [None]:
""" 
Load dataset into a dataloader. The dataloader returns the original tokens - this is important for BPE tokenizers as otherwise it's difficult to reconstruct the correct string later!
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

def chunk_list(input_list, max_length):
    return [input_list[i:i + max_length] for i in range(0, len(input_list), max_length)]

test_dl = DataLoader(
    ReconstructableTextDataset(input_df['prompt'].tolist(), tokenizer, max_length = 512 * 2 + 64, prompt_ix = input_df['prompt_ix'].tolist()), #512 * 2 + 64
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

# Get activations

In [None]:
"""
Run forward passes
"""
@torch.no_grad()
def run_and_export_states(model, dl: ReconstructableTextDataset, layers_to_keep_acts: list[int]):
    """
    Run forward passes on given model and store the decomposed sample_df plus hidden states

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits` and `all_hidden_states`.
        @dl: A ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_hidden_states` (see returned object description).

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `all_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.

    Example:
        test_outputs = run_and_export_states(model, test_dl, layers_to_keep_acts = list(range(model_n_layers)))
    """
    all_hidden_states = []
    sample_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']
        prompt_indices = batch['prompt_ix']

        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

        # Check no bugs by validating output/perplexity
        if batch_ix == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), output['logits'].size(-1)).detach().cpu().item()
            for i in range(min(20, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
                
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )
                
        prompt_indices_df = pd.DataFrame(
            [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
            columns = ['sequence_ix', 'prompt_ix']
        )
        
        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
            .assign(batch_ix = batch_ix)
        
        sample_dfs.append(sample_df)

        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

        batch_ix += 1

    sample_df = pd.concat(sample_dfs, ignore_index = True).drop(columns = ['batch_ix', 'sequence_ix']) # Drop batch/seq_ix, since prompt_ix identifies
    all_hidden_states = torch.cat(all_hidden_states, dim = 0)

    return {
        'sample_df': sample_df,
        'all_hs': all_hidden_states
    }

layers_to_probe = list(range(0, model_n_layers, 2 if model_n_layers <= 24 else 4))
res = run_and_export_states(model, test_dl, layers_to_keep_acts = layers_to_probe)

# Run probes


In [None]:
"""
Label roles and omit prepend
"""
import utils.role_assignments2
importlib.reload(utils.role_assignments2)
from utils.role_assignments2 import label_content_roles
from utils.substring_assignments import flag_message_types

probe_sample_df = res['sample_df']

# First, let's flag all roles
probe_sample_df =\
    label_content_roles(model_architecture, res['sample_df'])\
    .assign(sample_ix = lambda df: range(0, len(df)))

# Now, let's flag the prepend text.
probe_sample_df =\
    flag_message_types(probe_sample_df, [prepend])\
    .drop(columns = 'base_message')

# Drop prepend text + non-content tags
valid_probe_sample_df =\
    probe_sample_df\
    .pipe(lambda df: df[(df['base_message_ix'].isna()) & (df['is_content'] == True) & (df['role'].notna())])

# Verify role counts are accurate (should be exactly equal for most models, slight discrepancy for Olmo3 since <think> isn't a single tok)
display(
    valid_probe_sample_df\
    .groupby('role', as_index = False)\
    .agg(count = ('sample_ix', 'count'))
)

# Validate roles are flagged correctly
valid_probe_sample_df\
    .pipe(lambda df: df[df['prompt_ix'] <= 14])\
    .groupby(['prompt_ix', 'seg_ix', 'role'], as_index = False)\
    .agg(combined_text = ('token', ''.join))\
    .assign(eot = lambda df: df['combined_text'].str[-30:])

# probe_sample_df.pipe(lambda df: df[df['prompt_ix'] == 3]).to_csv('dump.csv')

In [None]:
label_content_roles(model_architecture, res['sample_df'])\
    .head(80).tail(50)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + layer-wise dict
"""
all_probe_hs = res['all_hs'].to(torch.float16)
# compare_bf16_fp16_batched(res['all_hs], all_probe_hs)
del res['all_hs']
all_probe_hs = {layer_ix: all_probe_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(layers_to_probe)}
gc.collect()

In [None]:
print(''.join(probe_sample_df.pipe(lambda df: df[df['prompt_ix'] == 3])['token'].tolist()))
print(''.join(valid_probe_sample_df.pipe(lambda df: df[(df['prompt_ix'] == 3) & (df['role'] == 'assistant-final')])['token'].tolist()))

In [None]:
"""
Run logistic regression probes to get role space models
"""
def fit_lr(x_train, y_train, x_test, y_test):
    """
    Fit a probe with a standard 80/20 split
    """
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', C = 10, max_iter = 10_000, fit_intercept = True) # 1e-2 or 1e-3 reg
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    y_test_pred = lr_model.predict(x_test)
    return lr_model, accuracy, y_test_pred

def get_probe_result(layer_ix, label2id):
    """
    Get probe results for a single layer and label combination

    Params:
        @layer_ix: The layer index to train the probe on
        @label2id: The label-to-id mapping
    
    Description:
        Trains only on content space for given roles
    """
    # id2label = {v: k for k, v in label2id.items()}
    roles = list(label2id.keys())

    # Get valid samples
    probe_sample_df = valid_probe_sample_df[valid_probe_sample_df['role'].isin(roles)]

    # 80:20 prompt split
    prompt_ix_train, prompt_ix_test = cuml.train_test_split(probe_sample_df['prompt_ix'].unique(), test_size = 0.2, random_state = seed)
    train_df = probe_sample_df[probe_sample_df['prompt_ix'].isin(prompt_ix_train)]
    test_df = probe_sample_df[probe_sample_df['prompt_ix'].isin(prompt_ix_test)]

    # Get y labels
    role_labels_train_cp = cupy.asarray([label2id[r] for r in train_df['role']])
    role_labels_test_cp = cupy.asarray([label2id[r] for r in test_df['role']])

    # Get x labels
    x_train_cp = cupy.asarray(all_probe_hs[layer_ix][train_df['sample_ix'].tolist(), :].to(torch.float32).detach().cpu())
    x_test_cp = cupy.asarray(all_probe_hs[layer_ix][test_df['sample_ix'].tolist(), :].to(torch.float32).detach().cpu())
    
    uniq_train = np.unique(role_labels_train_cp.get())
    if len(uniq_train) < len(label2id):
        raise Exception(f"Skipping layer {layer_ix}, mapping {label2id}: missing roles in train", uniq_train)

    lr_model, test_acc, y_test_pred = fit_lr(x_train_cp, role_labels_train_cp, x_test_cp, role_labels_test_cp)

    print(f"Layer [{layer_ix}] with roles [{'+'.join(roles)}]: {test_acc:.2f}")

    # Optional: Return classification metrics
    results_df =\
        test_df\
        .assign(pred = y_test_pred.tolist())\
        .assign(pred = lambda df: df['pred'].map({v: k for k, v in label2id.items()}))\
        .assign(is_acc = lambda df: df['role'] == df['pred'])

    acc_by_role =\
        results_df\
        .groupby(['role', 'pred'], as_index = False)\
        .agg(count = ('sample_ix', 'count'))

    acc_by_pos =\
        results_df\
        .groupby('token_in_seg_ix', as_index = False)\
        .agg(count = ('sample_ix', 'count'), acc = ('is_acc', 'mean'))

    return {
        'layer_ix': layer_ix, 'label2id': label2id,
        'roles': roles, 'probe': lr_model, 'accuracy': test_acc,
        'acc_by_role': acc_by_role, 'acc_by_pos': acc_by_pos
    }

mappings_to_test = [
    {'user': 0, 'cot': 1, 'assistant': 2},
    {'system': 0, 'user': 1, 'cot': 2, 'assistant': 3},
    {'system': 0, 'user': 1, 'assistant': 2, 'tool': 3},
    {'user': 0, 'assistant': 1, 'tool': 2},
    {'user': 0, 'cot': 1, 'assistant': 2, 'tool': 3}
    # {'system': 0, 'user': 1, 'assistant-cot': 2, 'assistant-final': 3, 'tool': 4},
    # {'user': 0, 'assistant-cot': 1, 'assistant-final': 2, 'tool': 3}
]

all_probes = []
for layer_ix in tqdm(layers_to_probe):
    for mapping in mappings_to_test:
        all_probes.append(get_probe_result(layer_ix, mapping))

In [None]:
messages = [
    {"role": "user", "content": "Look up the capital of France, then answer."},
    {
        "role": "assistant",
        "content": "Here are my reasoning steps:\nI should call a lookup tool.",
        "tool_calls": [
            {
                "function": {
                    "name": "lookup_capital",
                    "arguments": '{"country":"France"}'
                },
                "id": "call_0",
            }
        ],
    },
    {"role": "tool", "content": '{"capital":"Paris"}'},

]

formatted = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,  # Apriel template forces this anyway
)

print(formatted)

In [None]:
"""
Check val accuracy by layer/role/pos
"""
display(
    pd.DataFrame(all_probes)[['layer_ix', 'roles', 'accuracy']]\
        .assign(accuracy = lambda df: df['accuracy'].round(2), roles = lambda df: df['roles'].apply(lambda x: ','.join(([r[0] for r in x]))))\
        .pivot(index = 'layer_ix', columns = 'roles', values = 'accuracy')
)

acc_by_role = pd.concat([p['acc_by_role'].assign(model = model_prefix, layer = p['layer_ix'], roles = ','.join(p['roles'])) for p in all_probes], ignore_index = True)
acc_by_pos =\
    pd.concat([p['acc_by_pos'].assign(model = model_prefix, layer = p['layer_ix'], roles = ','.join(p['roles'])) for p in all_probes], ignore_index = True)\
    .assign(acc = lambda df: df['acc'].round(4))

display(acc_by_role)
display(acc_by_pos)

acc_by_role.to_csv(f'{ws}/experiments/role-analysis/probes/acc_by_role_{model_prefix}.csv', index = False)
acc_by_pos.to_csv(f'{ws}/experiments/role-analysis/probes/acc_by_pos_{model_prefix}.csv', index = False)

In [None]:
"""
Check accuracy counts
"""
base_sums =\
    acc_by_role\
    .groupby(['role', 'layer', 'roles'], as_index = False)\
    .agg(base_sum = ('count', 'sum'))

acc_by_role\
    .pipe(lambda df: df[df['role'] == df['pred']])\
    .groupby(['role', 'layer', 'roles'], as_index = False)\
    .agg(sum = ('count', 'sum'))\
    .merge(base_sums, on = ['layer', 'roles', 'role'], how = 'inner')\
    .assign(acc = lambda df: df['sum']/df['base_sum'])\
    .pivot(index = ['roles', 'layer'], columns = 'role', values = 'acc')

# Generalization test

In [None]:
# Fix role names
role_map = {'assistant-cot': 'cot', 'assistant-final': 'assistant'}
for probe in all_probes:
    probe['roles'] = tuple(role_map.get(r, r) for r in probe['roles'])

probe_layers = sorted(set(x['layer_ix'] for x in all_probes))
probe_roles = sorted(set(tuple(x['roles']) for x in all_probes))

print(f"Num probes: {str(len(all_probes))}")
print(f"Probe layers:\n  {', '.join([str(x) for x in probe_layers])}")
print(f"Probe roles:\n {'\n '.join([str(list(x)) for x in probe_roles])}")

In [None]:
"""
Validate IFT on the full chat template, which tokenizes the whole thing e2e. This should match the format from original messages.:
- `load_chat_template`: Overwrites the existing chat template in the tokenizer with one that better handles discrepancies.
    See docstring for the function for details. 
- `label_content_roles`: Takes an instruct-formatted text, then assigns each token to the correct roles.
"""
import importlib
import utils.role_assignments2
import utils.role_templates
importlib.reload(utils.role_assignments2)
label_content_roles = utils.role_assignments2.label_content_roles
importlib.reload(utils.role_templates)
load_chat_template = utils.role_templates.load_chat_template

# Load custom chat templater
old_chat_template = tokenizer.chat_template
new_chat_template = load_chat_template(f'{ws}/utils/chat_templates', model_architecture)

def test_chat_template(tokenizer):
    s = tokenizer.apply_chat_template(
        [
            {'role': 'system', 'content': 'Test.'},
            {'role': 'user', 'content': 'Hi! I\'m a dog.'},
            {'role': 'assistant', 'content': '<think>The user is a dog!</think>Congrats!'},
            {'role': 'user', 'content': 'Thanks!'},
            {'role': 'assistant', 'content': '<think>Hmm, the user said thanks!</think>Anything else?'}
        ],
        tokenize = False,
        padding = 'max_length',
        truncation = True,
        max_length = 512,
        add_generation_prompt = False
    )
    print(s)
    return s

tokenizer.chat_template = old_chat_template
s = test_chat_template(tokenizer)
tokenizer.chat_template = new_chat_template
s = test_chat_template(tokenizer)

z =\
    pd.DataFrame({'input_ids': tokenizer(s)['input_ids']})\
    .assign(token = lambda df: tokenizer.convert_ids_to_tokens(df['input_ids']))\
    .assign(prompt_ix = 0, batch_ix = 0, sequence_ix = 0, token_ix = lambda df: df.index)

label_content_roles(model_architecture, z).tail(50)

In [None]:
"""
Validate IFT
"""
conv = prepend + tokenizer.apply_chat_template(
    [
        {'role': 'user', 'content': 'Hi stinky pupper'}    
    ],
    tokenize = False,
    add_generation_prompt = True
)

inputs = tokenizer(conv, return_tensors = 'pt')

print(tokenizer.decode(model.generate(
    inputs['input_ids'].to(main_device),
    attention_mask = inputs['attention_mask'].to(main_device),
    max_new_tokens = 500,
    do_sample = False
))[0])

In [None]:
gen_ids = model.generate(
    inputs['input_ids'].to(main_device),
    attention_mask=inputs['attention_mask'].to(main_device),
    max_new_tokens=500,
    do_sample=False
)
token_ids = gen_ids[0].tolist()
tokens = tokenizer.convert_ids_to_tokens(token_ids)
df = pd.DataFrame({
    "token_id": token_ids,
    "token": tokens
})

df

In [None]:
df.head(50)

In [None]:
tokenizer.encode('<|begin_system|>')

In [None]:
"""
Import conversations
"""
# Created by get-conversations-data.ipynb
if model_prefix == 'apriel-16':
    user_queries_df = pd.read_csv(f'{ws}/experiments/role-analysis/convs/gptoss20.csv')
else:
    user_queries_df = pd.read_csv(f'{ws}/experiments/role-analysis/convs/{model_prefix}.csv')

print(f"Starting length: {len(user_queries_df['conv_id'].unique())} convs")

# 1. Remove convs with any missing responses
convs_with_missing = user_queries_df[user_queries_df[['user_query', 'cot', 'assistant']].isna().any(axis = 1)]['conv_id'].unique()
user_queries_df = user_queries_df[~user_queries_df['conv_id'].isin(convs_with_missing)]

print(f"After dropping missing: {len(user_queries_df['conv_id'].unique())} convs")

# 2. Filter out super long convs + convs with any message <= conv_min_length (to avoid issues with flag_message_type fn)
user_queries_df =\
    user_queries_df\
    .assign(total_len = lambda df:
        df.groupby('conv_id')['user_query'].transform(lambda x: x.str.len().sum()) +
        df.groupby('conv_id')['cot'].transform(lambda x: x.str.len().sum()) +
        df.groupby('conv_id')['assistant'].transform(lambda x: x.str.len().sum())
    )\
    .assign(min_len = lambda df: df[['user_query', 'cot', 'assistant']].apply(lambda x: x.str.len()).min(axis=1))\
    .assign(conv_min_len = lambda df: df.groupby('conv_id')['min_len'].transform('min'))\
    .assign(conv_min_user_len = lambda df: df.groupby('conv_id')['user_query'].transform(lambda x: x.str.len().min()))\
    .query('total_len < 20000 and conv_min_len >= 30 and conv_min_user_len > 50')\
    .drop(columns = ['total_len', 'min_len', 'conv_min_len', 'conv_min_user_len'])

print(f"After dropping long convs: {len(user_queries_df['conv_id'].unique())} convs")

# 3. Filter out convs where any message is a substring of another message
#    (this causes dupe issues with the flag_message_type fn, which does not handle tokens in multiple string contexts)
def has_substring_messages(group):
    contents = group['user_query'].tolist() + group['cot'].tolist() + group['assistant'].tolist()
    for i, a in enumerate(contents):
        for j, b in enumerate(contents):
            if i != j and a in b:
                return True
    return False

bad_convs = user_queries_df.groupby('conv_id').filter(has_substring_messages)['conv_id'].unique()
user_queries_df = user_queries_df.pipe(lambda df: df[~df['conv_id'].isin(bad_convs)])
print(f"After dropping substr: {len(user_queries_df['conv_id'].unique())} convs")

# Combine into conv-level df
convs_df =\
    user_queries_df\
    .sort_values(['conv_id', 'user_query_ix'])\
    .assign(conv_id = lambda df: df.groupby('conv_id', sort = False).ngroup())\
    .groupby('conv_id')\
    .apply(lambda g: pd.Series({
        'dataset': g['dataset'].values[0],
        'messages': [
            msg
            for _, row in g.iterrows()
            for msg in [
                {'role': 'user', 'content': row['user_query']},
                {'role': 'cot', 'content': row['cot']},
                {'role': 'assistant', 'content': row['assistant']}
            ]
        ]
    }))\
    .reset_index()\
    [['conv_id', 'dataset', 'messages']]\
    .sample(n = 40)\
    .assign(conv_id = lambda df: range(len(df)))

print(f"Final: {len(convs_df['conv_id'].unique())} convs")

convs_df
convs = convs_df['messages'].tolist()

messages_df =\
    convs_df\
    .explode('messages')\
    .assign(message_ix = lambda df: df.groupby('conv_id').cumcount())\
    .assign(
        role = lambda df: df['messages'].apply(lambda x: x['role']),
        content = lambda df: df['messages'].apply(lambda x: x['content'])
    )\
    .drop(columns = 'messages')\
    .reset_index(drop = True)

convs = convs_df['messages'].tolist()

display(convs_df.head(5))
display(messages_df.head(5))
print(f"Total convs: {len(convs_df)}")

In [None]:
"""
Prep untagged conversations
"""
all_convs_by_type = {}

# Prep conversations
def prep_conv(conv):
    return prepend + '\n\n'.join([x['content'] for x in conv])

all_convs_by_type['untagged'] = [prep_conv(conv) for conv in convs]
print(all_convs_by_type['untagged'][0])

# Find max input length
max_input_length =\
    tokenizer(all_convs_by_type['untagged'], padding = True, truncation = True, max_length = 1024 * 8, return_tensors = 'pt')\
    ['attention_mask'].sum(dim = 1).max().item()

print(f"Max input length: {max_input_length}")
print(f"Untagged convs: {len(all_convs_by_type['untagged'])}")

In [None]:
"""
Prep tagged conversations
"""
def fold_cot_into_final(convs):
    """Fold CoT into the following assistant message as a <think></think> tag."""
    result = []
    for conv in convs:
        new_conv = []
        it = iter(enumerate(conv))
        for i, msg in it:
            if msg['role'] == 'cot':
                if i + 1 >= len(conv) or conv[i + 1]['role'] != 'assistant':
                    raise ValueError("cot must be followed by assistant")
                _, next_msg = next(it)
                new_conv.append({'role': 'assistant', 'content': f"<think>{msg['content']}</think>{next_msg['content']}"})
            else:
                new_conv.append({'role': msg['role'], 'content': msg['content']})
        result.append(new_conv)
    return result

final_convs = fold_cot_into_final(convs)
all_convs_by_type['tagged'] = [
    prepend + x
    for x in tokenizer.apply_chat_template(final_convs, tokenize = False, add_generation_prompt = False)
]
print(all_convs_by_type['tagged'][0])

# Find max input length
max_input_length =\
    tokenizer(all_convs_by_type['tagged'], padding = True, truncation = True, max_length = 1024 * 8, return_tensors = 'pt')\
    ['attention_mask'].sum(dim = 1).max().item()

print(f"Max input length: {max_input_length}")
print(f"Tagged convs: {len(all_convs_by_type['tagged'])}")

In [None]:
"""
Run forward passes
"""
# Get input df
input_convs_df =\
    pd.concat([
        pd.DataFrame({'convs': v, 'conv_type': conv_type})\
            .assign(conv_id = lambda df: range(0, len(df)))\
            .merge(convs_df, how = 'inner', on = 'conv_id')
        for conv_type, v in all_convs_by_type.items()
    ], ignore_index = True)\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))
    
display(input_convs_df)

convs_dl = DataLoader(
    ReconstructableTextDataset(input_convs_df['convs'].tolist(), tokenizer, max_length = 1024 * 5, prompt_ix = list(range(0, len(input_convs_df)))),
    batch_size = 8,
    shuffle = False,
    collate_fn = stack_collate
)

convs_outputs = run_and_export_states(model, convs_dl, layers_to_keep_acts = probe_layers)

In [None]:
"""
Label messages
"""
import utils.substring_assignments
importlib.reload(utils.substring_assignments)
from utils.substring_assignments import flag_message_types

sample_dfs_by_conv = [group for _, group in convs_outputs['sample_df'].groupby('prompt_ix', sort = True)]
metadata_by_conv = input_convs_df.to_dict('records')

all_res = []

# Iterate through (input metadata, token-level sample df) pairs
for conv_metadata, sample_df_for_conv in tqdm(zip(metadata_by_conv, sample_dfs_by_conv)):

    content_spans = [msg['content'] for msg in conv_metadata['messages']]
    content_roles = [msg['role'] for msg in conv_metadata['messages']]
    
    try:
        # print(content_spans)
        # Note: set the last arg of flag_message_types to True for Olmo3 and other models which DIRECTLY 
        # quote the entire user message (one base_message is a subset of another)
        res = \
            flag_message_types(sample_df_for_conv, content_spans, any(x in model_prefix for x in ['olmo', 'glm']))\
            .merge(
                pd.DataFrame({'role': content_roles, 'base_message_ix': range(len(content_spans))}),
                on = 'base_message_ix',
                how = 'left'
            )
    except Exception as e:
        print(e)
        continue
    
    all_res.append(res)

display(pd.concat(all_res).pipe(lambda df: df[df['role'].isna()]))
        
sample_df_labeled =\
    pd.concat(all_res).reset_index(drop = True)\
    .assign(sample_ix = lambda df: range(0, len(df)))\
    .assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount())

display(sample_df_labeled.pipe(lambda df: df[~df['role'].isna()]))

sample_df_labeled\
    .pipe(lambda df: df[~df['role'].isna()])\
    .groupby(['prompt_ix', 'role'], as_index = False)\
    .agg(count = ('sample_ix', 'count'))\
    .pivot(index = ['prompt_ix'], columns = 'role', values = 'count')\
    .head(20)

In [None]:
print(
    sample_df_labeled\
    .pipe(lambda df: df[df['prompt_ix'] == 8])\
    .tail(50).head(1)['base_message'].tolist()[0]
)

In [None]:
"""
Clean up conversation pre_mlp_hs
"""
convs_hs = convs_outputs['all_hs'].to(torch.float16)
convs_hs = {layer_ix: convs_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(probe_layers)} # Match layers_to_keep_act
print(convs_hs[0].shape)
print(len(sample_df_labeled))
gc.collect()

In [None]:
"""
Run projections
"""
test_layer = 40 # probe_layers[int(np.ceil(0.5 * len(probe_layers)).item())]
test_roles = ['user', 'cot', 'assistant', 'tool'] # system, user, cot, assistant, tool
probe = [x for x in all_probes if sorted(x['roles']) == sorted(test_roles) and x['layer_ix'] == test_layer][0]

project_sample_df = sample_df_labeled.pipe(lambda df: df[~df['role'].isna()]) # Or drop non-roles
project_hs_cp = cupy.asarray(convs_hs[test_layer][project_sample_df['sample_ix'].tolist(), :])
project_probs = probe['probe'].predict_proba(project_hs_cp).round(8)

proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
if len(proj_results) != len(project_sample_df):
    raise Exception("Error!")

role_level_df =\
    pd.concat([proj_results, project_sample_df[['sample_ix']]], axis = 1)\
    .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
    .merge(project_sample_df, on = 'sample_ix', how = 'inner')\
    .merge(input_convs_df[['prompt_ix', 'dataset', 'conv_id', 'conv_type']], on = 'prompt_ix')

role_level_df

In [None]:
proj_results

In [None]:
"""
Tagged vs untagged
"""
print('Accuracy: Tagged and untagged by role')
display(
    role_level_df\
    .pipe(lambda df: df[~df['role'].isna()])\
    .pipe(lambda df: df[(df['token_in_prompt_ix'] >= 0) & (df['token_in_prompt_ix'] <= 20000)])\
    .groupby(['conv_type', 'role_space', 'role', 'prompt_ix'], as_index = False)\
    .agg(mean_prob = ('prob', 'mean'))\
    .groupby(['conv_type', 'role_space', 'role'], as_index = False)\
    .agg(mean_prob = ('mean_prob', 'mean'))\
    .pivot(index = ['conv_type', 'role'], columns = 'role_space', values = 'mean_prob')
)

In [None]:
print(role_level_df\
    .pipe(lambda df: df[df['prompt_ix'] == 30])\
    # .pipe(lambda df: df[df['role'] == 'cot'])\
    .groupby('role_space')\
    .agg(combined_text = ('token', ''.join))
    ['combined_text'].tolist()[0]
)
print('-----------')
print(role_level_df\
    # .pipe(lambda df: df[df['conv_type'] == 'tagged'])\
    .pipe(lambda df: df[df['prompt_ix'] == 60])\
    .pipe(lambda df: df[df['role'] == 'cot'])\
    .groupby('role_space')\
    .agg(combined_text = ('token', ''.join))\
    ['combined_text'].tolist()[0]
)

In [None]:
"""
All-layer projections
"""
all_projs = []

project_sample_df = sample_df_labeled # Or drop non-roles

for probe in tqdm(all_probes):

    project_hs_cp = cupy.asarray(convs_hs[probe['layer_ix']][project_sample_df['sample_ix'].tolist(), :])

    project_probs = probe['probe'].predict_proba(project_hs_cp).round(4)
    proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
    if len(proj_results) != len(project_sample_df):
        raise Exception("Error!")

    role_level_df =\
        pd.concat([proj_results, project_sample_df[['sample_ix']]], axis = 1)\
        .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
        .merge(project_sample_df, on = 'sample_ix', how = 'inner')\
        .merge(input_convs_df[['prompt_ix', 'dataset', 'conv_id', 'conv_type']], on = 'prompt_ix')\
        .assign(layer_ix = probe['layer_ix'])\
        .assign(roles = ','.join(probe['roles']))\
        .drop(columns = ['base_message'])

    all_projs.append(role_level_df)

all_projs_df =\
    pd.concat(all_projs)\
    .reset_index(drop = True)\
    .drop(columns = ['output_prob', 'output_id'])\
    .pipe(lambda df: df[~df['role'].isna()])

In [None]:
"""
Validation
"""
ft =\
    all_projs_df\
    .pipe(lambda df: df[df['roles'] == 'user,cot,assistant'])\
    .pipe(lambda df: df[df['role_space'] == df['role']])\
    .drop(columns = 'role_space')\
    .rename(columns = {'prob': 'acc'})

print('Tagged, all layers')
display(
    ft\
        .pipe(lambda df: df[df['conv_type'] == 'tagged'])\
        .groupby(['roles', 'layer_ix', 'role', 'prompt_ix'], as_index = False)\
        .agg(mean_acc = ('acc', 'mean'))\
        .groupby(['roles', 'layer_ix', 'role'], as_index = False)\
        .agg(mean_acc = ('mean_acc', 'mean'))\
        .pivot(index = ['roles', 'layer_ix'], columns = 'role', values = 'mean_acc')\
        .reset_index()\
        .rename_axis(columns = None)
)

print('Untagged, all layers')
display(
    ft
        .pipe(lambda df: df[df['conv_type'] == 'untagged'])\
        .groupby(['roles', 'layer_ix', 'role', 'prompt_ix'], as_index = False)\
        .agg(mean_acc = ('acc', 'mean'))\
        .groupby(['roles', 'layer_ix', 'role'], as_index = False)\
        .agg(mean_acc = ('mean_acc', 'mean'))\
        .pivot(index = ['roles', 'layer_ix'], columns = 'role', values = 'mean_acc')\
        .reset_index()\
        .rename_axis(columns = None)
)

In [None]:
ft =\
    all_projs_df\
    .pipe(lambda df: df[df['roles'] == 'system,user,cot,assistant'])\
    .pipe(lambda df: df[df['role_space'] == df['role']])\
    .drop(columns = 'role_space')\
    .rename(columns = {'prob': 'acc'})

print('Tagged, all layers')
display(
    ft\
        .pipe(lambda df: df[df['conv_type'] == 'tagged'])\
        .groupby(['roles', 'layer_ix', 'role', 'prompt_ix'], as_index = False)\
        .agg(mean_acc = ('acc', 'mean'))\
        .groupby(['roles', 'layer_ix', 'role'], as_index = False)\
        .agg(mean_acc = ('mean_acc', 'mean'))\
        .pivot(index = ['roles', 'layer_ix'], columns = 'role', values = 'mean_acc')\
        .reset_index()\
        .rename_axis(columns = None)
)

print('Untagged, all layers')
display(
    ft
        .pipe(lambda df: df[df['conv_type'] == 'untagged'])\
        .groupby(['roles', 'layer_ix', 'role', 'prompt_ix'], as_index = False)\
        .agg(mean_acc = ('acc', 'mean'))\
        .groupby(['roles', 'layer_ix', 'role'], as_index = False)\
        .agg(mean_acc = ('mean_acc', 'mean'))\
        .pivot(index = ['roles', 'layer_ix'], columns = 'role', values = 'mean_acc')\
        .reset_index()\
        .rename_axis(columns = None)
)

In [None]:
all_projs_df\
    .pipe(lambda df: df[df['roles'] == 'user,cot,assistant,tool'])\
    [['role_space']]\
    .drop_duplicates()

In [None]:
ft

In [None]:
raise Exception('Stop!')

# Tomato test

In [None]:
import yaml

def load_messages(yaml_path, model_key):
    """Loads system-like, assistant-like, user-like, etc. messages"""
    with open(yaml_path, 'r', encoding = 'utf-8') as f:
        data = yaml.safe_load(f)
    return [x['content'] for x in data[model_key]]

base_message_types = ['system', 'user', 'cot', 'assistant', 'user', 'cot', 'assistant']
base_messages = load_messages(f"{ws}/experiments/role-analysis/prompts/standard-conversations.yaml", 'gptoss20')

tomato_prompts = {}

tomato_prompts['basic_no_format'] = prepend + '\n'.join(base_messages)

tomato_prompts['everything_in_user_tags'] = prepend + tokenizer.apply_chat_template(
    [{'role': 'user', 'content': '\n'.join(base_messages)}],
    tokenize = False, add_generation_prompt = False
)

tomato_prompts['proper_tags'] = prepend + tokenizer.apply_chat_template(
    [
        {'role': 'system', 'content': base_messages[0]},
        {'role': 'user', 'content': base_messages[1]},
        {'role': 'assistant', 'content': f"<think>{base_messages[2]}</think>{base_messages[3]}"},
        {'role': 'user', 'content': base_messages[4]},
        {'role': 'assistant', 'content': f"<think>{base_messages[5]}</think>{base_messages[6]}"}
    ],
    tokenize = False, add_generation_prompt = False
)

print(tomato_prompts['proper_tags'])

tomato_input_df =\
    pd.DataFrame({
        'prompt': [p for _, p in tomato_prompts.items()],
        'prompt_key': [pk for pk, _ in tomato_prompts.items()]
    })\
    .assign(prompt_ix = lambda df: list(range(0, len(df))))

In [None]:
"""
Run asdfsdf
"""
tomato_dl = DataLoader(
    ReconstructableTextDataset(tomato_input_df['prompt'].tolist(), tokenizer, max_length = 512 * 4, prompt_ix = tomato_input_df['prompt_ix'].tolist()),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

tomato_outputs = run_and_export_states(model, tomato_dl, layers_to_keep_acts = probe_layers)

tomato_pre_mlp_hs = tomato_outputs['all_pre_mlp_hs'].to(torch.float16)
tomato_pre_mlp_hs = {layer_ix: tomato_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(probe_layers)} # Match layers_to_keep_act
print(tomato_pre_mlp_hs[0].shape)
gc.collect()

In [None]:
tomato_sample_df_labeled =\
    flag_message_types(tomato_outputs['sample_df'], base_messages)\
    .assign(sample_ix = lambda df: range(0, len(df)))\
    .assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount())\
    .merge(tomato_input_df, on = 'prompt_ix', how = 'inner')

tomato_sample_df_labeled\
    .groupby(['prompt_ix', 'prompt_key', 'base_message_ix', 'base_message'], as_index = False)\
    .agg(count = ('token', 'count'))

In [None]:
tomato_outputs['sample_df']

In [None]:
"""
Run projections
"""
test_layer = 16 # probe_layers[int(np.ceil(0.5 * len(probe_layers)).item())]
test_roles = ['user', 'cot', 'assistant'] # system, user, cot, assistant, tool
probe = [x for x in all_probes if sorted(x['roles']) == sorted(test_roles) and x['layer_ix'] == test_layer][0]

project_hs_cp = cupy.asarray(tomato_pre_mlp_hs[test_layer][tomato_sample_df_labeled['sample_ix'].tolist(), :])
project_probs = probe['probe'].predict_proba(project_hs_cp).round(8)

proj_results = pd.DataFrame(cupy.asnumpy(project_probs), columns = probe['roles']).clip(1e-6)
if len(proj_results) != len(tomato_sample_df_labeled):
    raise Exception("Error!")

role_level_df =\
    pd.concat([proj_results, tomato_sample_df_labeled[['sample_ix']]], axis = 1)\
    .melt(id_vars = ['sample_ix'], var_name = 'role_space', value_name = 'prob')\
    .merge(tomato_sample_df_labeled, on = 'sample_ix', how = 'inner')\
    .merge(
        pd.DataFrame({'base_message_ix': list(range(len(base_message_types))), 'base_message_type': base_message_types}),
        on = 'base_message_ix',
        how = 'left'
    )\
    .assign(base_message_type = lambda df: df['base_message_type'].fillna('other'))\
    .pipe(lambda df: df[df['base_message_type'] != 'other']) # Cull the role tags themselves

role_level_df

In [None]:
import plotly.express as px

facet_order = probe['roles']
all_message_types = role_level_df['base_message_type'].unique()
color_map = {msg_type: px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)] for i, msg_type in enumerate(all_message_types)}

smooth_df =\
    role_level_df\
    .sort_values('sample_ix')\
    .assign(
        prob_sma = lambda df: df.groupby(['prompt_ix', 'role_space'])['prob'].rolling(window = 2, min_periods = 1).mean().reset_index(level = [0, 1], drop = True),
        prob_ewma = lambda df: df.groupby(['prompt_ix', 'role_space'])['prob'].ewm(alpha = .5, min_periods = 1).mean().reset_index(level = [0, 1], drop = True)
    )

display(smooth_df)

for this_prompt_key in smooth_df['prompt_key'].unique().tolist():
    this_df = smooth_df.pipe(lambda df: df[df['prompt_key'] == this_prompt_key])
    fig = px.scatter(
        this_df, x = 'token_in_prompt_ix', y = 'prob_ewma',
        facet_row = 'role_space',
        color = 'base_message_type',
        color_discrete_map = color_map,
        category_orders = {
            'role_space': facet_order,
            'base_message_type': ['system', 'user', 'cot', 'assistant-final', 'other']
        },
        hover_name = 'token',
        hover_data = {
            'prob': ':.3f'
        },
        # markers = True,
        title = f'prompt = {this_prompt_key}',
        labels = {
            'token_in_prompt_ix': 'Token Index',
            'prob': 'Prob',
            'prob_smoothed': 'Smoothed Prob',
            'role_space': 'role'
        }
    ).update_yaxes(
        range = [0, 1],
        side = 'left'
    ).update_layout(height = 500, width = 800)

    def pretty(a):
        a.update(
            text = a.text.split("=")[-1],
            x = 0.5, xanchor = "center",
            y = a.y + 0.08,
            textangle = 0,
            font = dict(size = 12),
            bgcolor = 'rgba(255, 255, 255, 0.9)',  # light strip look
            showarrow = False
        )

    fig.for_each_annotation(pretty)
    fig.update_yaxes(title_text = None)
    fig.show()