In [None]:
"""
Generate conversational datasets for each model
"""
None

In [None]:
"""
Imports
"""
import pandas as pd
import numpy as np
from dotenv import load_dotenv
import yaml
from datasets import load_dataset
import re
import duckdb
import os
import huggingface_hub

from utils.openrouter import get_openrouter_responses

seed = 1234

load_dotenv('.env')
std_params = {'temperature': 0, 'top_p': 1, 'top_k': 0}
ws = '/workspace/deliberative-alignment-jailbreaks'

# Load

In [None]:
"""
Load WildChat data
If it gets stuck, make sure you're logged in and the token is passed:
 - huggingface_hub.login(); huggingface_hub.whoami()
 - token = huggingface_hub.get_token(); duckdb.sql(f"CREATE SECRET IF NOT EXISTS hf_token (TYPE HUGGINGFACE, TOKEN '{token}')")
"""

def load_raw_ds(n_samples = 1000, n_shards = 4):
    cache_path = f"{ws}/experiments/role-analysis/wildchat_sample.parquet"

    if os.path.exists(cache_path):
        return pd.read_parquet(cache_path)
    
    files = huggingface_hub.list_repo_files("allenai/WildChat-4.8M", repo_type="dataset")
    shards = sorted([f"hf://datasets/allenai/WildChat-4.8M/{f}" for f in files if f.endswith('.parquet')])[:n_shards]
    
    duckdb.sql("SELECT setseed(0.1234)") 
    df = duckdb.sql(f"""
        SELECT conversation, model, turn
        FROM read_parquet({shards})
        WHERE country = 'United States' AND language = 'English' AND turn >= 1
        ORDER BY random()
        LIMIT {n_samples}
    """).df()

    os.makedirs(os.path.dirname(cache_path), exist_ok = True)
    df.to_parquet(cache_path)
    return df

wildchat_ds = load_raw_ds(n_samples = 100)
wildchat_ds = [
    [msg['content'] for msg in conv if msg['role'] == 'user']
    for conv in wildchat_ds['conversation']
]
wildchat_ds

In [None]:
"""
Load Toxic-Chat
"""
def load_raw_ds(n_samples = 100):
    def get_convs():
        return load_dataset('lmsys/toxic-chat', 'toxicchat1123', split = 'train', streaming = True).shuffle(seed = seed, buffer_size = 50_000)
    
    def get_data(ds, n_samples):
        raw_data = []
        for sample in ds:
            raw_data.append([sample['user_input']])
            if len(raw_data) >= n_samples:
                break
        return raw_data
    
    return get_data(get_convs(), n_samples)    

toxicchat_ds = load_raw_ds(10)

In [None]:
"""
Asssign conversation ids and user query indices within each conv; truncate convs to 5 queries
"""
user_queries_df =\
    pd.concat([
        pd.DataFrame({'convs': wildchat_ds}).assign(dataset = 'wildchat'),
        pd.DataFrame({'convs': toxicchat_ds}).assign(dataset = 'toxicchat')
    ], ignore_index = True)\
    .assign(conv_id = lambda df: range(len(df)))\
    .explode('convs')\
    .assign(user_query_ix = lambda df: df.groupby('conv_id').cumcount())\
    .pipe(lambda df: df[df['user_query_ix'] < 5])\
    .rename(columns={'convs': 'user_query'})\
    [['conv_id', 'dataset', 'user_query_ix', 'user_query']]

display(user_queries_df.groupby('conv_id').agg({'user_query_ix': 'max'}).value_counts().sort_index())

user_queries_df

In [None]:
"""
Define target models
"""
target_models = [
    {'model': 'qwen/qwen3-30b-a3b-thinking-2507', 'model_provider': 'alibaba'}
    # {'model': 'openai/gpt-oss-120b', 'model_provider': 'nebius/fp4', 'policy_prompt': 'openai'},
    # {'model': 'qwen/qwen3-vl-30b-a3b-thinking', 'model_provider': 'novita/fp16', 'policy_prompt': 'qwen3-vl-30b-a3b'},
]

def _validate_and_extract_response(llm_response):
    """
    Extract content/reasoning from response
    
    Params:
        @llm_response: The LLM response object
    """
    if 'choices' not in llm_response:
        print(llm_response)
        return {'reasoning': None, 'output': None}

    choice = llm_response['choices'][0]
    if choice['finish_reason'] == 'length':
        print(f"Warning - early stop: {choice['finish_reason']}")
        print(f"  CONTENT: {choice['message']['content']}")
        print(f"  REASONING: {choice['message']['reasoning']}")
        return {'reasoning': None, 'output': None}

    return {
        'reasoning': choice['message']['reasoning'],
        'output': choice['message']['content'],
    }


In [None]:
"""
Generate data
"""
async def generate_convs_data(user_queries_df, target_dir, target_models):
    """
    Iterate over target models and run generations
    """
    os.makedirs(target_dir, exist_ok = True)

    all_model_results = []

    max_rounds = user_queries_df['user_query_ix'].max() + 1  # +1 since 0-indexed
    for target_model in target_models:
        model_name = target_model['model']
        filename = model_name.split('/')[-1] + ".csv"
        out_path = os.path.join(target_dir, filename)
        
        # Skip model if already generated
        if os.path.exists(out_path):
            print(f"Skipping {model_name} - already exists at {out_path}")
            continue

        print(f"\n{'='*50}")
        print(f"Processing model: {model_name}")
        print(f"{'='*50}")

        # Store responses: (conv_id, user_query_ix) -> {'reasoning': ..., 'output': ...}
        responses = {}

        for round_ix in range(max_rounds):

            print(f"Processing round {round_ix}...")
            current_round_df = user_queries_df[user_queries_df['user_query_ix'] == round_ix]
            # Build conversation histories for each query in this round
            message_histories = []
            conv_ids = []
            skipped_count = 0

            for _, row in current_round_df.iterrows():
                # Skip conversations that had a failure in any previous round
                has_prior_failure = any(responses[(row['conv_id'], prev_ix)]['output'] is None for prev_ix in range(round_ix))
                if has_prior_failure:
                    skipped_count += 1
                    responses[(row['conv_id'], round_ix)] = {'output': None, 'reasoning': None}
                    continue
                
                messages = []
            
                # Add all previous turns (user + assistant pairs)
                for prev_ix in range(round_ix):
                    # Get previous user query
                    prev_query =\
                        user_queries_df\
                        .pipe(lambda df: df[(df['conv_id'] == row['conv_id']) & (df['user_query_ix'] == prev_ix)])\
                        ['user_query'].item()
                    messages.append({'role': 'user', 'content': prev_query})
                    # Get previous assistant response (output only, no CoT)
                    prev_response = responses[(row['conv_id'], prev_ix)]
                    messages.append({'role': 'assistant', 'content': prev_response['output']})
            
                # Add current round's user query
                messages.append({'role': 'user', 'content': row['user_query']})
                message_histories.append(messages)
                conv_ids.append(row['conv_id'])
        
            if len(message_histories) == 0:
                print(f"  All {skipped_count} conversations skipped due to prior failures")
                continue
        
            # Call API
            raw_llm_responses = await get_openrouter_responses(
                message_histories,
                {
                    'model': target_model['model'],
                    'provider': {'order': [target_model['model_provider']], 'allow_fallbacks': False},
                    'reasoning': {'effort': 'medium', 'enabled': True},
                    'temperature': 0,
                    'max_tokens': 10_000,
                    **std_params
                },
                batch_size = 20
            )
        
            # Store responses for use in subsequent rounds
            for conv_id, raw_response in zip(conv_ids, raw_llm_responses):
                extracted = _validate_and_extract_response(raw_response)
                responses[(conv_id, round_ix)] = extracted
            
            print(f"  Completed: {len(conv_ids)} conversations, skipped: {skipped_count}")

        # Convert results to a dataframe for easier analysis
        responses_df =\
            pd.DataFrame([
                {'conv_id': conv_id, 'user_query_ix': user_query_ix, **response}
                for (conv_id, user_query_ix), response in responses.items()
            ])\
            .rename(columns = {'output': 'assistant', 'reasoning': 'cot'})
        model_results_df = user_queries_df.merge(responses_df, on = ['conv_id', 'user_query_ix'], how = 'left').assign(model = model_name)
        
        model_results_df.to_csv(out_path, index = False)
        all_model_results.append(model_results_df)

    if not all_model_results:
        return None
    
    return pd.concat(all_model_results, ignore_index = True)

result_df = await generate_convs_data(
    user_queries_df,
    target_dir = f"{ws}/experiments/role-analysis/convs",
    target_models = target_models
)

In [None]:
result_df.head(50)