## Install and Import Dependencies

Installs required packages and imports all libraries used below.


In [None]:
import json
import time
from pathlib import Path
from datetime import datetime

import nest_asyncio
import together
from collinear.client import Client
from together.abstract import api_requestor
from together.types import TogetherRequest

nest_asyncio.apply()


## Utility functions

Defines helpers to build steering personas, format conversations, and enrich results.


In [None]:
def conversation_lines(messages):
    lines = []
    for message in messages:
        role = message.get('role')
        content = message.get('content')
        if content:
            lines.append(f"{role}: {content}")
    return lines


def conversation_text(messages):
    return "\n".join(conversation_lines(messages))


def persona_from_steer(runner, steer):
    if not steer:
        return {}
    try:
        characteristics = runner._user_characteristics_payload(steer)
    except Exception:
        characteristics = {}
    return {
        'characteristics': characteristics or {},
        'traits': dict(getattr(steer, 'traits', {}) or {}),
    }


def print_evaluation_results(path: Path) -> None:
    for idx, line in enumerate(path.read_text(encoding='utf-8').splitlines(), start=1):
        try:
            row = json.loads(line)
        except json.JSONDecodeError:
            print(f"[{idx}] could not parse result")
            print(line)
            continue
        score = row.get('score', '-')
        passed = row.get('pass')
        rationale = row.get('feedback') or row.get('rationale')
        print(f"[{idx}] score={score} status={(passed if passed is not None else '-')}")
        if rationale:
            print(f"  rationale: {rationale}")


def _needs_fallback(response: str) -> bool:
    if not response:
        return True
    stripped = response.strip()
    if not stripped:
        return True
    if stripped == "###STOP###":
        return True
    if stripped.lower().startswith("assistant returned empty response"):
        return True
    if stripped.lower().startswith("error:"):
        return True
    return False


## Load Config

Loads simulation, judge, and Together settings from JSON files.


In [None]:
CONFIG_DIR = Path('configs')
SIMULATION_CONFIG_FILE = CONFIG_DIR / 'simulation_config.json'
config = json.loads(SIMULATION_CONFIG_FILE.read_text())

steering_name = config.get('steering_config_file', 'steering_config_airline.json')
STEERING_CONFIG_FILE = CONFIG_DIR / Path(steering_name).name
steer_config = json.loads(STEERING_CONFIG_FILE.read_text())
STEER_TASKS = steer_config.get('tasks') or []

client_cfg = config.get('client', {}) or {}
CLIENT_ASSISTANT_MODEL_URL = client_cfg.get('assistant_model_url', 'https://api.together.xyz/v1')
CLIENT_ASSISTANT_MODEL_API_KEY = client_cfg.get('assistant_model_api_key')
CLIENT_ASSISTANT_MODEL_NAME = client_cfg.get('assistant_model_name', 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo')
CLIENT_STEER_API_KEY = client_cfg.get('steer_api_key', 'demo-001')
CLIENT_TIMEOUT = float(client_cfg.get('timeout', 120))
CLIENT_MAX_RETRIES = int(client_cfg.get('max_retries', 3))
CLIENT_RATE_LIMIT_RETRIES = int(client_cfg.get('rate_limit_retries', 6))

sim_cfg = config.get('simulate', {}) or {}
SIM_SAMPLES = sim_cfg.get('k', 3)
SIM_EXCHANGES = sim_cfg.get('num_exchanges', 2)
SIM_DELAY = sim_cfg.get('batch_delay', 0.2)
SIM_STEER_TEMPERATURE = sim_cfg.get('steer_temperature', 0.7)
SIM_STEER_MAX_TOKENS = sim_cfg.get('steer_max_tokens', 256)
SIM_MIX_TRAITS = bool(sim_cfg.get('mix_traits', False))
SIM_MAX_CONCURRENCY = int(sim_cfg.get('max_concurrency', 8))

assess_cfg = config.get('assess', {}) or {}
ASSESS_JUDGE_MODEL_NAME = assess_cfg.get('judge_model_name')

together_cfg = config.get('together', {}) or {}
RESULTS_DIR = Path(together_cfg.get('output_directory', '.')).joinpath('results')
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
JUDGE_SYSTEM_PROMPT = Path(together_cfg.get('judge_system_prompt', 'configs/judge_system_prompt.jinja')).read_text(encoding='utf-8')
TOGETHER_UPLOAD_PURPOSE = together_cfg.get('upload_purpose', 'eval')
TOGETHER_EVAL_TYPE = together_cfg.get('evaluation_type', 'score')
TOGETHER_MODEL_TO_EVALUATE = together_cfg.get('model_to_evaluate', 'assistant_response')
TOGETHER_JUDGE_MODEL_SOURCE = together_cfg.get('judge_model_source', 'serverless')
TOGETHER_MIN_SCORE = together_cfg.get('min_score', 1.0)
TOGETHER_MAX_SCORE = together_cfg.get('max_score', 10.0)
TOGETHER_PASS_THRESHOLD = together_cfg.get('pass_threshold', 7.0)
TOGETHER_POLL_TIMEOUT_SECONDS = int(together_cfg.get('poll_timeout_seconds', 300))
TOGETHER_POLL_INTERVAL_SECONDS = int(together_cfg.get('poll_interval_seconds', 5))
raw_prefix = together_cfg.get('results_filename_prefix') or together_cfg.get('output_filename') or 'together_eval'
FILENAME_BASE = (str(raw_prefix).rsplit('.', 1)[0]).rstrip('_')
RUN_ID = datetime.now().strftime('%Y%m%d_%H%M%S')

print(f'Loaded simulation config {SIMULATION_CONFIG_FILE}')
print(f'Steering config {STEERING_CONFIG_FILE} | tasks: {STEER_TASKS or "<none>"}')


In [None]:
import os

if CLIENT_ASSISTANT_MODEL_API_KEY:
    os.environ['OPENAI_API_KEY'] = CLIENT_ASSISTANT_MODEL_API_KEY
if CLIENT_ASSISTANT_MODEL_URL:
    os.environ['OPENAI_BASE_URL'] = CLIENT_ASSISTANT_MODEL_URL
if CLIENT_STEER_API_KEY:
    os.environ['STEER_API_KEY'] = CLIENT_STEER_API_KEY


## Client setup

Initializes the Collinear client and applies optional custom system prompts.


In [None]:
if not CLIENT_ASSISTANT_MODEL_API_KEY:
    raise RuntimeError('assistant_model_api_key must be set in configs/simulation_config.json')

client = Client(
    assistant_model_url=CLIENT_ASSISTANT_MODEL_URL,
    assistant_model_api_key=CLIENT_ASSISTANT_MODEL_API_KEY,
    assistant_model_name=CLIENT_ASSISTANT_MODEL_NAME,
    steer_api_key=CLIENT_STEER_API_KEY,
    timeout=CLIENT_TIMEOUT,
    max_retries=CLIENT_MAX_RETRIES,
    rate_limit_retries=CLIENT_RATE_LIMIT_RETRIES,
)

runner = client.simulation_runner


## Generate simulated user interactions

Runs simulations and writes a JSONL dataset with conversation, assistant response, and steering persona.


In [None]:
simulations = client.simulate(
    steer_config=steer_config,
    k=SIM_SAMPLES,
    num_exchanges=SIM_EXCHANGES,
    batch_delay=SIM_DELAY,
    steer_temperature=SIM_STEER_TEMPERATURE,
    steer_max_tokens=SIM_STEER_MAX_TOKENS,
    mix_traits=SIM_MIX_TRAITS,
    max_concurrency=SIM_MAX_CONCURRENCY,
)

rows = []
for sim in simulations:
    messages = list(sim.conv_prefix)
    assistant_response = (sim.response or "").strip()

    if _needs_fallback(assistant_response):
        fallback = ""
        cutoff_index = None
        for idx in range(len(messages) - 1, -1, -1):
            message = messages[idx]
            if message.get("role") == "assistant":
                candidate = (message.get("content") or "").strip()
                if candidate and "###STOP###" not in candidate:
                    fallback = candidate
                    cutoff_index = idx
                    break
        if fallback:
            assistant_response = fallback
            if cutoff_index is not None:
                messages = messages[: cutoff_index + 1]

    rows.append(
        {
            "conversation_messages": messages,
            "assistant_response": assistant_response,
            "steering_persona": persona_from_steer(runner, getattr(sim, "steer", None)),
        }
    )

dataset_path = RESULTS_DIR / f"{FILENAME_BASE}_{RUN_ID}_dataset.jsonl"
with dataset_path.open("w", encoding="utf-8") as fh:
    for row in rows:
        convo_lines = conversation_lines(row["conversation_messages"])
        serializable = {
            "conversation": "\n".join(convo_lines),
            "assistant_response": row["assistant_response"],
            "steering_persona": row["steering_persona"],
        }
        fh.write(json.dumps(serializable, ensure_ascii=False))
        fh.write(chr(10))
print(f"Saved {len(rows)} simulations to {dataset_path}")

for idx, row in enumerate(rows, start=1):
    print()
    print("=" * 40)
    print(f"Conversation {idx}")
    print("-" * 40)
    persona = row["steering_persona"] or {}
    print("Persona:")
    print(json.dumps(persona, indent=2, ensure_ascii=False) if persona else "  <none>")
    print()
    print("Transcript:")
    turns = conversation_lines(row["conversation_messages"])
    for turn_no, line in enumerate(turns, start=1):
        print(f"{turn_no:02d}. {line}")
    print(f"{len(turns) + 1:02d}. assistant: {row['assistant_response'] or '<no response>'}")


## Upload simulations as dataset and load judge model on Together

Uploads the dataset to Together and starts a safety-score evaluation.


In [None]:
together_client = together.Together(api_key=CLIENT_ASSISTANT_MODEL_API_KEY)

upload = together_client.files.upload(file=str(dataset_path), purpose=TOGETHER_UPLOAD_PURPOSE)
upload_id = getattr(upload, 'id', None)
if upload_id is None:
    upload_id = upload['id']

requestor = api_requestor.APIRequestor(client=together_client.client)
payload = {
    'type': TOGETHER_EVAL_TYPE,
    'parameters': {
        'judge': {
            'model': ASSESS_JUDGE_MODEL_NAME,
            'model_source': TOGETHER_JUDGE_MODEL_SOURCE,
            'system_template': JUDGE_SYSTEM_PROMPT,
        },
        'input_data_file_path': upload_id,
        'model_to_evaluate': TOGETHER_MODEL_TO_EVALUATE,
        'min_score': TOGETHER_MIN_SCORE,
        'max_score': TOGETHER_MAX_SCORE,
        'pass_threshold': TOGETHER_PASS_THRESHOLD,
    },
}
response, _, _ = requestor.request(
    options=TogetherRequest(method='POST', url='evaluation', params=payload),
    stream=False,
)
evaluation = getattr(response, 'data', response)
workflow_id = getattr(evaluation, 'workflow_id', evaluation['workflow_id'])
print(f'Started evaluation {workflow_id}')


## Eval results and analysis

Polls for completion, enriches results with personas, and prints a summary.


In [None]:
deadline = time.time() + TOGETHER_POLL_TIMEOUT_SECONDS
results_path = None
while time.time() < deadline:
    status_obj = together_client.evaluation.status(workflow_id)
    status_raw = str(getattr(status_obj, "status", "pending"))
    state = status_raw.lower().split('.')[-1]
    print(f"status: {status_raw}")
    if state in {"completed", "success", "failed", "error", "user_error"}:
        results = getattr(status_obj, "results", None)
        if isinstance(results, dict) and results.get("result_file_id"):
            results_path = RESULTS_DIR / f"{FILENAME_BASE}_{RUN_ID}_{workflow_id}_results.jsonl"
            together_client.files.retrieve_content(results["result_file_id"], output=str(results_path))
            print(f"Downloaded results to {results_path}")
            evaluation_rows = []
            with results_path.open("r", encoding="utf-8") as fh:
                for line in fh:
                    text = line.strip()
                    if text:
                        evaluation_rows.append(json.loads(text))
            for idx, row in enumerate(rows, start=1):
                evaluation = evaluation_rows[idx - 1] if idx - 1 < len(evaluation_rows) else {}
                print()
                print("=" * 40)
                print(f"Conversation {idx}")
                print("-" * 40)
                persona = row.get("steering_persona") or {}
                print("Persona:")
                print(json.dumps(persona, indent=2, ensure_ascii=False) if persona else "  <none>")
                print()
                print("Transcript:")
                turns = conversation_lines(row.get("conversation_messages", []))
                for turn_no, line in enumerate(turns, start=1):
                    print(f"{turn_no:02d}. {line}")
                print(f"{len(turns) + 1:02d}. assistant: {row.get('assistant_response') or '<no response>'}")
                score = evaluation.get("score", "-")
                passed = evaluation.get("pass")
                rationale = evaluation.get("feedback") or evaluation.get("rationale")
                print()
                print("Assessment:")
                print(f"  score: {score}")
                if passed is not None:
                    print(f"  pass: {passed}")
                if rationale:
                    print(f"  rationale: {rationale}")
        break
    time.sleep(TOGETHER_POLL_INTERVAL_SECONDS)
else:
    print("Timed out waiting for evaluation to finish.")
