In [None]:
"""
Create role probes and tests them
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import cupy
import cuml
import importlib
import gc
import pickle
import os
from tqdm import tqdm

from utils.memory import check_memory, clear_all_cuda_memory
from utils.role_assignments import label_content_roles

main_device = 'cuda:0'
seed = 1234

clear_all_cuda_memory()
check_memory()

ws = '/workspace/deliberative-alignment-jailbreaks'

# Load models

## Load model

In [None]:
"""
Load the base tokenizer/model
"""
selected_model_index = 0

def get_model(index):
    # HF model ID, model prefix, model architecture,  attn implementation, whether to use hf lib implementation
    models = {
        0: ('openai/gpt-oss-20b', 'gptoss20', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 24),
        1: ('openai/gpt-oss-120b', 'gptoss120', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 36),
        2: ('Qwen/Qwen3-30B-A3B-Thinking-2507', 'qwen3-30b-a3b', 'qwen3moe', None, True, 48),
        3: ('zai-org/GLM-4.5-Air-FP8', 'glm4moe', 'glm4moe', None, True, 45) # MoE layers
    }
    return models[index]

def load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
    model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir = cache_dir, dtype = torch.bfloat16, trust_remote_code = not model_use_hf, device_map = 'auto', attn_implementation = model_attn).eval()
    return tokenizer, model

model_id, model_prefix, model_architecture, model_attn, model_use_hf, model_n_layers = get_model(selected_model_index)
tokenizer, model = load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf)

# Train probe

## Load probe training activations

In [None]:
"""
Load SFT dataset (probe training data)
"""
def load_data(folder_path, model_prefix, max_data_files):
    """
    Load data saved by `export-c4-activations.ipynb`
    """
    folders = [f'{ws}/experiments/role-analysis/{folder_path}/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)    
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)    

    with open(f'{ws}/experiments/role-analysis/{folder_path}/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

sample_df_import, topk_df_import, all_pre_mlp_hs_import, act_map = load_data('activations', model_prefix, 10)

In [None]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first. We'll also flag the role wrapper tokens. Drop nothing. 
"""
def clean_mappings(model_architecture, sample_df_import, topk_df_import):
    """
    We'll read the role mappings the raw way, instead of just using the input_mappings file. Then we can re-use the same logic for later on.
    
    Note:
        Technically this could be done easily with the input_mappings file with 
        .merge(input_mappings[['prompt_ix', 'question_ix', 'role']], how = 'left', on = ['prompt_ix']), but this is more robust, 
        and doesn't assume one role per sequence/prompt_ix.
    """
    # input_mappings = pd.read_csv(f'{ws}/experiments/role-analysis/{folder_path}/{model_prefix}/input_mappings.csv')

    sample_df_raw =\
        label_content_roles(model_architecture, sample_df_import)\
        .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
        .assign(token_in_seq_ix = lambda df: df.groupby(['prompt_ix']).cumcount())\
        .assign(token_in_role_ix = lambda df: df.groupby(['prompt_ix', 'role']).cumcount())

    c4_topk_df =\
        topk_df_import\
        .merge(
            sample_df_raw[['sample_ix', 'prompt_ix', 'batch_ix', 'sequence_ix', 'token_ix']],
            how = 'inner',
            on = ['sequence_ix', 'token_ix', 'batch_ix']
        )\
        .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])

    c4_sample_df =\
        sample_df_raw\
        .drop(columns = ['batch_ix', 'sequence_ix'])

    return c4_sample_df, c4_topk_df

sample_df, topk_df = clean_mappings(model_architecture, sample_df_import, topk_df_import)
del sample_df_import, topk_df_import

gc.collect()
display(sample_df)
display(topk_df)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + layer-wise dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

## Run LR

In [None]:
"""
Run logistic regression probes to get role space models
"""
def fit_lr(x_cp, y_cp):
    """
    Fit a probe with a standard 80/20 split
    """
    x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.2, random_state = 123)
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', C = 0.0001, max_iter = 1000, fit_intercept = True)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    return lr_model, accuracy

def get_probe_result(layer_ix, label2id):
    """
    Get probe results for a single layer and label combination

    Params:
        @layer_ix: The layer index to train the probe on
        @label2id: The label-to-id mapping
    
    Description:
        Trains only on content space for given roles
    """
    # id2label = {v: k for k, v in label2id.items()}
    roles = list(label2id.keys())

    valid_sample_ix =\
        sample_df\
        .pipe(lambda df: df[(df['in_content_span'] == True) & (df['role'].notna()) & (df['role'].isin(roles))])\
        ['sample_ix'].tolist()

    role_labels = [
        label2id[role]
        for role in sample_df[sample_df['sample_ix'].isin(valid_sample_ix)]['role'].tolist()
    ]

    role_labels_cp = cupy.asarray(role_labels)
    x_cp = cupy.asarray(all_pre_mlp_hs[layer_ix][valid_sample_ix, :].to(torch.float16).detach().cpu())
    lr_model, test_acc = fit_lr(x_cp, role_labels_cp)

    print(test_acc)
    
    return {
        'layer_ix': layer_ix, 'label2id': label2id,
        'roles': roles, 'probe': lr_model, 'accuracy': test_acc
    }

layers_to_test = list(all_pre_mlp_hs.keys()) # [i for i in list(all_pre_mlp_hs.keys()) if i % 4 == 0]
mappings_to_test = [
    {'user': 0, 'assistant-cot': 1, 'assistant-final': 2},
    {'system': 0, 'user': 1, 'assistant-cot': 2, 'assistant-final': 3},
    {'system': 0, 'user': 1, 'assistant-cot': 2, 'assistant-final': 3, 'tool': 4},
    {'user': 0, 'assistant-cot': 1, 'assistant-final': 2, 'tool': 3}
]

probes = []
for layer_ix in tqdm(layers_to_test):
    for mapping in mappings_to_test:
        probes.append(get_probe_result(layer_ix, mapping))

## Save probes

In [None]:
"""
Save
"""
with open(f'{ws}/experiments/role-analysis/probes/{model_prefix}.pkl', 'wb') as f:
    pickle.dump(probes, f)