## Install packages and set up imports

In [1]:
!pip install collinear --upgrade



In [7]:
import json
from pathlib import Path

import nest_asyncio
from collinear.client import Client

# Necessary to run in a Jupyter notebook
nest_asyncio.apply()
import argparse
from pprint import pprint


In [15]:
def header(title: str) -> None:
    line = "=" * len(title)
    print(line)
    print(title)
    print(line)

def display_persona(sim_runner, simulation):
    raw = build_steering_persona(sim_runner, getattr(simulation, 'steer', None))
    persona = {
      'characteristics': dict(raw.get('characteristics', {})),
      'traits': dict(raw.get('traits', {})),
    }
    pprint(persona)

def make_dataset_row(simulation_runner, simulation):
    """Build a single dataset row with persona characteristics included."""
    return {
        'conversation': format_conversation(simulation.conv_prefix),
        'last_response': simulation.response,
        'steering_persona': build_steering_persona(simulation_runner, getattr(simulation, 'steer', None)),
    }

def format_conversation(conversation_prefix):
    """Compact conversation text from message dicts."""
    return '\n'.join(
        f"{message.get('role', '')}: {message.get('content', '')}"
        for message in conversation_prefix
        if message.get('content')
    )

def build_steering_persona(simulation_runner, steer_combination):
    """Return persona metadata via the runner's normalization helpers."""
    if steer_combination is None:
        return {'characteristics': {}, 'traits': {}}
    characteristics: dict[str, object] = {}
    try:
        characteristics = simulation_runner._user_characteristics_payload(steer_combination)
    except Exception:
        characteristics = {}
    traits = getattr(steer_combination, 'traits', {}) or {}
    return {'characteristics': characteristics, 'traits': traits}

## Load model, setup client
Please update the config in configs/simulations_config with your Assistant model name, Assistant API key, Stere API key before proceeding to the next step

In [9]:
# Config Variables (from configs/simulation_config.json and steering_config_*.json)

CONFIG_DIR = Path('configs')
SIMULATION_CONFIG_FILE = CONFIG_DIR / 'simulation_config.json'
config_data = json.loads(SIMULATION_CONFIG_FILE.read_text())

steering_config_value = config_data.get('steering_config_file') or 'steering_config_airline.json'
steering_candidate = Path(steering_config_value)
if not steering_candidate.is_absolute():
    steering_candidate = CONFIG_DIR / steering_candidate.name
STEERING_CONFIG_FILE = steering_candidate
STEER_CONFIG = json.loads(STEERING_CONFIG_FILE.read_text())
STEER_TASKS = list(STEER_CONFIG.get('tasks', []))

# Client options
client_settings = config_data.get('client', {}) or {}
CLIENT_ASSISTANT_MODEL_URL = client_settings.get('assistant_model_url', 'https://api.openai.com/v1')
CLIENT_ASSISTANT_MODEL_API_KEY = client_settings.get('assistant_model_api_key')
CLIENT_ASSISTANT_MODEL_NAME = client_settings.get('assistant_model_name', 'gpt-4o-mini')
CLIENT_STEER_API_KEY = client_settings.get('steer_api_key', 'demo-001')
CLIENT_TIMEOUT = float(client_settings.get('timeout', 120))
CLIENT_MAX_RETRIES = int(client_settings.get('max_retries', 3))
CLIENT_RATE_LIMIT_RETRIES = int(client_settings.get('rate_limit_retries', 6))

# Simulation options
simulate_settings = config_data.get('simulate', {}) or {}
SIM_SAMPLES = simulate_settings.get('k', 1)
SIM_EXCHANGES = simulate_settings.get('num_exchanges', 2)
SIM_DELAY = simulate_settings.get('batch_delay', 0.2)
SIM_STEER_TEMPERATURE = simulate_settings.get('steer_temperature', 0.7)
SIM_STEER_MAX_TOKENS = simulate_settings.get('steer_max_tokens', 256)
SIM_MIX_TRAITS = bool(simulate_settings.get('mix_traits', False))
SIM_MAX_CONCURRENCY = int(simulate_settings.get('max_concurrency', 8))

# Assessment options
assess_settings = config_data.get('assess', {}) or {}
ASSESS_JUDGE_MODEL_URL = assess_settings.get('judge_model_url')
ASSESS_JUDGE_MODEL_API_KEY = assess_settings.get('judge_model_api_key')
ASSESS_JUDGE_MODEL_NAME = assess_settings.get('judge_model_name')
ASSESS_TEMPERATURE = assess_settings.get('temperature', 0.0)
ASSESS_MAX_TOKENS = assess_settings.get('max_tokens', 512)

tasks_display = STEER_TASKS if STEER_TASKS else '<none>'
print(f'Loaded simulation: {SIMULATION_CONFIG_FILE} | steering: {STEERING_CONFIG_FILE} | tasks: {tasks_display}')


Loaded simulation: configs/simulation_config.json | steering: configs/steering_config_airline.json | tasks: ['airline support']


In [10]:
import os

if CLIENT_ASSISTANT_MODEL_API_KEY:
    os.environ['OPENAI_API_KEY'] = CLIENT_ASSISTANT_MODEL_API_KEY
if CLIENT_ASSISTANT_MODEL_URL:
    os.environ['OPENAI_BASE_URL'] = CLIENT_ASSISTANT_MODEL_URL
if CLIENT_STEER_API_KEY:
    os.environ['STEER_API_KEY'] = CLIENT_STEER_API_KEY


## Client setup

In [11]:
# Client setup

if not CLIENT_ASSISTANT_MODEL_API_KEY:
    raise RuntimeError('assistant_model_api_key must be set in configs/simulation_config.json')

client = Client(
    assistant_model_url=CLIENT_ASSISTANT_MODEL_URL,
    assistant_model_api_key=CLIENT_ASSISTANT_MODEL_API_KEY,
    assistant_model_name=CLIENT_ASSISTANT_MODEL_NAME,
    steer_api_key=CLIENT_STEER_API_KEY,
    timeout=CLIENT_TIMEOUT,
    max_retries=CLIENT_MAX_RETRIES,
    rate_limit_retries=CLIENT_RATE_LIMIT_RETRIES,
)

runner = client.simulation_runner


## Simulate samples

In [16]:
# Generate simulations

simulations = client.simulate(
    steer_config=STEER_CONFIG,
    k=SIM_SAMPLES,
    num_exchanges=SIM_EXCHANGES,
    batch_delay=SIM_DELAY,
    steer_temperature=SIM_STEER_TEMPERATURE,
    steer_max_tokens=SIM_STEER_MAX_TOKENS,
    mix_traits=SIM_MIX_TRAITS,
    max_concurrency=SIM_MAX_CONCURRENCY,
)


# Save to file
output_dir = Path()
output_dir.mkdir(parents=True, exist_ok=True)
dataset_path = output_dir / "simulated_rl.jsonl"
with dataset_path.open('w', encoding='utf-8') as dataset_file:
    for simulation in simulations:
        dataset_row = make_dataset_row(runner, simulation)
        dataset_file.write(json.dumps(dataset_row, ensure_ascii=False) + '\n')
print(f'Wrote dataset to: {dataset_path}')

with open("simulated_rl.jsonl") as f:
    for i, x in enumerate(f):
        print("="*8)
        print(f"Conversation {i+1}")
        print("="*8)
        pprint(json.loads(x)["steering_persona"])
        print(json.loads(x)["conversation"])
        if json.loads(x)["last_response"] ==  "###STOP###":
            print("user: ", json.loads(x)["last_response"])
        else:
            print("assistant: ", json.loads(x)["last_response"])
        print()


User/Assistant turns:  87%|████████▋ | 61/70 [01:07<00:09,  1.10s/query]

Wrote dataset to: simulated_rl.jsonl
Conversation 1
{'characteristics': {'age': 70,
                     'gender': 'female',
                     'intent': 'track_baggage',
                     'language': 'English',
                     'location': 'Australia',
                     'occupation': 'freelancer',
                     'task': 'airline support'},
 'traits': {'confusion': 2}}
user: Hi, I'm having trouble tracking my baggage on my flight from Sydney to Perth with Flight 121. Can you help me check the status of my checked-in luggage?
assistant: I'm sorry to hear you're having trouble tracking your baggage. Unfortunately, I don't have real-time tracking capabilities for luggage. I recommend checking the airline's official website or app for live updates on your baggage status. You can also contact the airline's customer service directly for assistance. If you need further help, feel free to ask!
user: I'm getting a bit anxious about my luggage, I was really looking forward to w




## Assess agent in multi-turn setting

In [None]:
# Assess
result = client.assess(
    dataset=simulations,
    judge_model_url=ASSESS_JUDGE_MODEL_URL,
    judge_model_api_key=ASSESS_JUDGE_MODEL_API_KEY,
    judge_model_name=ASSESS_JUDGE_MODEL_NAME,
    temperature=ASSESS_TEMPERATURE,
    max_tokens=ASSESS_MAX_TOKENS,
)
print(f"Assessment: {result.message or '<no message>'}")

with dataset_path.open('r', encoding='utf-8') as dataset_file:
    dataset_rows = [json.loads(line) for line in dataset_file if line.strip()]

for i, (scores_map, row) in enumerate(zip(result.evaluation_result, dataset_rows), start=1):
    print('=' * 8)
    print(f"Conversation {i}")
    print('=' * 8)
    pprint(row['steering_persona'])
    print(row['conversation'])
    print("assistant: ", row['assistant_response'])
    for metric_name, scores in scores_map.items():
        print(f"  Score: {scores.score}")
        print(f"  Rationale: {scores.rationale}")
    print()


## Make **realistic** RL environments for tool-use agents (Tau-Bench-**hard**)

In [None]:
!uv pip install -i https://test.pypi.org/simple/ tau-trait -U
import argparse
from tau_trait.types import RunConfig
from tau_trait.run import run
from litellm import provider_list
from tau_trait.envs.user import UserStrategy

In [18]:
#load config from tau_hard_config.json
with open("configs/tau_hard_config.json", "r") as f:
    config = RunConfig(**json.load(f))

print("FOR CLARITY, TOOL CALLS ARE NOT STREAMED BUT CAN BE VIEWED IN THE RESULTS FILE")
run(config)
#results are saved under the results directory in the format of <agent_strategy>-<model>-<temperature>_range_<start_index>-<end_index>_user-<user_model>-<user_strategy>_traits-<trait_dict>_<timestamp>.json
# Example: tool_calling_agent-gpt-4o-mini-0.7_range_0-10_user-gpt-4o-mini-user_traits-skeptical-2_impatience-1_incoherence-0_confusion-0_2025-09-19_15-30-00.json

FOR CLARITY, TOOL CALLS ARE NOT STREAMED BUT CAN BE VIEWED IN THE RESULTS FILE
Loading user with strategy: traitbasis
--------------------------------
THE PROVIDER IS steer
--------------------------------

assistant: Hi! How can I help you today?
user: I'm trying to track down a package that was supposed to be delivered yesterday; I had an order confirmation email that said it would arrive by now, and the tracking info is saying it's still in transit. Can you look up an order for me?

Running tasks [4] (checkpoint path: results/tool-calling-gpt-4o-mini-0.7_range_0--1_user-traitbasis_traits-impatience-1_confusion-0_skeptical-2_incoherence-0_0919174223.json)
--------------------------------
THE PROVIDER IS steer
--------------------------------

assistant: Hi! How can I help you today?
user: I'm trying to track down a package that's supposed to arrive today, and I think it might be stuck in customs. Can you look up the status of a package that was shipped from New York to California las

[EnvRunResult(task_id=4, reward=0.0, info={'task': {'user_id': 'yusuf_rossi_9620', 'actions': [{'name': 'find_user_id_by_name_zip', 'kwargs': {'first_name': 'Yusuf', 'last_name': 'Rossi', 'zip': '19122'}}, {'name': 'get_product_details', 'kwargs': {'product_id': '6086499569'}}, {'name': 'list_all_product_types', 'kwargs': {}}, {'name': 'get_product_details', 'kwargs': {'product_id': '9523456873'}}, {'name': 'get_user_details', 'kwargs': {'user_id': 'yusuf_rossi_9620'}}, {'name': 'get_order_details', 'kwargs': {'order_id': '#W6247578'}}, {'name': 'get_order_details', 'kwargs': {'order_id': '#W9711842'}}, {'name': 'get_order_details', 'kwargs': {'order_id': '#W4776164'}}, {'name': 'get_order_details', 'kwargs': {'order_id': '#W6679257'}}, {'name': 'get_order_details', 'kwargs': {'order_id': '#W2378156'}}, {'name': 'get_product_details', 'kwargs': {'product_id': '9523456873'}}, {'name': 'get_user_details', 'kwargs': {'user_id': 'yusuf_rossi_9620'}}, {'name': 'modify_pending_order_items', 