In [None]:
"""
This generates conversational datasets for each model - this is used for evaluating validity of role probes. This takes user 
messages from toxicchat/oasst, then generates model-specific assistant responses for each model using Openrouter. Allows for
multiturn conversations. Runs models locally as a fallback if unavailable via API.
"""
None

In [None]:
"""
Imports
"""
import pandas as pd
from dotenv import load_dotenv
from datasets import load_dataset
import os
from tqdm import tqdm

from utils.openrouter import get_openrouter_responses

seed = 1234

load_dotenv('.env')
std_params = {'temperature': 1, 'top_p': 1, 'top_k': 0} # Most reasoning models require 1 temp
ws = '/workspace/deliberative-alignment-jailbreaks'

# Import conversations + retain only user messages

In [None]:
"""
Load toxic-chat
"""
def load_raw_ds(n_samples = 1000):
    def get_convs():
        return load_dataset('lmsys/toxic-chat', 'toxicchat1123', split = 'train').shuffle(seed = seed)
    
    def get_data(ds, n_samples):
        raw_data = []
        for sample in ds:
            user_text = sample['user_input']
            if len(user_text) >= 100 and len(user_text) <= 500:
                raw_data.append([user_text])
            if len(raw_data) >= n_samples:
                break
        return raw_data
    
    return get_data(get_convs(), n_samples)

toxicchat_ds = load_raw_ds(100)
toxicchat_ds

In [None]:
"""
Oasst data
"""
def load_oasst_conversations(n_samples = 1_000):
    ds = load_dataset('OpenAssistant/oasst1', split = 'train', streaming = False) # Do not shuf

    # Extract valid rows
    rows = [
        {
            'message_tree_id': ex['message_tree_id'],
            'message_id': ex['message_id'],
            'parent_id': ex['parent_id'],
            'role': ex['role'],          # "prompter" or "assistant"
            'text': ex['text'],
        }
        for ex in ds
        if ex['lang'] == 'en' and ex['tree_state'] == 'ready_for_export'
    ]    

    # Group by tree ID
    trees = {}
    for r in rows:
        tid = r["message_tree_id"]
        trees.setdefault(tid, []).append(r)

    # Iterate over trees, reconstruct convs
    conversations = []
    tree_ids = list(trees.keys())

    for tid in tree_ids:
        msgs = trees[tid]
        
        # Build lookups
        id2msg = {m['message_id']: m for m in msgs}
        children = {}
        for m in msgs:
            pid = m["parent_id"]
            if pid:
                children.setdefault(pid, []).append(m["message_id"])

        # Find Root
        roots = [m for m in msgs if not m["parent_id"]]
        if not roots:
            continue
        root = roots[0]

        # Reconstruct linear path (Root -> Leaf)
        path_ids = [root["message_id"]]
        cur_id = root["message_id"]
        
        while cur_id in children:
            child_id = children[cur_id][0] # Simple heuristic: take first child
            path_ids.append(child_id)
            cur_id = child_id

        conv = [id2msg[mid] for mid in path_ids]

        # Validation tree
        is_valid = True
        for msg in conv:
            if msg['role'] == 'prompter':
                # If ANY user message is out of bounds, the whole conv is bad
                if not (100 <= len(msg['text']) <= 500):
                    is_valid = False
                    break
        
        if is_valid:
            conversations.append(conv)

        if n_samples is not None and len(conversations) >= n_samples:
            break

    return conversations

oasst_convs = load_oasst_conversations(100)
oasst_user_only = [
    [m['text'] for m in conv if m['role'] == 'prompter']
    for conv in oasst_convs
]
oasst_user_only

In [None]:
"""
Assign conversation ids and user query indices within each conv; truncate convs to 2 queries
"""
user_queries_df =\
    pd.concat([
        pd.DataFrame({'convs': oasst_user_only}).assign(dataset = 'oasst'),
        pd.DataFrame({'convs': toxicchat_ds}).assign(dataset = 'toxicchat')
    ], ignore_index = True)\
    .assign(conv_id = lambda df: range(len(df)))\
    .explode('convs')\
    .assign(user_query_ix = lambda df: df.groupby('conv_id').cumcount())\
    .pipe(lambda df: df[df['user_query_ix'] <= 1])\
    .rename(columns = {'convs': 'user_query'})\
    [['conv_id', 'dataset', 'user_query_ix', 'user_query']]

display(user_queries_df.groupby('dataset', as_index = False).agg(convs = ('conv_id', 'nunique')))
display(user_queries_df.groupby('conv_id', as_index = False).agg(n_msg = ('user_query_ix', 'max')).groupby('n_msg', as_index = False).agg(n_msg_count = ('n_msg', 'count')))

user_queries_df

# Generate assistant/cot responses by model

In [None]:
"""
Define target models
"""
target_models = [
    # Baseline models
    {'model': 'moonshotai/kimi-k2-thinking', 'model_provider': 'google-vertex', 'model_prefix': 'kimi-k2-thinking'}, # U
    {'model': 'anthropic/claude-haiku-4.5', 'model_provider': 'google-vertex', 'model_prefix': 'haiku-4.5'}, # U
    {'model': 'minimax/minimax-m2.1', 'model_provider': 'minimax/fp8', 'model_prefix': 'minimax-m2.1'}, # U
    {'model': 'qwen/qwen3-next-80b-a3b-thinking', 'model_provider': 'google-vertex', 'model_prefix': 'qwen3-next'}, # U

    # Test cases
    {'model': 'openai/gpt-oss-20b', 'model_provider': 'nebius/fp4', 'model_prefix': 'gptoss-20b'}, # U
    {'model': 'openai/gpt-oss-120b', 'model_provider': 'nebius/fp4', 'model_prefix': 'gptoss-120b'}, # U
    {'model': 'nvidia/nemotron-3-nano-30b-a3b', 'model_provider': 'deepinfra/bf16', 'model_prefix': 'nemotron-3-nano'}, # U
    {'model': 'z-ai/glm-4.6v', 'model_provider': 'z-ai/fp8', 'model_prefix': 'glm-4.6v-flash'}, # U
    {'model': 'qwen/qwen3-30b-a3b-thinking-2507', 'model_provider': 'alibaba', 'model_prefix': 'qwen3-30b-a3b'}, # U
    {'model': 'z-ai/glm-4.7-flash', 'model_provider': 'novita/bf16', 'model_prefix': 'glm-4.7-flash'}, # U
]

def _validate_and_extract_response(llm_response):
    """
    Extract content/reasoning from response
    
    Params:
        @llm_response: The LLM response object
    """
    if 'choices' not in llm_response:
        print(llm_response)
        return {'reasoning': None, 'output': None}

    choice = llm_response['choices'][0]
    if choice['finish_reason'] == 'length':
        print(f"Warning - early stop: {choice['finish_reason']}")
        # print(f"  CONTENT: {choice['message'].get('content')}")
        # print(f"  REASONING: {choice['message'].get('reasoning')}")
        return {'reasoning': None, 'output': None}

    content = choice['message'].get('content')
    if content is None or (isinstance(content, str) and content.strip() == ''):
        return {'reasoning': None, 'output': None}

    return {
        'reasoning': choice['message'].get('reasoning'),
        'output': content,
    }

In [None]:
"""
Generate data and save them by model!
"""
BATCH_SIZE = 10
MAX_TOKS = 4_000 # Cull to deal with models with endless reasoning

async def generate_convs_data(user_queries_df, target_dir, target_models):
    """
    Iterate over target models
    """ 
    os.makedirs(target_dir, exist_ok=True)
    all_model_results = []

    query_lookup = user_queries_df.set_index(['conv_id', 'user_query_ix'])['user_query'].to_dict()
    max_rounds = user_queries_df['user_query_ix'].max() + 1

    for target_model in target_models:
        model_name = target_model['model']
        model_prefix = target_model['model_prefix']
        out_path = os.path.join(target_dir, model_prefix + '.csv')
        
        if os.path.exists(out_path):
            print(f"Skipping {model_name} - already exists")
            continue

        print(f"\n{'='*50}\nProcessing model: {model_name}\n{'='*50}")
        responses = {}
        failed_conv_ids = set()

        for round_ix in range(max_rounds):
            print(f"Processing round {round_ix}...")
            current_round_df = user_queries_df[user_queries_df['user_query_ix'] == round_ix]
            
            message_histories = []
            conv_ids = []

            for _, row in current_round_df.iterrows():
                cid = row['conv_id']
                
                # Skip if conversation is already marked as failed
                if cid in failed_conv_ids:
                    responses[(cid, round_ix)] = {'output': None, 'reasoning': None}
                    continue
                
                messages = []
                for prev_ix in range(round_ix):
                    prev_query = query_lookup.get((cid, prev_ix))
                    messages.append({'role': 'user', 'content': prev_query})
                    
                    prev_response = responses[(cid, prev_ix)]
                    messages.append({'role': 'assistant', 'content': prev_response['output']})
            
                messages.append({'role': 'user', 'content': row['user_query']})
                message_histories.append(messages)
                conv_ids.append(cid)
        
            if not message_histories:
                print(f"  All conversations skipped due to prior failures")
                continue
        
            raw_llm_responses = await get_openrouter_responses(
                message_histories,
                {
                    'model': target_model['model'],
                    'provider': {'order': [target_model['model_provider']], 'allow_fallbacks': False},
                    'reasoning': {'effort': 'medium', 'enabled': True},
                    'max_tokens': MAX_TOKS,
                    **std_params
                },
                batch_size = BATCH_SIZE
            )
        
            for conv_id, raw_response in zip(conv_ids, raw_llm_responses, strict = True):
                extracted = _validate_and_extract_response(raw_response)
                responses[(conv_id, round_ix)] = extracted
                
                # Mark as failed immediately if output is None
                if extracted['output'] is None:
                    failed_conv_ids.add(conv_id)
            
            print(f"  Failed: {len(failed_conv_ids)}")

        responses_df = pd.DataFrame([
            {'conv_id': c, 'user_query_ix': ix, **r}
            for (c, ix), r in responses.items()
        ]).rename(columns={'output': 'assistant', 'reasoning': 'cot'})
        
        if responses_df.empty:
             print(f"  Warning: No responses generated for {model_name}")
             continue

        model_results_df = user_queries_df.merge(responses_df, on=['conv_id', 'user_query_ix'], how='left')
        model_results_df['model'] = model_name
        model_results_df['model_prefix'] = model_prefix

        is_invalid = (
            model_results_df['assistant'].isna() | 
            (model_results_df['assistant'].astype(str).str.strip() == '')
        )
        poisoned_ids = model_results_df.loc[is_invalid, 'conv_id'].unique()
        
        if len(poisoned_ids) > 0:
            print(f"  Dropping {len(poisoned_ids)} conversations.")
            
        clean_results_df = model_results_df[~model_results_df['conv_id'].isin(poisoned_ids)].copy()

        if not clean_results_df.empty:
            clean_results_df.to_csv(out_path, index=False)
            all_model_results.append(clean_results_df)

    if not all_model_results:
        return None
    
    return pd.concat(all_model_results, ignore_index=True)


result_df = await generate_convs_data(
    user_queries_df,
    target_dir = f"{ws}/experiments/role-analysis/data/conversations",
    target_models = target_models
)

# Local fallback for unavailable models

In [None]:
"""
All below code is a fallback for models which do NOT have available endpoints. Avoid if possible.
"""
import torch

from utils.loader import load_model_and_tokenizer
from utils.memory import check_memory

main_device = 'cuda:0'

model_prefix = 'gptoss-20b'
tokenizer, model, model_architecture, model_n_layers = load_model_and_tokenizer(model_prefix, device = main_device)

check_memory()

In [None]:
"""
Run generations
"""
MAX_INPUT_TOKENS = 1024 * 12
MAX_NEW_TOKENS = 4_000
TEMPERATURE = 1.0
BATCH_SIZE = 16

def chunk_list(lst, n):
    return [lst[i:i + n] for i in range(0, len(lst), n)]

@torch.inference_mode()
def generate_local_batch(message_histories):
    """
    Params:
        @message_histories: list[list[{'role': str, 'content': str}]]
    
    Returns:
        list[str] decoded new tokens only
    """
    enc = tokenizer.apply_chat_template(
        message_histories,
        tokenize = True,
        add_generation_prompt = True,
        return_tensors = 'pt',
        padding = True,
        truncation = True,
        max_length = MAX_INPUT_TOKENS,
        return_dict = True
    )

    input_ids = enc['input_ids'].to(main_device)
    attention_mask = enc['attention_mask'].to(main_device)
    prompt_len = input_ids.shape[1] # IMPORTANT: with padding, prompt_len is the padded length, so new tokens start here for all rows

    out = model.generate(
        input_ids = input_ids,
        attention_mask = attention_mask,
        max_new_tokens = MAX_NEW_TOKENS,
        do_sample = True,
        temperature = TEMPERATURE,
        pad_token_id = tokenizer.pad_token_id,
        eos_token_id = tokenizer.eos_token_id,
        use_cache = True,
    )

    new_tokens = out[:, prompt_len:] # (batch, <=max_new_tokens) padded on the right

    assistant_texts, state_texts = [], []

    for i in range(out.shape[0]):
        prompt_ids = input_ids[i][attention_mask[i].bool()].tolist() # Remove lpad from the prompt using the attention mask

        gen_ids = new_tokens[i].tolist()
        # Strip right-pad tokens added by batching different gen lengths
        while gen_ids and gen_ids[-1] == tokenizer.pad_token_id:
            gen_ids.pop()

        # Assistant text used for subsequent rounds (no special tokens)
        asst = tokenizer.decode(gen_ids, skip_special_tokens = True).strip()
        asst = asst if asst else None

        # Full composite prompt + output (keep special tokens / instruct formatting)
        state_text = tokenizer.decode(prompt_ids + gen_ids, skip_special_tokens = False)

        assistant_texts.append(asst)
        state_texts.append(state_text)

    return assistant_texts, state_texts


def generate_convs_data_local(user_queries_df, target_dir):
    os.makedirs(target_dir, exist_ok = True)
    out_path = os.path.join(target_dir, model_prefix + '-raw.csv')

    print(f"\n{'='*50}")
    print(f"Local generation model: {model_prefix}")

    responses = {}
    max_rounds = int(user_queries_df['user_query_ix'].max()) + 1

    for round_ix in range(max_rounds):
        print(f"Round {round_ix}...")

        current_round_df = user_queries_df[user_queries_df["user_query_ix"] == round_ix]
        message_histories = []
        conv_ids = []
        skipped = 0

        for _, row in current_round_df.iterrows():
            conv_id = int(row["conv_id"])

            # Skip if any prior round missing
            prior_failed = any(responses.get((conv_id, prev_ix), {}).get("assistant") is None for prev_ix in range(round_ix))
            if prior_failed:
                skipped += 1
                responses[(conv_id, round_ix)] = {"assistant": None, "state_text": None}
                continue

            # Prev turns
            msgs = []
            for prev_ix in range(round_ix):
                prev_user = user_queries_df[(user_queries_df['conv_id'] == conv_id) & (user_queries_df["user_query_ix"] == prev_ix)]['user_query'].item()
                msgs.append({"role": "user", "content": prev_user})
                prev_asst = responses[(conv_id, prev_ix)]["assistant"]
                msgs.append({"role": "assistant", "content": prev_asst})
                
            # Current user turn
            msgs.append({"role": "user", "content": row["user_query"]})
            message_histories.append(msgs)
            conv_ids.append(conv_id)

        if not message_histories:
            print(f"  All skipped ({skipped}) due to prior failures")
            continue

        all_asst = []
        all_state = []
        for batch in tqdm(chunk_list(message_histories, BATCH_SIZE)):
            batch_asst, batch_state = generate_local_batch(batch)
            all_asst.extend(batch_asst)
            all_state.extend(batch_state)

        for conv_id, asst_text, state_text in zip(conv_ids, all_asst, all_state):
            responses[(conv_id, round_ix)] = {"assistant": asst_text, "state_text": state_text}

        print(f"  Completed: {len(conv_ids)} convs, skipped: {skipped}")
    
    responses_df = pd.DataFrame(
        [{"conv_id": cid, "user_query_ix": qix, **resp}
        for (cid, qix), resp in responses.items()]
    )

    model_results_df = (
        user_queries_df
        .merge(responses_df, on = ["conv_id", "user_query_ix"], how="left")
        .assign(model = model_prefix, model_prefix = model_prefix)
    )

    model_results_df.to_csv(out_path, index = False)
    return model_results_df

local_result_df = generate_convs_data_local(
    user_queries_df,
    target_dir = f"{ws}/experiments/role-analysis/data/conversations",
)

In [None]:
"""
Look for -raw files, extract cot + assistant
"""
out_path = os.path.join(f"{ws}/experiments/role-analysis/data/conversations", model_prefix + "-raw.csv")

if os.path.exists(out_path):
    raw_df = pd.read_csv(out_path).assign(prompt_ix = lambda df: list(range(0, len(df))))

    # Split into token level dataframe
    tok_dfs = []
    for row in raw_df.to_dict('records'):
        enc = tokenizer(row['state_text'], add_special_tokens = False, return_offsets_mapping = True)
        substrings = [row['state_text'][s:e] for (s, e) in enc['offset_mapping']]
        tok_dfs.append(pd.DataFrame({'token': substrings, 'prompt_ix': row['prompt_ix'], 'token_ix': list(range(0, len(substrings)))}))

    tok_df = pd.concat(tok_dfs, ignore_index = True)

In [None]:
"""
Split CoT/Assistant
"""
import importlib
import utils.role_assignments
importlib.reload(utils.role_assignments)
from utils.role_assignments import label_content_roles

sample_df_labeled = (
    label_content_roles(model_prefix, tok_df)
    .groupby(['prompt_ix', 'seg_ix', 'role'], as_index = False)
    .agg(combined_text = ('token', ''.join))
    .sort_values(['prompt_ix', 'seg_ix'])
)

display(sample_df_labeled)


mask = sample_df_labeled.groupby('prompt_ix')['role'].transform(
    lambda s: s.iloc[-1:].tolist() == ['assistant']
)
sample_df_tail = (
    sample_df_labeled[mask]
    .groupby('prompt_ix')
    .tail(1)
)

final = (
    sample_df_tail
    .pivot(index='prompt_ix', columns='role', values='combined_text')
    .reset_index()
)

# Ensure both columns exist
if 'cot' not in final.columns:
    final['cot'] = None
if 'assistant' not in final.columns:
    final['assistant'] = None  # just in case

merged = (
    raw_df
    .drop(columns=['state_text', 'assistant'])
    .merge(final[['prompt_ix', 'cot', 'assistant']], on='prompt_ix', how='inner')
    .reset_index(drop=True)
    .drop(columns='prompt_ix')
)

display(merged)

merged.to_csv(os.path.join(f"{ws}/experiments/role-analysis/data/conversations", model_prefix + ".csv"), index = False)