From 7356361c1beb36d7c5f911cccffcd79c86212363 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Wed, 13 Aug 2025 21:21:57 -0700 Subject: [PATCH 1/9] Finished Error Handling --- eval_protocol/mcp/execution/manager.py | 2 +- .../pytest/default_agent_rollout_processor.py | 5 +- .../default_mcp_gym_rollout_processor.py | 105 +++++--- .../default_single_turn_rollout_process.py | 2 + eval_protocol/pytest/evaluation_test.py | 105 +++++++- eval_protocol/pytest/plugin.py | 25 ++ tests/test_retry_mechanism.py | 138 ++++++++++ tests/test_rollout_error_handling.py | 251 ++++++++++++++++++ 8 files changed, 587 insertions(+), 46 deletions(-) create mode 100644 tests/test_retry_mechanism.py create mode 100644 tests/test_rollout_error_handling.py diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 405e72b4..c1991bdc 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -138,7 +138,7 @@ async def _execute_with_semaphore(idx): if trajectory.terminated: if trajectory.termination_reason == TerminationReason.ERROR: evaluation_row.rollout_status.status = "error" - evaluation_row.rollout_status.error_message = trajectory.control_plane_summary.get( + evaluation_row.rollout_status.termination_reason = trajectory.control_plane_summary.get( "error_message", None ) else: diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index b3997c49..8836d3c8 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -126,7 +126,7 @@ async def default_agent_rollout_processor( async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row with agent rollout.""" agent = Agent( - model=config.completion_params.model, row=row, config_path=config.mcp_config_path, logger=config.logger + model=config.completion_params["model"], row=row, config_path=config.mcp_config_path, logger=config.logger ) try: await agent.setup() @@ -141,7 +141,8 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: try: return await process_row(r) except Exception as e: - logger.exception(f"Error processing row {r.input_metadata.row_id}: {e}") + r.rollout_status.status = "error" + r.rollout_status.termination_reason = str(e) return r # Create all tasks diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 2b90239d..54c3914b 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -6,12 +6,14 @@ import subprocess import time from pathlib import Path -from typing import AsyncIterator, List, Optional +from typing import Any, AsyncIterator, Dict, List, Optional import eval_protocol as ep from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest.types import RolloutProcessorConfig +CURRENT_RUN_STATE: Dict[str, Any] = {} + class MCPServerManager: """Manages MCP server lifecycle for testing.""" @@ -204,41 +206,78 @@ async def default_mcp_gym_rollout_processor( Args: rows: List of EvaluationRow objects containing messages and dataset info in input_metadata config: RolloutProcessorConfig with model and other parameters + - config.kwargs can include: + - start_server (bool): If True, create fresh server and environments. If False, reuse existing ones. Default: True. Returns: AsyncIterator of EvaluationRow objects with completed conversations """ - if config.server_script_path is None: - raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor") - server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {})) - - try: - server.start() - - policy = ep.LiteLLMPolicy( - model_id=config.completion_params.model, - temperature=config.completion_params.get("temperature", 0.0), - max_tokens=config.completion_params.get("max_tokens", 4096), - reasoning_effort=config.completion_params.get("reasoning_effort", None), - ) + start_server = config.kwargs.get("start_server", True) if config.kwargs else True + if start_server: + # Create fresh MCP server and environments for this run + if config.server_script_path is None: + raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor") - # Create MCP environments directly from evaluation_rows - envs = ep.make( - "http://localhost:9700/mcp/", - evaluation_rows=rows, - model_id=policy.model_id, - ) + server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {})) - # Run rollout with environments and policy - async for evaluation_row in ep.rollout( - envs, - policy=policy, - evaluation_rows=rows, - steps=config.steps, - max_concurrent_rollouts=config.max_concurrent_rollouts, - ): - yield evaluation_row - - finally: - # Always clean up the server - server.stop() + try: + server.start() + + policy = ep.LiteLLMPolicy( + model_id=config.completion_params.get("model", None), + temperature=config.completion_params.get("temperature", 0.0), + max_tokens=config.completion_params.get("max_tokens", 4096), + reasoning_effort=config.completion_params.get("reasoning_effort", None), + ) + + # Create MCP environments directly from evaluation_rows + envs = ep.make( + "http://localhost:9700/mcp/", + evaluation_rows=rows, + model_id=policy.model_id, + ) + + # Store in current run state for reuse within this run + CURRENT_RUN_STATE.update( + { + "server": server, + "envs": envs, + "policy": policy, + } + ) + + except Exception as e: + server.stop() + CURRENT_RUN_STATE.clear() + raise e + + else: + # Reuse existing MCP environments for retry + if not CURRENT_RUN_STATE: + raise RuntimeError("Cannot retry without existing server/environments. Call with start_server=True first.") + + server = CURRENT_RUN_STATE["server"] + envs = CURRENT_RUN_STATE["envs"] + policy = CURRENT_RUN_STATE["policy"] + + # Run rollout with environments and policy (automatically resets environments) + async for evaluation_row in ep.rollout( + envs, + policy=policy, + evaluation_rows=rows, + steps=config.steps, + max_concurrent_rollouts=config.max_concurrent_rollouts, + ): + yield evaluation_row + + +# Add cleanup method directly to the function object +def _cleanup_mcp_gym_rollout_processor(): + """Cleanup function for MCP gym rollout processor""" + if CURRENT_RUN_STATE and "server" in CURRENT_RUN_STATE: + CURRENT_RUN_STATE["server"].stop() + CURRENT_RUN_STATE.clear() # Clear for next run + + +# Attach cleanup method to the processor function +default_mcp_gym_rollout_processor.cleanup = _cleanup_mcp_gym_rollout_processor diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index ef2ad48b..e405d45f 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -112,6 +112,8 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: try: return await process_row(r) except Exception as e: + r.rollout_status.status = "error" + r.rollout_status.termination_reason = str(e) return r # Create all tasks diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index dd7ecb04..25627543 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -8,6 +8,7 @@ import re import statistics import time +from dataclasses import replace from typing import Any, Callable, Dict, List, Literal, Optional, Union import pytest @@ -269,6 +270,82 @@ def generate_combinations(): return combinations + async def rollout_processor_with_retry( + rollout_processor: RolloutProcessor, + fresh_dataset: List[EvaluationRow], + config: RolloutProcessorConfig, + max_retry: int, + ): + """ + Wrapper around rollout_processor that handles retry logic internally. + Uses async queue pattern to yield results immediately as they become available. + Yields both successful and failed results, leaving it up to the user to handle them in test_func. + """ + + try: + queue = asyncio.Queue() + retry_counts = {r.execution_metadata.rollout_id: 0 for r in fresh_dataset} + failed_permanently = [] + + async def retry_handler(failed_row: EvaluationRow): + rollout_id = failed_row.execution_metadata.rollout_id + current_attempts = retry_counts.get(rollout_id, 0) + + if current_attempts >= max_retry: + assert ( + failed_row.rollout_status and failed_row.rollout_status.status == "error" + ), f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" + failed_permanently.append(failed_row) + await queue.put(failed_row) # put failed row on queue + return + + retry_counts[rollout_id] = current_attempts + 1 + + # add kwargs start_server=False to config so we don't start new MCP server + retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False}) + + retry_call = rollout_processor([failed_row], retry_config) + + retry_result = await anext(retry_call) + if retry_result.rollout_status and retry_result.rollout_status.status == "finished": + await queue.put(retry_result) + else: + asyncio.create_task(retry_handler(retry_result)) # retry failed, spawn another retry + + async def initial_processor(): + """Process initial batch and spawn retries for failures""" + async for initial_row in rollout_processor(fresh_dataset, config): + if initial_row.rollout_status and initial_row.rollout_status.status == "finished": + await queue.put(initial_row) # rollout succeeded, put on queue + else: + asyncio.create_task(retry_handler(initial_row)) # rollout errored, spawn retry task + + processor_task = asyncio.create_task(initial_processor()) + + # yield results as they become available + completed_count = 0 + total_expected = len(fresh_dataset) + + while completed_count < total_expected: + finished_row = await queue.get() + + # only permanent failure rows are put on the queue, so we can check for them here + if finished_row.rollout_status and finished_row.rollout_status.status == "error": + if os.getenv("EP_FAIL_ON_PERMANENT_FAILURE", "true") != "false": + raise RuntimeError( + f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {finished_row.rollout_status.termination_reason}" + ) + + completed_count += 1 + yield finished_row + + await processor_task # explicitly wait for task completion and catch any exceptions + + finally: + # processor clean up after themselves if they have a cleanup method + if hasattr(rollout_processor, "cleanup"): + rollout_processor.cleanup() + combinations = generate_combinations() if len(combinations) == 0: raise ValueError( @@ -410,6 +487,8 @@ def _log_eval_error( kwargs=rollout_processor_kwargs or {}, ) + max_retry = int(os.getenv("EP_MAX_RETRY", "0")) + for i in range(num_runs): # Regenerate outputs each run by deep-copying the pristine dataset # so model responses are not reused across runs. @@ -428,8 +507,6 @@ def _log_eval_error( for row in fresh_dataset: active_logger.log(row) - rollout_result = rollout_processor(fresh_dataset, config) - if mode == "pointwise": # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution semaphore = asyncio.Semaphore(max_concurrent_rollouts) @@ -437,6 +514,8 @@ def _log_eval_error( async def _execute_with_semaphore(row): async with semaphore: + # NOTE: we will still evaluate errored rows (give users control over this) + # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func result = await execute_with_params( test_func, processed_row=row, @@ -448,7 +527,10 @@ async def _execute_with_semaphore(row): ) return result - async for row in rollout_processor(fresh_dataset, config): + # Use wrapper that handles retry logic internally + async for row in rollout_processor_with_retry( + rollout_processor, fresh_dataset, config, max_retry + ): tasks.append(asyncio.create_task(_execute_with_semaphore(row))) all_results[i] = await asyncio.gather(*tasks) @@ -456,9 +538,12 @@ async def _execute_with_semaphore(row): else: # Batch mode: collect all results first, then evaluate (no pipelining) input_dataset = [] - async for row in rollout_result: + async for row in rollout_processor_with_retry( + rollout_processor, fresh_dataset, config, max_retry + ): input_dataset.append(row) - + # NOTE: we will still evaluate errored rows (give users control over this) + # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func results = await execute_with_params( test_func, processed_dataset=input_dataset, @@ -530,7 +615,7 @@ async def _execute_with_semaphore(row): should_print = os.getenv("EP_PRINT_SUMMARY") == "1" summary_path = os.getenv("EP_SUMMARY_JSON") suite_name = test_func.__name__ - model_used = config.completion_params.model + model_used = config.completion_params["model"] total_rows = len([item for sublist in all_results for item in sublist]) summary_obj = { "suite": suite_name, @@ -990,7 +1075,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict: total_rows = len(all_results) summary_obj = { "suite": suite_name, - "model": config.completion_params.model, + "model": config.completion_params["model"], "agg_score": float(agg_score) if agg_score is not None else None, "num_runs": num_runs, "rows": total_rows, @@ -1001,11 +1086,11 @@ def _deep_update_dict(base: dict, override: dict) -> dict: if should_print: if ci_low is not None and ci_high is not None: print( - f"EP Summary | suite={suite_name} model={config.completion_params.model} agg={summary_obj['agg_score']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] runs={num_runs} rows={total_rows}" + f"EP Summary | suite={suite_name} model={config.completion_params['model']} agg={summary_obj['agg_score']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] runs={num_runs} rows={total_rows}" ) else: print( - f"EP Summary | suite={suite_name} model={config.completion_params.model} agg={summary_obj['agg_score']:.3f} runs={num_runs} rows={total_rows}" + f"EP Summary | suite={suite_name} model={config.completion_params['model']} agg={summary_obj['agg_score']:.3f} runs={num_runs} rows={total_rows}" ) if summary_path: import json as _json @@ -1037,7 +1122,7 @@ def _extract_effort_tag(params: dict) -> str | None: return None return None - model_slug = _sanitize_filename(config.completion_params.model) + model_slug = _sanitize_filename(config.completion_params["model"]) effort_tag = _extract_effort_tag(completion_params) or "" effort_suffix = f"__effort-{_sanitize_filename(effort_tag)}" if effort_tag else "" base_name = f"{suite_name}__{model_slug}{effort_suffix}__{mode}__runs{num_runs}.json" diff --git a/eval_protocol/pytest/plugin.py b/eval_protocol/pytest/plugin.py index 3a5ec0e2..4522caef 100644 --- a/eval_protocol/pytest/plugin.py +++ b/eval_protocol/pytest/plugin.py @@ -59,6 +59,23 @@ def pytest_addoption(parser) -> None: "Values: low|medium|high" ), ) + group.addoption( + "--ep-max-retry", + action="store", + type=int, + default=None, + help=("Failed rollouts (with rollout_status.status == 'error') will be retried up to this many times."), + ) + group.addoption( + "--ep-fail-on-permanent-failure", + action="store", + default=None, + choices=["true", "false"], + help=( + "Whether to fail the entire rollout when permanent failures occur after max retries. " + "Default: true (fail on permanent failures). Set to 'false' to continue with remaining rollouts." + ), + ) def _normalize_max_rows(val: Optional[str]) -> Optional[str]: @@ -100,6 +117,14 @@ def pytest_configure(config) -> None: if summary_json_path: os.environ["EP_SUMMARY_JSON"] = summary_json_path + max_retry = config.getoption("--ep-max-retry") + if max_retry is not None: + os.environ["EP_MAX_RETRY"] = str(max_retry) + + fail_on_permanent_failure = config.getoption("--ep-fail-on-permanent-failure") + if fail_on_permanent_failure is not None: + os.environ["EP_FAIL_ON_PERMANENT_FAILURE"] = fail_on_permanent_failure + # Allow ad-hoc overrides of input params via CLI flags try: import json as _json diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py new file mode 100644 index 00000000..b9cfe916 --- /dev/null +++ b/tests/test_retry_mechanism.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +""" +Simple test to verify the retry mechanism works with evaluation_test. +""" + +import asyncio +import os +import time +from dataclasses import dataclass +from typing import AsyncIterator, List + +from eval_protocol.models import EvaluateResult, EvaluationRow, Message, RolloutStatus +from eval_protocol.pytest.evaluation_test import evaluation_test +from eval_protocol.pytest.types import RolloutProcessorConfig + +os.environ["EP_MAX_RETRY"] = "2" # Allow up to 2 retries + +start_time = time.time() +timing_results = [] # Collect timing data for assertions + + +async def mock_rollout_processor_with_retries( + rows: List[EvaluationRow], config: RolloutProcessorConfig +) -> AsyncIterator[EvaluationRow]: + """Mock rollout processor that fails second task alphabetically on first attempt, succeeds on retry""" + row_setup = { + 0: {"delay": 3.0, "should_fail": False}, + 1: {"delay": 3.0, "should_fail": True}, + 2: {"delay": 5.0, "should_fail": False}, + 3: {"delay": 5.0, "should_fail": False}, + 4: {"delay": 5.0, "should_fail": False}, + } + + async def process_single_row(row: EvaluationRow, delay: float, should_fail: bool = False) -> EvaluationRow: + await asyncio.sleep(delay) + + if should_fail: + row.rollout_status = RolloutStatus(status="error", termination_reason="Simulated failure for testing") + else: + row.rollout_status = RolloutStatus(status="finished") + + return row + + # Create tasks for concurrent processing + tasks = [ + asyncio.create_task(process_single_row(row, row_setup[i]["delay"], row_setup[i]["should_fail"])) + for i, row in enumerate(rows) + ] + + # Yield results as they complete + for completed_task in asyncio.as_completed(tasks): + result = await completed_task + elapsed = time.time() - start_time + print(f"šŸŽ‰ FINISHED {result.rollout_status.status} at {elapsed:.2f}s: {result.execution_metadata.rollout_id}") + yield result + + +@evaluation_test( + completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], + input_messages=[ + [Message(role="user", content="Task A")], + [Message(role="user", content="Task B")], + [Message(role="user", content="Task C")], + [Message(role="user", content="Task D")], + [Message(role="user", content="Task E")], + ], + rollout_processor=mock_rollout_processor_with_retries, + num_runs=1, + mode="pointwise", +) +def test_retry_mechanism(row: EvaluationRow) -> EvaluationRow: + """MOCK TEST: first 2 rows take 3s, last 3 take 5s, second row fails on first attempt, succeeds on retry. Should take around 6s total.""" + # Just print the timing - we'll parse it from output + elapsed = time.time() - start_time + print( + f"šŸ“Š EVALUATED at {elapsed:.2f}s: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" + ) + + # Assign a score based on success/failure + score = 1.0 if row.rollout_status.status == "finished" else 0.0 + row.evaluation_result = EvaluateResult(score=score) + + return row + + +def test_timing_assertions(): + """Validate that timing results match expected pipeline behavior""" + global start_time + + # Reset and run the evaluation test + start_time = time.time() + + # Capture pytest output + import subprocess + import sys + + result = subprocess.run( + [sys.executable, "-m", "pytest", __file__ + "::test_retry_mechanism", "-v", "-s"], + capture_output=True, + text=True, + cwd=os.getcwd(), + ) + + print(result.stdout) # Show the original output + + # Parse timing from output + import re + + timing_results = [] + for line in result.stdout.split("\n"): + match = re.search(r"šŸ“Š EVALUATED at (\d+\.\d+)s:", line) + if match: + timing_results.append(float(match.group(1))) + + print(f"\nšŸ“Š PIPELINE TIMING ANALYSIS:") + print(f" Results received at: {[f'{t:.2f}s' for t in sorted(timing_results)]}") + + # Assertions for expected timing behavior + sorted_times = sorted(timing_results) + + assert len(sorted_times) == 5, f"Expected 5 evaluation results, got {len(sorted_times)}" + + # First result should be around 3s (row 0 success) + assert 2.5 <= sorted_times[0] <= 3.5, f"First result at {sorted_times[0]:.2f}s, expected ~3s" + + # Next three results should be around 5s (rows 2,3,4) + assert 4.5 <= sorted_times[1] <= 5.5, f"Second result at {sorted_times[1]:.2f}s, expected ~5s" + assert 4.5 <= sorted_times[2] <= 5.5, f"Third result at {sorted_times[2]:.2f}s, expected ~5s" + assert 4.5 <= sorted_times[3] <= 5.5, f"Fourth result at {sorted_times[3]:.2f}s, expected ~5s" + + # Last result should be around 6s (row 1 retry success) + assert 5.5 <= sorted_times[4] <= 6.5, f"Fifth result at {sorted_times[4]:.2f}s, expected ~6s (retry success)" + + print("āœ… All timing assertions passed! Pipeline behavior is correct.") + + +if __name__ == "__main__": + test_timing_assertions() diff --git a/tests/test_rollout_error_handling.py b/tests/test_rollout_error_handling.py new file mode 100644 index 00000000..767fa16d --- /dev/null +++ b/tests/test_rollout_error_handling.py @@ -0,0 +1,251 @@ +""" +Unit tests for rollout processor error handling. + +Tests that rollout processors properly set rollout_status.status = "error" when exceptions occur. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from eval_protocol.dataset_logger import default_logger +from eval_protocol.models import EvaluationRow, Message, RolloutStatus +from eval_protocol.pytest.default_agent_rollout_processor import default_agent_rollout_processor +from eval_protocol.pytest.default_single_turn_rollout_process import default_single_turn_rollout_processor +from eval_protocol.pytest.types import RolloutProcessorConfig + + +class TestRolloutErrorHandling: + """Test that rollout processors handle errors correctly.""" + + @pytest.mark.asyncio + async def test_agent_rollout_processor_429_error(self): + """Test that agent rollout processor handles 429 rate limit errors correctly.""" + + # Create test row with initialized rollout_status + test_row = EvaluationRow( + messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") + ) + + config = RolloutProcessorConfig( + completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger # Empty to avoid MCP setup + ) + + # Mock the LiteLLM policy to raise a 429 error + with patch("eval_protocol.pytest.default_agent_rollout_processor.LiteLLMPolicy") as mock_policy_class: + # Create a mock policy instance + mock_policy = AsyncMock() + mock_policy_class.return_value = mock_policy + + # Mock the _make_llm_call method to raise a 429 error + import litellm + + mock_policy._make_llm_call.side_effect = litellm.RateLimitError( + message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" + ) + + # The agent rollout processor should catch the exception and set error status + result = [] + async for row in default_agent_rollout_processor([test_row], config): + result.append(row) + + assert len(result) == 1 + assert result[0].rollout_status.status == "error" + assert result[0].rollout_status.termination_reason is not None + assert ( + "429" in result[0].rollout_status.termination_reason + or "rate limit" in result[0].rollout_status.termination_reason.lower() + ) + + @pytest.mark.asyncio + async def test_agent_rollout_processor_bad_request_error(self): + """Test that agent rollout processor handles BadRequest errors correctly.""" + + test_row = EvaluationRow( + messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") + ) + + config = RolloutProcessorConfig( + completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger + ) + + # Mock the LiteLLM policy to raise a BadRequest error like the one in your example + with patch("eval_protocol.pytest.default_agent_rollout_processor.LiteLLMPolicy") as mock_policy_class: + mock_policy = AsyncMock() + mock_policy_class.return_value = mock_policy + + import openai + + mock_policy._make_llm_call.side_effect = openai.BadRequestError( + "Invalid value for 'content': expected a string, got null.", response=MagicMock(), body=None + ) + + result = [] + async for row in default_agent_rollout_processor([test_row], config): + result.append(row) + + assert len(result) == 1 + assert result[0].rollout_status.status == "error" + assert result[0].rollout_status.termination_reason is not None + assert ( + "content" in result[0].rollout_status.termination_reason + or "BadRequest" in result[0].rollout_status.termination_reason + ) + + @pytest.mark.asyncio + async def test_single_turn_rollout_processor_429_error(self): + """Test that single turn rollout processor handles 429 rate limit errors correctly.""" + + test_row = EvaluationRow( + messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") + ) + + config = RolloutProcessorConfig( + completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger + ) + + # Mock litellm.acompletion to raise a 429 error + with patch("importlib.import_module") as mock_import: + mock_litellm = MagicMock() + mock_import.return_value = mock_litellm + + import litellm + + mock_litellm.acompletion.side_effect = litellm.RateLimitError( + message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" + ) + + result = [] + async for row in default_single_turn_rollout_processor([test_row], config): + result.append(row) + + assert len(result) == 1 + assert result[0].rollout_status.status == "error" + assert result[0].rollout_status.termination_reason is not None + assert ( + "429" in result[0].rollout_status.termination_reason + or "rate limit" in result[0].rollout_status.termination_reason.lower() + ) + + @pytest.mark.asyncio + async def test_single_turn_rollout_processor_bad_request_error(self): + """Test that single turn rollout processor handles BadRequest errors correctly.""" + + test_row = EvaluationRow( + messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") + ) + + config = RolloutProcessorConfig( + completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger + ) + + # Mock litellm.acompletion to raise a BadRequest error + with patch("importlib.import_module") as mock_import: + mock_litellm = MagicMock() + mock_import.return_value = mock_litellm + + import openai + + mock_litellm.acompletion.side_effect = openai.BadRequestError( + "Invalid value for 'content': expected a string, got null.", response=MagicMock(), body=None + ) + + result = [] + async for row in default_single_turn_rollout_processor([test_row], config): + result.append(row) + + assert len(result) == 1 + assert result[0].rollout_status.status == "error" + assert result[0].rollout_status.termination_reason is not None + assert ( + "content" in result[0].rollout_status.termination_reason + or "BadRequest" in result[0].rollout_status.termination_reason + ) + + @pytest.mark.asyncio + async def test_multiple_rows_with_mixed_errors(self): + """Test that when some rows get 429 errors and some succeed, each gets the correct status.""" + + # Create test rows + row1 = EvaluationRow( + messages=[Message(role="user", content="Hello 1")], rollout_status=RolloutStatus(status="running") + ) + + row2 = EvaluationRow( + messages=[Message(role="user", content="Hello 2")], rollout_status=RolloutStatus(status="running") + ) + + config = RolloutProcessorConfig( + completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger + ) + + # Mock litellm.acompletion to raise 429 for both rows (simulating rate limiting) + with patch("importlib.import_module") as mock_import: + mock_litellm = MagicMock() + mock_import.return_value = mock_litellm + + import litellm + + mock_litellm.acompletion.side_effect = litellm.RateLimitError( + message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" + ) + + result = [] + async for row in default_single_turn_rollout_processor([row1, row2], config): + result.append(row) + + assert len(result) == 2 + # Both should have error status due to 429 errors + for row in result: + assert row.rollout_status.status == "error" + assert row.rollout_status.termination_reason is not None + assert ( + "429" in row.rollout_status.termination_reason + or "rate limit" in row.rollout_status.termination_reason.lower() + ) + + @pytest.mark.asyncio + async def test_rollout_status_preserves_original_row_data_on_api_error(self): + """Test that when API errors occur, the original row data is preserved.""" + + original_message = Message(role="user", content="Original message") + test_row = EvaluationRow(messages=[original_message], rollout_status=RolloutStatus(status="running")) + + config = RolloutProcessorConfig( + completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger + ) + + # Mock the LiteLLM policy to raise an API error + with patch("eval_protocol.pytest.default_agent_rollout_processor.LiteLLMPolicy") as mock_policy_class: + mock_policy = AsyncMock() + mock_policy_class.return_value = mock_policy + + import litellm + + mock_policy._make_llm_call.side_effect = litellm.RateLimitError( + message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" + ) + + result = [] + async for row in default_agent_rollout_processor([test_row], config): + result.append(row) + + assert len(result) == 1 + assert result[0].rollout_status.status == "error" + # Original message should be preserved + assert len(result[0].messages) == 1 + assert result[0].messages[0].content == "Original message" + + def test_rollout_status_initialization(self): + """Test that RolloutStatus initializes with correct default values.""" + + # Test default initialization + status = RolloutStatus() + assert status.status == "running" # Default from the model + assert status.termination_reason == "" # Default empty string + + # Test explicit initialization + status = RolloutStatus(status="error", termination_reason="Test error") + assert status.status == "error" + assert status.termination_reason == "Test error" From 062e4487785618623a91d79e0f352f662df6b953 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 14 Aug 2025 00:39:02 -0700 Subject: [PATCH 2/9] Address comments --- .github/workflows/ci.yml | 1 + eval_protocol/benchmarks/suites/gpqa.py | 5 +- eval_protocol/mcp/execution/manager.py | 19 +- eval_protocol/mcp_env.py | 24 +- .../pytest/default_agent_rollout_processor.py | 31 +-- .../default_mcp_gym_rollout_processor.py | 18 +- .../pytest/default_no_op_rollout_process.py | 16 +- .../default_single_turn_rollout_process.py | 33 +-- eval_protocol/pytest/evaluation_test.py | 38 ++- eval_protocol/pytest/types.py | 5 +- tests/test_retry_mechanism.py | 22 +- .../test_rollout_control_plane_integration.py | 28 +- tests/test_rollout_error_handling.py | 251 ------------------ 13 files changed, 116 insertions(+), 375 deletions(-) delete mode 100644 tests/test_rollout_error_handling.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a1cf6aec..a0184b62 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -92,6 +92,7 @@ jobs: --ignore=tests/pytest/test_frozen_lake.py \ --ignore=tests/pytest/test_lunar_lander.py \ --ignore=tests/pytest/test_tau_bench_airline.py \ + --ignore=tests/pytest/test_apps_coding.py \ --ignore=tests/test_tau_bench_airline_smoke.py \ --cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10 diff --git a/eval_protocol/benchmarks/suites/gpqa.py b/eval_protocol/benchmarks/suites/gpqa.py index 76967beb..ff745adc 100644 --- a/eval_protocol/benchmarks/suites/gpqa.py +++ b/eval_protocol/benchmarks/suites/gpqa.py @@ -1,3 +1,4 @@ +import asyncio import csv import io import re @@ -60,7 +61,7 @@ def _strip_gt_messages(msgs: List[Message]) -> List[Message]: return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))] -async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[EvaluationRow]: +def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[asyncio.Task[EvaluationRow]]: """Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to default processor.""" processed: List[EvaluationRow] = [] for r in rows: @@ -72,7 +73,7 @@ async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:")) ] processed.append(r) - return await default_single_turn_rollout_processor(processed, config) + return default_single_turn_rollout_processor(processed, config) @export_benchmark("gpqa") diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index c1991bdc..b0359d79 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -35,7 +35,7 @@ class ExecutionManager: Manage rollout for MCP environments. """ - async def execute_rollouts( + def execute_rollouts( self, envs: "GeneralMCPVectorEnv", policy: Union["LLMBasePolicy", Callable], @@ -43,7 +43,7 @@ async def execute_rollouts( openai_format_log_file: Optional[str] = None, max_concurrent_rollouts: int = 8, evaluation_rows: Optional[List[EvaluationRow]] = None, - ) -> AsyncIterator[EvaluationRow]: + ) -> List[asyncio.Task[EvaluationRow]]: """ Execute general rollouts using tool calling interface with automatic record/playback. @@ -66,7 +66,7 @@ async def execute_rollouts( - Set and file exists: Playback mode (uses recorded data) Returns: - AsyncIterator of EvaluationRow objects with unified evaluation data format + List of asyncio.Task objects for external handling """ start_time = time.time() @@ -151,18 +151,7 @@ async def _execute_with_semaphore(idx): # Create all tasks tasks = [asyncio.create_task(_execute_with_semaphore(i)) for i in range(envs.n)] - - # Yield results as they complete (note that they're not necessarily in original order) - try: - for task in asyncio.as_completed(tasks): - try: - yield await task - except Exception: - logger.exception("Error processing rollout") - finally: - for t in tasks: - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + return tasks async def _execute_rollout( self, diff --git a/eval_protocol/mcp_env.py b/eval_protocol/mcp_env.py index 5d930a4e..5dc77c48 100644 --- a/eval_protocol/mcp_env.py +++ b/eval_protocol/mcp_env.py @@ -236,7 +236,7 @@ def make( return mcp_envs -async def rollout( +def rollout( envs: GeneralMCPVectorEnv, policy: Union[FireworksPolicy, LLMBasePolicy, Callable], *, @@ -246,7 +246,7 @@ async def rollout( steps: int = 512, openai_format_log_file: Optional[str] = None, max_concurrent_rollouts: int = 8, -) -> AsyncIterator[EvaluationRow]: +) -> List[asyncio.Task[EvaluationRow]]: """ Execute general rollouts using tool calling interface with automatic record/playback. @@ -274,14 +274,14 @@ async def rollout( - Set and file exists: Playback mode (uses recorded data) Returns: - List of EvaluationRow objects + List of asyncio.Task objects for external handling Example: # Live mode - evaluation_rows = await ep.rollout(envs, policy) + tasks = await ep.rollout(envs, policy) # Create environments automatically - trajectories = await ep.rollout( + tasks = await ep.rollout( "http://localhost:8000/mcp/", policy, evaluation_rows=my_evaluation_rows, @@ -290,10 +290,10 @@ async def rollout( # Recording mode os.environ["EP_PLAYBACK_FILE"] = "record.jsonl" - evaluation_rows = await ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl") + tasks = await ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl") # Playback mode (after recording file exists) - evaluation_rows = await ep.rollout(envs, policy) + tasks = await ep.rollout(envs, policy) """ # Automatically create environments if a base URL is provided if isinstance(envs, str): @@ -301,15 +301,15 @@ async def rollout( raise ValueError("Either 'evaluation_rows' or 'dataset' must be provided when envs is a URL") auto_model_id = model_id or getattr(policy, "model_id", "unknown") - envs = await make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id) + envs = make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id) # Use the new ExecutionManager for execution execution_manager = ExecutionManager() - async for evaluation_row in execution_manager.execute_rollouts( + tasks = execution_manager.execute_rollouts( envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows - ): - yield evaluation_row + ) + return tasks async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]: @@ -336,7 +336,7 @@ async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]: policy = FireworksPolicy("test-model") # Run short rollout - evaluation_rows = await rollout(envs, policy=policy, steps=10) + evaluation_rows = rollout(envs, policy=policy, steps=10) if evaluation_rows and len(evaluation_rows[0].messages) > 1: results["successful"] += 1 diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index 8836d3c8..ab1f596e 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -115,10 +115,10 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[Tex return tool_result.content -async def default_agent_rollout_processor( +def default_agent_rollout_processor( rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> AsyncIterator[EvaluationRow]: - """Process agent rollouts with bounded concurrency and yield as they complete.""" +) -> List[asyncio.Task[EvaluationRow]]: + """Create agent rollout tasks and return them for external handling.""" max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 semaphore = asyncio.Semaphore(max_concurrent) @@ -138,24 +138,9 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: async with semaphore: - try: - return await process_row(r) - except Exception as e: - r.rollout_status.status = "error" - r.rollout_status.termination_reason = str(e) - return r - - # Create all tasks - tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + result = await process_row(r) + return result - # Yield results as they complete (note that they're not necessarily in original order) - try: - for task in asyncio.as_completed(tasks): - try: - yield await task - except Exception: - logger.exception("Error processing row") - finally: - for t in tasks: - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + # Create and return tasks for external handling + tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + return tasks diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 54c3914b..0d5350be 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -194,14 +194,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False # Don't suppress exceptions -async def default_mcp_gym_rollout_processor( +def default_mcp_gym_rollout_processor( rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> AsyncIterator[EvaluationRow]: +) -> List[asyncio.Task[EvaluationRow]]: """ Rollout processor for tau bench environments. - This processor starts an MCP server, creates tau bench environments, and runs rollouts - using the eval_protocol framework, yielding results as they complete. + This processor starts an MCP server, creates tau bench environments, and returns rollout tasks + using the eval_protocol framework. Args: rows: List of EvaluationRow objects containing messages and dataset info in input_metadata @@ -210,7 +210,7 @@ async def default_mcp_gym_rollout_processor( - start_server (bool): If True, create fresh server and environments. If False, reuse existing ones. Default: True. Returns: - AsyncIterator of EvaluationRow objects with completed conversations + List of asyncio.Task objects for external handling """ start_server = config.kwargs.get("start_server", True) if config.kwargs else True if start_server: @@ -260,15 +260,15 @@ async def default_mcp_gym_rollout_processor( envs = CURRENT_RUN_STATE["envs"] policy = CURRENT_RUN_STATE["policy"] - # Run rollout with environments and policy (automatically resets environments) - async for evaluation_row in ep.rollout( + # Get rollout tasks from ep.rollout + tasks = ep.rollout( envs, policy=policy, evaluation_rows=rows, steps=config.steps, max_concurrent_rollouts=config.max_concurrent_rollouts, - ): - yield evaluation_row + ) + return tasks # Add cleanup method directly to the function object diff --git a/eval_protocol/pytest/default_no_op_rollout_process.py b/eval_protocol/pytest/default_no_op_rollout_process.py index 47cb17be..afcd9206 100644 --- a/eval_protocol/pytest/default_no_op_rollout_process.py +++ b/eval_protocol/pytest/default_no_op_rollout_process.py @@ -1,15 +1,21 @@ -from typing import AsyncIterator, List +import asyncio +from typing import List from eval_protocol.models import EvaluationRow from eval_protocol.pytest.types import RolloutProcessorConfig -async def default_no_op_rollout_processor( +def default_no_op_rollout_processor( rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> AsyncIterator[EvaluationRow]: +) -> List[asyncio.Task[EvaluationRow]]: """ Simply passes input dataset through to the test function. This can be useful if you want to run the rollout yourself. """ - for row in rows: - yield row + + async def return_row(row: EvaluationRow) -> EvaluationRow: + return row + + # Create tasks that immediately return the rows (no-op) + tasks = [asyncio.create_task(return_row(row)) for row in rows] + return tasks diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index e405d45f..4787cf41 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -15,10 +15,10 @@ logger = logging.getLogger(__name__) -async def default_single_turn_rollout_processor( +def default_single_turn_rollout_processor( rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> AsyncIterator[EvaluationRow]: - """Generate a single response from any supported model provider using LiteLLM.""" +) -> List[asyncio.Task[EvaluationRow]]: + """Generate single turn rollout tasks and return them for external handling.""" # Quiet LiteLLM logs in test runs unless user overrode try: @@ -103,30 +103,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: default_logger.log(row) return row - # Process rows with bounded concurrency and yield as they complete + # Process rows with bounded concurrency max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 semaphore = asyncio.Semaphore(max_concurrent) async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: async with semaphore: - try: - return await process_row(r) - except Exception as e: - r.rollout_status.status = "error" - r.rollout_status.termination_reason = str(e) - return r - - # Create all tasks - tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + result = await process_row(r) + return result - # Yield results as they complete (note that they're not necessarily in original order) - try: - for task in asyncio.as_completed(tasks): - try: - yield await task - except Exception: - logger.exception("Error processing row") - finally: - for t in tasks: - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + # Create and return tasks for external handling + tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + return tasks diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 25627543..93e586da 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -304,21 +304,37 @@ async def retry_handler(failed_row: EvaluationRow): # add kwargs start_server=False to config so we don't start new MCP server retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False}) - retry_call = rollout_processor([failed_row], retry_config) + retry_tasks = rollout_processor([failed_row], retry_config) - retry_result = await anext(retry_call) - if retry_result.rollout_status and retry_result.rollout_status.status == "finished": + try: + retry_result = await retry_tasks[0] + retry_result.rollout_status.status = "finished" await queue.put(retry_result) - else: - asyncio.create_task(retry_handler(retry_result)) # retry failed, spawn another retry + except Exception as e: + failed_row.rollout_status.status = "error" + failed_row.rollout_status.termination_reason = str(e) + asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry async def initial_processor(): """Process initial batch and spawn retries for failures""" - async for initial_row in rollout_processor(fresh_dataset, config): - if initial_row.rollout_status and initial_row.rollout_status.status == "finished": - await queue.put(initial_row) # rollout succeeded, put on queue - else: - asyncio.create_task(retry_handler(initial_row)) # rollout errored, spawn retry task + base_tasks = rollout_processor(fresh_dataset, config) + pending = set(base_tasks) + + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + + for task in done: + task_index = base_tasks.index(task) + + try: + result = await task + result.rollout_status.status = "finished" + await queue.put(result) + except Exception as e: + failed_row = fresh_dataset[task_index] + failed_row.rollout_status.status = "error" + failed_row.rollout_status.termination_reason = str(e) + asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task processor_task = asyncio.create_task(initial_processor()) @@ -606,7 +622,7 @@ async def _execute_with_semaphore(row): for result in all_results: for r in result: if r.eval_metadata is not None: - r.eval_metadata.status = "finished" + r.eval_metadata.status = "finished" # TODO: might not be needed r.eval_metadata.passed = passed active_logger.log(r) diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index 1a80254b..b2952cfb 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -2,8 +2,9 @@ Parameter types """ +import asyncio from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional from eval_protocol.dataset_logger import default_logger from eval_protocol.dataset_logger.dataset_logger import DatasetLogger @@ -51,4 +52,4 @@ class RolloutProcessorConfig: kwargs: Dict[str, Any] = field(default_factory=dict) # any additional kwargs to pass to the rollout processor -RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], AsyncIterator[EvaluationRow]] +RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[asyncio.Task[EvaluationRow]]] diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index b9cfe916..a937192a 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -19,9 +19,9 @@ timing_results = [] # Collect timing data for assertions -async def mock_rollout_processor_with_retries( +def mock_rollout_processor_with_retries( rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> AsyncIterator[EvaluationRow]: +) -> List[asyncio.Task[EvaluationRow]]: """Mock rollout processor that fails second task alphabetically on first attempt, succeeds on retry""" row_setup = { 0: {"delay": 3.0, "should_fail": False}, @@ -34,25 +34,23 @@ async def mock_rollout_processor_with_retries( async def process_single_row(row: EvaluationRow, delay: float, should_fail: bool = False) -> EvaluationRow: await asyncio.sleep(delay) + elapsed = time.time() - start_time + print( + f"šŸŽ‰ FINISHED {'error' if should_fail else 'finished'} at {elapsed:.2f}s: {row.execution_metadata.rollout_id}" + ) + if should_fail: - row.rollout_status = RolloutStatus(status="error", termination_reason="Simulated failure for testing") - else: - row.rollout_status = RolloutStatus(status="finished") + raise Exception("Simulated failure for testing") return row - # Create tasks for concurrent processing + # Create and return tasks (let evaluation_test handle them) tasks = [ asyncio.create_task(process_single_row(row, row_setup[i]["delay"], row_setup[i]["should_fail"])) for i, row in enumerate(rows) ] - # Yield results as they complete - for completed_task in asyncio.as_completed(tasks): - result = await completed_task - elapsed = time.time() - start_time - print(f"šŸŽ‰ FINISHED {result.rollout_status.status} at {elapsed:.2f}s: {result.execution_metadata.rollout_id}") - yield result + return tasks @evaluation_test( diff --git a/tests/test_rollout_control_plane_integration.py b/tests/test_rollout_control_plane_integration.py index 1b92d5aa..8d176780 100644 --- a/tests/test_rollout_control_plane_integration.py +++ b/tests/test_rollout_control_plane_integration.py @@ -239,8 +239,10 @@ def mock_step_side_effect(env_index, tool_call): policy = MockPolicy(["right", "down", "right"]) # Execute rollout + tasks = self.execution_manager.execute_rollouts(mock_env, policy, steps=10) evaluation_rows = [] - async for row in self.execution_manager.execute_rollouts(mock_env, policy, steps=10): + for task in tasks: + row = await task evaluation_rows.append(row) # Validate results @@ -459,8 +461,10 @@ async def test_rollout_handles_control_plane_failure_gracefully(self): # Execute rollout with control plane failure policy = MockPolicy(["right"]) + tasks = self.execution_manager.execute_rollouts(mock_env, policy, steps=1) evaluation_rows = [] - async for row in self.execution_manager.execute_rollouts(mock_env, policy, steps=1): + for task in tasks: + row = await task evaluation_rows.append(row) # Should still work, but without control plane info @@ -497,7 +501,7 @@ async def test_rollout_creates_envs_from_url(self): policy = MockPolicy(["right"]) with ( - patch("eval_protocol.mcp_env.make", new_callable=AsyncMock) as mock_make, + patch("eval_protocol.mcp_env.make") as mock_make, patch("eval_protocol.mcp_env.ExecutionManager") as MockManager, ): mock_env = MagicMock() @@ -505,24 +509,30 @@ async def test_rollout_creates_envs_from_url(self): manager_instance = MockManager.return_value - # Mock execute_rollouts to return an async generator and track calls + # Mock execute_rollouts to return tasks and track calls call_args = [] - async def mock_execute_rollouts(*args, **kwargs): + async def mock_task(): + return "ok" + + def mock_execute_rollouts(*args, **kwargs): call_args.append((args, kwargs)) - for item in ["ok"]: - yield item + import asyncio + + return [asyncio.create_task(mock_task())] manager_instance.execute_rollouts = mock_execute_rollouts result = [] - async for row in ep.rollout( + tasks = ep.rollout( "http://localhost:1234/mcp/", policy, dataset=dataset, model_id="test_model", steps=5, - ): + ) + for task in tasks: + row = await task result.append(row) mock_make.assert_called_once_with( diff --git a/tests/test_rollout_error_handling.py b/tests/test_rollout_error_handling.py deleted file mode 100644 index 767fa16d..00000000 --- a/tests/test_rollout_error_handling.py +++ /dev/null @@ -1,251 +0,0 @@ -""" -Unit tests for rollout processor error handling. - -Tests that rollout processors properly set rollout_status.status = "error" when exceptions occur. -""" - -import asyncio -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from eval_protocol.dataset_logger import default_logger -from eval_protocol.models import EvaluationRow, Message, RolloutStatus -from eval_protocol.pytest.default_agent_rollout_processor import default_agent_rollout_processor -from eval_protocol.pytest.default_single_turn_rollout_process import default_single_turn_rollout_processor -from eval_protocol.pytest.types import RolloutProcessorConfig - - -class TestRolloutErrorHandling: - """Test that rollout processors handle errors correctly.""" - - @pytest.mark.asyncio - async def test_agent_rollout_processor_429_error(self): - """Test that agent rollout processor handles 429 rate limit errors correctly.""" - - # Create test row with initialized rollout_status - test_row = EvaluationRow( - messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") - ) - - config = RolloutProcessorConfig( - completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger # Empty to avoid MCP setup - ) - - # Mock the LiteLLM policy to raise a 429 error - with patch("eval_protocol.pytest.default_agent_rollout_processor.LiteLLMPolicy") as mock_policy_class: - # Create a mock policy instance - mock_policy = AsyncMock() - mock_policy_class.return_value = mock_policy - - # Mock the _make_llm_call method to raise a 429 error - import litellm - - mock_policy._make_llm_call.side_effect = litellm.RateLimitError( - message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" - ) - - # The agent rollout processor should catch the exception and set error status - result = [] - async for row in default_agent_rollout_processor([test_row], config): - result.append(row) - - assert len(result) == 1 - assert result[0].rollout_status.status == "error" - assert result[0].rollout_status.termination_reason is not None - assert ( - "429" in result[0].rollout_status.termination_reason - or "rate limit" in result[0].rollout_status.termination_reason.lower() - ) - - @pytest.mark.asyncio - async def test_agent_rollout_processor_bad_request_error(self): - """Test that agent rollout processor handles BadRequest errors correctly.""" - - test_row = EvaluationRow( - messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") - ) - - config = RolloutProcessorConfig( - completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger - ) - - # Mock the LiteLLM policy to raise a BadRequest error like the one in your example - with patch("eval_protocol.pytest.default_agent_rollout_processor.LiteLLMPolicy") as mock_policy_class: - mock_policy = AsyncMock() - mock_policy_class.return_value = mock_policy - - import openai - - mock_policy._make_llm_call.side_effect = openai.BadRequestError( - "Invalid value for 'content': expected a string, got null.", response=MagicMock(), body=None - ) - - result = [] - async for row in default_agent_rollout_processor([test_row], config): - result.append(row) - - assert len(result) == 1 - assert result[0].rollout_status.status == "error" - assert result[0].rollout_status.termination_reason is not None - assert ( - "content" in result[0].rollout_status.termination_reason - or "BadRequest" in result[0].rollout_status.termination_reason - ) - - @pytest.mark.asyncio - async def test_single_turn_rollout_processor_429_error(self): - """Test that single turn rollout processor handles 429 rate limit errors correctly.""" - - test_row = EvaluationRow( - messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") - ) - - config = RolloutProcessorConfig( - completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger - ) - - # Mock litellm.acompletion to raise a 429 error - with patch("importlib.import_module") as mock_import: - mock_litellm = MagicMock() - mock_import.return_value = mock_litellm - - import litellm - - mock_litellm.acompletion.side_effect = litellm.RateLimitError( - message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" - ) - - result = [] - async for row in default_single_turn_rollout_processor([test_row], config): - result.append(row) - - assert len(result) == 1 - assert result[0].rollout_status.status == "error" - assert result[0].rollout_status.termination_reason is not None - assert ( - "429" in result[0].rollout_status.termination_reason - or "rate limit" in result[0].rollout_status.termination_reason.lower() - ) - - @pytest.mark.asyncio - async def test_single_turn_rollout_processor_bad_request_error(self): - """Test that single turn rollout processor handles BadRequest errors correctly.""" - - test_row = EvaluationRow( - messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") - ) - - config = RolloutProcessorConfig( - completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger - ) - - # Mock litellm.acompletion to raise a BadRequest error - with patch("importlib.import_module") as mock_import: - mock_litellm = MagicMock() - mock_import.return_value = mock_litellm - - import openai - - mock_litellm.acompletion.side_effect = openai.BadRequestError( - "Invalid value for 'content': expected a string, got null.", response=MagicMock(), body=None - ) - - result = [] - async for row in default_single_turn_rollout_processor([test_row], config): - result.append(row) - - assert len(result) == 1 - assert result[0].rollout_status.status == "error" - assert result[0].rollout_status.termination_reason is not None - assert ( - "content" in result[0].rollout_status.termination_reason - or "BadRequest" in result[0].rollout_status.termination_reason - ) - - @pytest.mark.asyncio - async def test_multiple_rows_with_mixed_errors(self): - """Test that when some rows get 429 errors and some succeed, each gets the correct status.""" - - # Create test rows - row1 = EvaluationRow( - messages=[Message(role="user", content="Hello 1")], rollout_status=RolloutStatus(status="running") - ) - - row2 = EvaluationRow( - messages=[Message(role="user", content="Hello 2")], rollout_status=RolloutStatus(status="running") - ) - - config = RolloutProcessorConfig( - completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger - ) - - # Mock litellm.acompletion to raise 429 for both rows (simulating rate limiting) - with patch("importlib.import_module") as mock_import: - mock_litellm = MagicMock() - mock_import.return_value = mock_litellm - - import litellm - - mock_litellm.acompletion.side_effect = litellm.RateLimitError( - message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" - ) - - result = [] - async for row in default_single_turn_rollout_processor([row1, row2], config): - result.append(row) - - assert len(result) == 2 - # Both should have error status due to 429 errors - for row in result: - assert row.rollout_status.status == "error" - assert row.rollout_status.termination_reason is not None - assert ( - "429" in row.rollout_status.termination_reason - or "rate limit" in row.rollout_status.termination_reason.lower() - ) - - @pytest.mark.asyncio - async def test_rollout_status_preserves_original_row_data_on_api_error(self): - """Test that when API errors occur, the original row data is preserved.""" - - original_message = Message(role="user", content="Original message") - test_row = EvaluationRow(messages=[original_message], rollout_status=RolloutStatus(status="running")) - - config = RolloutProcessorConfig( - completion_params={"model": "gpt-4"}, mcp_config_path="", logger=default_logger - ) - - # Mock the LiteLLM policy to raise an API error - with patch("eval_protocol.pytest.default_agent_rollout_processor.LiteLLMPolicy") as mock_policy_class: - mock_policy = AsyncMock() - mock_policy_class.return_value = mock_policy - - import litellm - - mock_policy._make_llm_call.side_effect = litellm.RateLimitError( - message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" - ) - - result = [] - async for row in default_agent_rollout_processor([test_row], config): - result.append(row) - - assert len(result) == 1 - assert result[0].rollout_status.status == "error" - # Original message should be preserved - assert len(result[0].messages) == 1 - assert result[0].messages[0].content == "Original message" - - def test_rollout_status_initialization(self): - """Test that RolloutStatus initializes with correct default values.""" - - # Test default initialization - status = RolloutStatus() - assert status.status == "running" # Default from the model - assert status.termination_reason == "" # Default empty string - - # Test explicit initialization - status = RolloutStatus(status="error", termination_reason="Test error") - assert status.status == "error" - assert status.termination_reason == "Test error" From 3c54cc7fd128103d1795a86ccc9e5c40e9690f36 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 14 Aug 2025 01:43:13 -0700 Subject: [PATCH 3/9] Changing the rollout processors --- eval_protocol/benchmarks/suites/aime25.py | 4 +- eval_protocol/benchmarks/suites/gpqa.py | 42 ++- .../suites/livebench_data_analysis.py | 8 +- .../benchmarks/suites/tau_bench_retail.py | 4 +- eval_protocol/pytest/__init__.py | 19 +- .../pytest/default_agent_rollout_processor.py | 63 ++-- .../pytest/default_base_rollout_process.py | 31 ++ .../default_mcp_gym_rollout_processor.py | 149 +++++----- .../pytest/default_no_op_rollout_process.py | 21 -- .../default_single_turn_rollout_process.py | 196 ++++++------ eval_protocol/pytest/evaluation_test.py | 279 ++---------------- eval_protocol/pytest/types.py | 3 - eval_protocol/pytest/utils.py | 233 ++++++++++++++- examples/gpqa/tests/test_gpqa.py | 4 +- examples/healthbench/tests/test_evaluation.py | 4 +- tests/pytest/test_apps_coding.py | 4 +- tests/pytest/test_basic_coding.py | 4 +- tests/pytest/test_frozen_lake.py | 4 +- tests/pytest/test_hallucination.py | 4 +- tests/pytest/test_lunar_lander.py | 4 +- tests/pytest/test_markdown_highlighting.py | 4 +- ..._pytest_default_agent_rollout_processor.py | 4 +- tests/pytest/test_pytest_ensure_logging.py | 4 +- tests/pytest/test_pytest_flaky_sometimes.py | 4 +- tests/pytest/test_pytest_function_calling.py | 4 +- tests/pytest/test_pytest_ids.py | 6 +- tests/pytest/test_pytest_input_messages.py | 4 +- tests/pytest/test_pytest_json_schema.py | 4 +- tests/pytest/test_pytest_math_example.py | 4 +- .../pytest/test_pytest_math_format_length.py | 4 +- tests/pytest/test_pytest_mcp_config.py | 4 +- tests/pytest/test_pytest_mcp_url.py | 4 +- .../pytest/test_pytest_word_count_example.py | 4 +- tests/pytest/test_tau_bench_airline.py | 4 +- tests/test_retry_mechanism.py | 55 ++-- tests/test_tau_bench_airline_smoke.py | 4 +- 36 files changed, 615 insertions(+), 582 deletions(-) create mode 100644 eval_protocol/pytest/default_base_rollout_process.py delete mode 100644 eval_protocol/pytest/default_no_op_rollout_process.py diff --git a/eval_protocol/benchmarks/suites/aime25.py b/eval_protocol/benchmarks/suites/aime25.py index 3558eaa1..92d7bedc 100644 --- a/eval_protocol/benchmarks/suites/aime25.py +++ b/eval_protocol/benchmarks/suites/aime25.py @@ -3,7 +3,7 @@ from eval_protocol.benchmarks.registry import export_benchmark from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult from eval_protocol.pytest.default_single_turn_rollout_process import ( - default_single_turn_rollout_processor, + SingleTurnRolloutProcessor, ) from eval_protocol.pytest.evaluation_test import evaluation_test @@ -72,7 +72,7 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]: "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", } ], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=8, diff --git a/eval_protocol/benchmarks/suites/gpqa.py b/eval_protocol/benchmarks/suites/gpqa.py index ff745adc..531dfcaa 100644 --- a/eval_protocol/benchmarks/suites/gpqa.py +++ b/eval_protocol/benchmarks/suites/gpqa.py @@ -8,10 +8,12 @@ from eval_protocol.benchmarks.registry import export_benchmark from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult +from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor from eval_protocol.pytest.default_single_turn_rollout_process import ( - default_single_turn_rollout_processor, + SingleTurnRolloutProcessor, ) from eval_protocol.pytest.evaluation_test import evaluation_test +from eval_protocol.pytest.types import RolloutProcessorConfig SYSTEM_PROMPT = ( "You are a helpful assistant. Read the question and options carefully. " @@ -61,19 +63,31 @@ def _strip_gt_messages(msgs: List[Message]) -> List[Message]: return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))] -def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[asyncio.Task[EvaluationRow]]: - """Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to default processor.""" - processed: List[EvaluationRow] = [] - for r in rows: - gt_tokens = [m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")] - if gt_tokens: - gt_val = gt_tokens[-1].split(":", 1)[1].strip() - r.ground_truth = gt_val - r.messages = [ - m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:")) +class GPQAStripGTRolloutProcessor(BaseRolloutProcessor): + """Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to SingleTurnRolloutProcessor.""" + + def __init__(self): + super().__init__() + self.single_turn_processor = SingleTurnRolloutProcessor() + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Preprocess rows and delegate to SingleTurnRolloutProcessor.""" + processed: List[EvaluationRow] = [] + + for r in rows: + gt_tokens = [ + m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:") ] - processed.append(r) - return default_single_turn_rollout_processor(processed, config) + if gt_tokens: + gt_val = gt_tokens[-1].split(":", 1)[1].strip() + r.ground_truth = gt_val + r.messages = [ + m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:")) + ] + processed.append(r) + + # Delegate to SingleTurnRolloutProcessor + return self.single_turn_processor(processed, config) @export_benchmark("gpqa") @@ -82,7 +96,7 @@ def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[a completion_params=[ {"extra_body": {"reasoning_effort": "low"}, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"} ], - rollout_processor=gpqa_strip_gt_rollout_processor, + rollout_processor=GPQAStripGTRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=8, diff --git a/eval_protocol/benchmarks/suites/livebench_data_analysis.py b/eval_protocol/benchmarks/suites/livebench_data_analysis.py index fc5abb4e..da384439 100644 --- a/eval_protocol/benchmarks/suites/livebench_data_analysis.py +++ b/eval_protocol/benchmarks/suites/livebench_data_analysis.py @@ -5,7 +5,7 @@ from eval_protocol.benchmarks.registry import export_benchmark, register_composite_benchmark from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult from eval_protocol.pytest.default_single_turn_rollout_process import ( - default_single_turn_rollout_processor, + SingleTurnRolloutProcessor, ) from eval_protocol.pytest.evaluation_test import evaluation_test @@ -375,7 +375,7 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]: completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], input_messages=[[m for m in r.messages] for r in _CTA_ROWS], rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=4, @@ -418,7 +418,7 @@ def livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow: completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS], rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=4, @@ -462,7 +462,7 @@ def livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow: completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS], rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=4, diff --git a/eval_protocol/benchmarks/suites/tau_bench_retail.py b/eval_protocol/benchmarks/suites/tau_bench_retail.py index 8e8aaea0..6c0a8a36 100644 --- a/eval_protocol/benchmarks/suites/tau_bench_retail.py +++ b/eval_protocol/benchmarks/suites/tau_bench_retail.py @@ -13,7 +13,7 @@ from eval_protocol.benchmarks.registry import export_benchmark from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor +from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor from vendor.tau2.data_model.message import ( AssistantMessage, SystemMessage, @@ -73,7 +73,7 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", } ], - rollout_processor=default_mcp_gym_rollout_processor, + rollout_processor=MCPGymRolloutProcessor(), rollout_processor_kwargs={"domain": "retail"}, num_runs=8, mode="pointwise", diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index 2d2576d6..7b5de324 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -1,18 +1,17 @@ -from .default_agent_rollout_processor import default_agent_rollout_processor +from .default_agent_rollout_processor import AgentRolloutProcessor +from .default_base_rollout_process import BaseRolloutProcessor from .default_dataset_adapter import default_dataset_adapter -from .default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor -from .default_no_op_rollout_process import default_no_op_rollout_processor -from .default_single_turn_rollout_process import default_single_turn_rollout_processor +from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor +from .default_single_turn_rollout_process import SingleTurnRolloutProcessor from .evaluation_test import evaluation_test -from .types import RolloutProcessor, RolloutProcessorConfig +from .types import RolloutProcessorConfig __all__ = [ - "default_agent_rollout_processor", - "default_mcp_gym_rollout_processor", - "default_no_op_rollout_processor", - "default_single_turn_rollout_processor", + "AgentRolloutProcessor", + "MCPGymRolloutProcessor", + "BaseRolloutProcessor", + "SingleTurnRolloutProcessor", "default_dataset_adapter", - "RolloutProcessor", "RolloutProcessorConfig", "evaluation_test", ] diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index ab1f596e..6a220344 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -13,6 +13,7 @@ from eval_protocol.mcp.execution.policy import LiteLLMPolicy from eval_protocol.mcp.mcp_multi_client import MCPMultiClient from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig logger = logging.getLogger(__name__) @@ -115,32 +116,36 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[Tex return tool_result.content -def default_agent_rollout_processor( - rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> List[asyncio.Task[EvaluationRow]]: - """Create agent rollout tasks and return them for external handling.""" - - max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 - semaphore = asyncio.Semaphore(max_concurrent) - - async def process_row(row: EvaluationRow) -> EvaluationRow: - """Process a single row with agent rollout.""" - agent = Agent( - model=config.completion_params["model"], row=row, config_path=config.mcp_config_path, logger=config.logger - ) - try: - await agent.setup() - await agent.call_agent() - return agent.evaluation_row - finally: - if agent.mcp_client: - await agent.mcp_client.cleanup() - - async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: - async with semaphore: - result = await process_row(r) - return result - - # Create and return tasks for external handling - tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] - return tasks +class AgentRolloutProcessor(BaseRolloutProcessor): + """Agent rollout processor for tool-calling agents.""" + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Create agent rollout tasks and return them for external handling.""" + + max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_row(row: EvaluationRow) -> EvaluationRow: + """Process a single row with agent rollout.""" + agent = Agent( + model=config.completion_params["model"], + row=row, + config_path=config.mcp_config_path, + logger=config.logger, + ) + try: + await agent.setup() + await agent.call_agent() + return agent.evaluation_row + finally: + if agent.mcp_client: + await agent.mcp_client.cleanup() + + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: + async with semaphore: + result = await process_row(r) + return result + + # Create and return tasks for external handling + tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + return tasks diff --git a/eval_protocol/pytest/default_base_rollout_process.py b/eval_protocol/pytest/default_base_rollout_process.py new file mode 100644 index 00000000..be3597e6 --- /dev/null +++ b/eval_protocol/pytest/default_base_rollout_process.py @@ -0,0 +1,31 @@ +import asyncio +from typing import List + +from eval_protocol.models import EvaluationRow +from eval_protocol.pytest.types import RolloutProcessorConfig + + +class BaseRolloutProcessor: + """ + Base rollout processor - minimal implementation that all others inherit from. + + This is the Strategy pattern base class. It provides: + 1. __call__(rows, config) -> tasks (the main interface) + 2. cleanup() -> None (resource cleanup) + + All other processors inherit from this and override as needed. + """ + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Process rows by returning them unchanged (no-op implementation).""" + + async def return_row(row: EvaluationRow) -> EvaluationRow: + return row + + # Create tasks that immediately return the rows (no-op) + tasks = [asyncio.create_task(return_row(row)) for row in rows] + return tasks + + def cleanup(self) -> None: + """No-op cleanup - override in subclasses if needed.""" + pass diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 0d5350be..491e4298 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -10,6 +10,7 @@ import eval_protocol as ep from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig CURRENT_RUN_STATE: Dict[str, Any] = {} @@ -194,90 +195,80 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False # Don't suppress exceptions -def default_mcp_gym_rollout_processor( - rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> List[asyncio.Task[EvaluationRow]]: +class MCPGymRolloutProcessor(BaseRolloutProcessor): """ Rollout processor for tau bench environments. This processor starts an MCP server, creates tau bench environments, and returns rollout tasks - using the eval_protocol framework. + using the eval_protocol framework with proper cleanup handling. + """ - Args: - rows: List of EvaluationRow objects containing messages and dataset info in input_metadata - config: RolloutProcessorConfig with model and other parameters - - config.kwargs can include: - - start_server (bool): If True, create fresh server and environments. If False, reuse existing ones. Default: True. + def __init__(self): + self.current_run_state: Dict[str, Any] = {} - Returns: - List of asyncio.Task objects for external handling - """ - start_server = config.kwargs.get("start_server", True) if config.kwargs else True - if start_server: - # Create fresh MCP server and environments for this run - if config.server_script_path is None: - raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor") + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Process evaluation rows with MCP gym environments.""" + start_server = config.kwargs.get("start_server", True) if config.kwargs else True - server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {})) + if start_server: + # Create fresh MCP server and environments for this run + if config.server_script_path is None: + raise ValueError("server_script_path is required for MCPGymRolloutProcessor") - try: - server.start() - - policy = ep.LiteLLMPolicy( - model_id=config.completion_params.get("model", None), - temperature=config.completion_params.get("temperature", 0.0), - max_tokens=config.completion_params.get("max_tokens", 4096), - reasoning_effort=config.completion_params.get("reasoning_effort", None), - ) - - # Create MCP environments directly from evaluation_rows - envs = ep.make( - "http://localhost:9700/mcp/", - evaluation_rows=rows, - model_id=policy.model_id, - ) - - # Store in current run state for reuse within this run - CURRENT_RUN_STATE.update( - { - "server": server, - "envs": envs, - "policy": policy, - } - ) - - except Exception as e: - server.stop() - CURRENT_RUN_STATE.clear() - raise e - - else: - # Reuse existing MCP environments for retry - if not CURRENT_RUN_STATE: - raise RuntimeError("Cannot retry without existing server/environments. Call with start_server=True first.") - - server = CURRENT_RUN_STATE["server"] - envs = CURRENT_RUN_STATE["envs"] - policy = CURRENT_RUN_STATE["policy"] - - # Get rollout tasks from ep.rollout - tasks = ep.rollout( - envs, - policy=policy, - evaluation_rows=rows, - steps=config.steps, - max_concurrent_rollouts=config.max_concurrent_rollouts, - ) - return tasks - - -# Add cleanup method directly to the function object -def _cleanup_mcp_gym_rollout_processor(): - """Cleanup function for MCP gym rollout processor""" - if CURRENT_RUN_STATE and "server" in CURRENT_RUN_STATE: - CURRENT_RUN_STATE["server"].stop() - CURRENT_RUN_STATE.clear() # Clear for next run - - -# Attach cleanup method to the processor function -default_mcp_gym_rollout_processor.cleanup = _cleanup_mcp_gym_rollout_processor + server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {})) + + try: + server.start() + + policy = ep.LiteLLMPolicy( + model_id=config.completion_params.get("model", None), + temperature=config.completion_params.get("temperature", 0.0), + max_tokens=config.completion_params.get("max_tokens", 4096), + reasoning_effort=config.completion_params.get("reasoning_effort", None), + ) + + # Store in instance state for cleanup + self.current_run_state.update( + { + "server": server, + "policy": policy, + } + ) + + except Exception as e: + server.stop() + self.current_run_state.clear() + raise e + + else: + # Reuse existing MCP environments for retry + if not self.current_run_state: + raise RuntimeError( + "Cannot retry without existing server/environments. Call with start_server=True first." + ) + + server = self.current_run_state["server"] + policy = self.current_run_state["policy"] + + # Create MCP environments directly from evaluation_rows + envs = ep.make( + "http://localhost:9700/mcp/", + evaluation_rows=rows, + model_id=policy.model_id, + ) + + # Get rollout tasks from ep.rollout + tasks = ep.rollout( + envs, + policy=policy, + evaluation_rows=rows, + steps=config.steps, + max_concurrent_rollouts=config.max_concurrent_rollouts, + ) + return tasks + + def cleanup(self) -> None: + """Cleanup MCP server and environments.""" + if self.current_run_state and "server" in self.current_run_state: + self.current_run_state["server"].stop() + self.current_run_state.clear() diff --git a/eval_protocol/pytest/default_no_op_rollout_process.py b/eval_protocol/pytest/default_no_op_rollout_process.py deleted file mode 100644 index afcd9206..00000000 --- a/eval_protocol/pytest/default_no_op_rollout_process.py +++ /dev/null @@ -1,21 +0,0 @@ -import asyncio -from typing import List - -from eval_protocol.models import EvaluationRow -from eval_protocol.pytest.types import RolloutProcessorConfig - - -def default_no_op_rollout_processor( - rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> List[asyncio.Task[EvaluationRow]]: - """ - Simply passes input dataset through to the test function. This can be useful - if you want to run the rollout yourself. - """ - - async def return_row(row: EvaluationRow) -> EvaluationRow: - return row - - # Create tasks that immediately return the rows (no-op) - tasks = [asyncio.create_task(return_row(row)) for row in rows] - return tasks diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 4787cf41..cb7ea347 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -10,108 +10,110 @@ from eval_protocol.dataset_logger import default_logger from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig logger = logging.getLogger(__name__) -def default_single_turn_rollout_processor( - rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> List[asyncio.Task[EvaluationRow]]: - """Generate single turn rollout tasks and return them for external handling.""" - - # Quiet LiteLLM logs in test runs unless user overrode - try: - if os.environ.get("LITELLM_LOG") is None: - os.environ["LITELLM_LOG"] = "ERROR" - _llog = logging.getLogger("LiteLLM") - _llog.setLevel(logging.CRITICAL) - _llog.propagate = False - for _h in list(_llog.handlers): - _llog.removeHandler(_h) - except Exception: - pass - - # Do not modify global LiteLLM cache. Disable caching per-request instead. - - async def process_row(row: EvaluationRow) -> EvaluationRow: - """Process a single row asynchronously.""" - if len(row.messages) == 0: - raise ValueError("Messages is empty. Please provide a non-empty dataset") - - messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] - - request_params = {"messages": messages_payload, **config.completion_params} - # Ensure caching is disabled only for this request (review feedback) - request_params["cache"] = {"no-cache": True} - # Single-level reasoning effort: expect `reasoning_effort` only - effort_val = None - - if "reasoning_effort" in config.completion_params: - effort_val = str(config.completion_params["reasoning_effort"]) # flat shape - elif ( - isinstance(config.completion_params.get("extra_body"), dict) - and "reasoning_effort" in config.completion_params["extra_body"] - ): - # Accept if user passed it directly inside extra_body - effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body - - if effort_val: - # Always under extra_body so LiteLLM forwards to provider-specific param set - request_params.setdefault("extra_body", {}) - request_params["extra_body"]["reasoning_effort"] = effort_val - # Ensure unsupported top-level keys are not present - if "reasoning_effort" in request_params: - request_params.pop("reasoning_effort", None) - - if row.tools is not None: - request_params["tools"] = row.tools - - # Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet - import importlib - - _litellm = importlib.import_module("litellm") - acompletion = getattr(_litellm, "acompletion") - response = await acompletion(**request_params) - - assistant_content = response.choices[0].message.content or "" - tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None - - converted_tool_calls = None - if tool_calls: - converted_tool_calls = [ - ChatCompletionMessageToolCall( - id=tool_call.id, - type=tool_call.type, - function={ - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - }, +class SingleTurnRolloutProcessor(BaseRolloutProcessor): + """Single turn rollout processor for direct LLM calls.""" + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Generate single turn rollout tasks and return them for external handling.""" + + # Quiet LiteLLM logs in test runs unless user overrode + try: + if os.environ.get("LITELLM_LOG") is None: + os.environ["LITELLM_LOG"] = "ERROR" + _llog = logging.getLogger("LiteLLM") + _llog.setLevel(logging.CRITICAL) + _llog.propagate = False + for _h in list(_llog.handlers): + _llog.removeHandler(_h) + except Exception: + pass + + # Do not modify global LiteLLM cache. Disable caching per-request instead. + + async def process_row(row: EvaluationRow) -> EvaluationRow: + """Process a single row asynchronously.""" + if len(row.messages) == 0: + raise ValueError("Messages is empty. Please provide a non-empty dataset") + + messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] + + request_params = {"messages": messages_payload, **config.completion_params} + # Ensure caching is disabled only for this request (review feedback) + request_params["cache"] = {"no-cache": True} + # Single-level reasoning effort: expect `reasoning_effort` only + effort_val = None + + if "reasoning_effort" in config.completion_params: + effort_val = str(config.completion_params["reasoning_effort"]) # flat shape + elif ( + isinstance(config.completion_params.get("extra_body"), dict) + and "reasoning_effort" in config.completion_params["extra_body"] + ): + # Accept if user passed it directly inside extra_body + effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body + + if effort_val: + # Always under extra_body so LiteLLM forwards to provider-specific param set + request_params.setdefault("extra_body", {}) + request_params["extra_body"]["reasoning_effort"] = effort_val + # Ensure unsupported top-level keys are not present + if "reasoning_effort" in request_params: + request_params.pop("reasoning_effort", None) + + if row.tools is not None: + request_params["tools"] = row.tools + + # Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet + import importlib + + _litellm = importlib.import_module("litellm") + acompletion = getattr(_litellm, "acompletion") + response = await acompletion(**request_params) + + assistant_content = response.choices[0].message.content or "" + tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None + + converted_tool_calls = None + if tool_calls: + converted_tool_calls = [ + ChatCompletionMessageToolCall( + id=tool_call.id, + type=tool_call.type, + function={ + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + ) + for tool_call in tool_calls + ] + + messages = list(row.messages) + [ + Message( + role="assistant", + content=assistant_content, + tool_calls=converted_tool_calls, ) - for tool_call in tool_calls ] - messages = list(row.messages) + [ - Message( - role="assistant", - content=assistant_content, - tool_calls=converted_tool_calls, - ) - ] - - row.messages = messages - default_logger.log(row) - return row - - # Process rows with bounded concurrency - max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 - semaphore = asyncio.Semaphore(max_concurrent) - - async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: - async with semaphore: - result = await process_row(r) - return result - - # Create and return tasks for external handling - tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] - return tasks + row.messages = messages + default_logger.log(row) + return row + + # Process rows with bounded concurrency + max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 + semaphore = asyncio.Semaphore(max_concurrent) + + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: + async with semaphore: + result = await process_row(r) + return result + + # Create and return tasks for external handling + tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + return tasks diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 93e586da..70f976fd 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -24,8 +24,8 @@ InputMetadata, Message, ) +from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter -from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor from eval_protocol.pytest.types import ( Dataset, DatasetPathParam, @@ -33,7 +33,6 @@ EvaluationTestMode, InputMessagesParam, ModelParam, - RolloutProcessor, RolloutProcessorConfig, RolloutProcessorInputParam, TestFunction, @@ -42,8 +41,14 @@ AggregationMethod, aggregate, create_dynamically_parameterized_wrapper, + deep_update_dict, execute_function, + extract_effort_tag, + generate_parameter_combinations, log_eval_status_and_rows, + parse_ep_max_rows, + rollout_processor_with_retry, + sanitize_filename, ) from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci @@ -56,7 +61,7 @@ def evaluation_test( # noqa: C901 input_messages: Optional[List[InputMessagesParam]] = None, input_dataset: Optional[List[DatasetPathParam]] = None, dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter, - rollout_processor: RolloutProcessor = default_no_op_rollout_processor, + rollout_processor: BaseRolloutProcessor = BaseRolloutProcessor(), evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None, rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None, aggregation_method: AggregationMethod = "mean", @@ -201,168 +206,15 @@ async def execute_with_params( return test_func(**kwargs) # Calculate all possible combinations of parameters - def _parse_ep_max_rows(default_value: int | None) -> int | None: - """Read EP_MAX_DATASET_ROWS env override as int or None.""" - raw = os.getenv("EP_MAX_DATASET_ROWS") - if raw is None: - return default_value - s = raw.strip().lower() - if s == "none": - return None - try: - return int(s) - except ValueError: - return default_value - - def _deep_update_dict(base: dict, override: dict) -> dict: - """Recursively update nested dictionaries in-place and return base.""" - for key, value in override.items(): - if isinstance(value, dict) and isinstance(base.get(key), dict): - _deep_update_dict(base[key], value) - else: - base[key] = value - return base - - def generate_combinations(): - combinations = [] - - # Handle optional parameters with defaults - # Optionally combine multiple dataset paths into one logical dataset, - # or parameterize to run one dataset per test invocation. - if input_dataset is not None: - if combine_datasets: - datasets: List[Optional[List[DatasetPathParam]]] = [input_dataset] # type: ignore - else: - # Fan out: one dataset path per parameterization - if isinstance(input_dataset, list): # type: ignore - datasets = [[p] for p in input_dataset] # type: ignore - else: - datasets = [[input_dataset]] # type: ignore - else: - datasets = [None] - cps: List[Optional[CompletionParams]] = completion_params if completion_params is not None else [None] # type: ignore - # Apply EP_MAX_DATASET_ROWS to input_messages, but do NOT parameterize over - # each row. Instead, pass the entire sliced list through in a single test run - # so summaries aggregate all rows together (AIME-style behavior). - if input_messages is not None and isinstance(input_messages, list): - effective_max_rows = _parse_ep_max_rows(max_dataset_rows) - if effective_max_rows is not None: - sliced_messages = input_messages[:effective_max_rows] # type: ignore - else: - sliced_messages = input_messages # type: ignore - # Wrap as a single parameter payload - messages = [sliced_messages] # type: ignore - else: - messages = [None] # type: ignore - kwargs: List[Optional[EvaluationInputParam]] = evaluation_test_kwargs if evaluation_test_kwargs is not None else [None] # type: ignore - - # Generate all combinations - for ds in datasets: - for cp in cps: - for im in messages: - for etk in kwargs: - # if no dataset and no messages, raise an error - if ds is None and im is None: - raise ValueError( - "No dataset or messages provided. Please provide at least one of input_dataset or input_messages." - ) - combinations.append((ds, cp, im, etk)) - - return combinations - - async def rollout_processor_with_retry( - rollout_processor: RolloutProcessor, - fresh_dataset: List[EvaluationRow], - config: RolloutProcessorConfig, - max_retry: int, - ): - """ - Wrapper around rollout_processor that handles retry logic internally. - Uses async queue pattern to yield results immediately as they become available. - Yields both successful and failed results, leaving it up to the user to handle them in test_func. - """ - - try: - queue = asyncio.Queue() - retry_counts = {r.execution_metadata.rollout_id: 0 for r in fresh_dataset} - failed_permanently = [] - - async def retry_handler(failed_row: EvaluationRow): - rollout_id = failed_row.execution_metadata.rollout_id - current_attempts = retry_counts.get(rollout_id, 0) - - if current_attempts >= max_retry: - assert ( - failed_row.rollout_status and failed_row.rollout_status.status == "error" - ), f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" - failed_permanently.append(failed_row) - await queue.put(failed_row) # put failed row on queue - return - retry_counts[rollout_id] = current_attempts + 1 - - # add kwargs start_server=False to config so we don't start new MCP server - retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False}) - - retry_tasks = rollout_processor([failed_row], retry_config) - - try: - retry_result = await retry_tasks[0] - retry_result.rollout_status.status = "finished" - await queue.put(retry_result) - except Exception as e: - failed_row.rollout_status.status = "error" - failed_row.rollout_status.termination_reason = str(e) - asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry - - async def initial_processor(): - """Process initial batch and spawn retries for failures""" - base_tasks = rollout_processor(fresh_dataset, config) - pending = set(base_tasks) - - while pending: - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - - for task in done: - task_index = base_tasks.index(task) - - try: - result = await task - result.rollout_status.status = "finished" - await queue.put(result) - except Exception as e: - failed_row = fresh_dataset[task_index] - failed_row.rollout_status.status = "error" - failed_row.rollout_status.termination_reason = str(e) - asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task - - processor_task = asyncio.create_task(initial_processor()) - - # yield results as they become available - completed_count = 0 - total_expected = len(fresh_dataset) - - while completed_count < total_expected: - finished_row = await queue.get() - - # only permanent failure rows are put on the queue, so we can check for them here - if finished_row.rollout_status and finished_row.rollout_status.status == "error": - if os.getenv("EP_FAIL_ON_PERMANENT_FAILURE", "true") != "false": - raise RuntimeError( - f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {finished_row.rollout_status.termination_reason}" - ) - - completed_count += 1 - yield finished_row - - await processor_task # explicitly wait for task completion and catch any exceptions - - finally: - # processor clean up after themselves if they have a cleanup method - if hasattr(rollout_processor, "cleanup"): - rollout_processor.cleanup() - - combinations = generate_combinations() + combinations = generate_parameter_combinations( + input_dataset, + completion_params, + input_messages, + evaluation_test_kwargs, + max_dataset_rows, + combine_datasets, + ) if len(combinations) == 0: raise ValueError( "No combinations of parameters were found. Please provide at least a model and one of input_dataset or input_messages." @@ -424,7 +276,7 @@ def _log_eval_error( else: data_jsonl = load_jsonl(ds_arg) # Apply env override for max rows if present - effective_max_rows = _parse_ep_max_rows(max_dataset_rows) + effective_max_rows = parse_ep_max_rows(max_dataset_rows) if effective_max_rows is not None: data_jsonl = data_jsonl[:effective_max_rows] data = dataset_adapter(data_jsonl) @@ -460,7 +312,7 @@ def _log_eval_error( if _env_override: override_obj = _json.loads(_env_override) if isinstance(override_obj, dict): - completion_params = _deep_update_dict(dict(completion_params), override_obj) + completion_params = deep_update_dict(dict(completion_params), override_obj) except Exception: pass @@ -688,35 +540,9 @@ async def _execute_with_semaphore(row): ) # As per project convention, avoid printing per-metric CI lines to reduce noise if summary_path: - - def _sanitize_filename(text: str) -> str: - safe = re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip()) - return safe[:120] - - def _extract_effort_tag(params: dict) -> str | None: - try: - if not isinstance(params, dict): - return None - # Common locations - if "extra_body" in params and isinstance(params["extra_body"], dict): - eb = params["extra_body"] - if isinstance(eb.get("reasoning"), dict) and "effort" in eb["reasoning"]: - return str(eb["reasoning"]["effort"]).lower() - if "reasoning_effort" in eb: - return str(eb["reasoning_effort"]).lower() - if ( - "reasoning" in params - and isinstance(params["reasoning"], dict) - and "effort" in params["reasoning"] - ): - return str(params["reasoning"]["effort"]).lower() - except Exception: - return None - return None - - model_slug = _sanitize_filename(model_used) - effort_tag = _extract_effort_tag(completion_params) or "" - effort_suffix = f"__effort-{_sanitize_filename(effort_tag)}" if effort_tag else "" + model_slug = sanitize_filename(model_used) + effort_tag = extract_effort_tag(completion_params) or "" + effort_suffix = f"__effort-{sanitize_filename(effort_tag)}" if effort_tag else "" base_name = f"{suite_name}__{model_slug}{effort_suffix}__{mode}__runs{num_runs}.json" p = pathlib.Path(summary_path) @@ -734,7 +560,7 @@ def _extract_effort_tag(params: dict) -> str | None: parent.mkdir(parents=True, exist_ok=True) # If we detected an effort tag, fan out to separate files; otherwise write to the exact file if effort_tag: - out_file = parent / f"{p.stem}__{_sanitize_filename(effort_tag)}{p.suffix}" + out_file = parent / f"{p.stem}__{sanitize_filename(effort_tag)}{p.suffix}" else: out_file = p @@ -923,7 +749,7 @@ def run_evaluation_test_direct( input_dataset: Optional[List[DatasetPathParam]] = None, dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter, completion_params: Optional[CompletionParams] = None, - rollout_processor: RolloutProcessor = default_no_op_rollout_processor, + rollout_processor: BaseRolloutProcessor = BaseRolloutProcessor(), rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None, aggregation_method: AggregationMethod = "mean", passed_threshold: Optional[Union[EvaluationThreshold, float]] = None, @@ -945,26 +771,6 @@ def run_evaluation_test_direct( if passed_threshold is not None and not isinstance(passed_threshold, EvaluationThreshold): passed_threshold = EvaluationThreshold(success=passed_threshold) - def _parse_ep_max_rows(default_value: int | None) -> int | None: - raw = os.getenv("EP_MAX_DATASET_ROWS") - if raw is None: - return default_value - s = raw.strip().lower() - if s == "none": - return None - try: - return int(s) - except ValueError: - return default_value - - def _deep_update_dict(base: dict, override: dict) -> dict: - for key, value in override.items(): - if isinstance(value, dict) and isinstance(base.get(key), dict): - _deep_update_dict(base[key], value) - else: - base[key] = value - return base - # Build dataset/messages data: List[EvaluationRow] = [] if input_dataset is not None: @@ -972,12 +778,12 @@ def _deep_update_dict(base: dict, override: dict) -> dict: data_jsonl: List[Dict[str, Any]] = [] for p in input_dataset: data_jsonl.extend(load_jsonl(p)) - effective_max_rows = _parse_ep_max_rows(max_dataset_rows) + effective_max_rows = parse_ep_max_rows(max_dataset_rows) if effective_max_rows is not None: data_jsonl = data_jsonl[:effective_max_rows] data = dataset_adapter(data_jsonl) elif input_messages is not None: - effective_max_rows = _parse_ep_max_rows(max_dataset_rows) + effective_max_rows = parse_ep_max_rows(max_dataset_rows) msgs = input_messages if effective_max_rows is not None and isinstance(msgs, list): msgs = msgs[:effective_max_rows] # type: ignore @@ -997,7 +803,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict: if _env_override: override_obj = _json.loads(_env_override) if isinstance(override_obj, dict): - completion_params = _deep_update_dict(dict(completion_params), override_obj) + completion_params = deep_update_dict(dict(completion_params), override_obj) except Exception: pass @@ -1111,36 +917,11 @@ def _deep_update_dict(base: dict, override: dict) -> dict: if summary_path: import json as _json import pathlib as _pathlib - import re as _re import time as _time - def _sanitize_filename(text: str) -> str: - safe = _re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip()) - return safe[:120] - - def _extract_effort_tag(params: dict) -> str | None: - try: - if not isinstance(params, dict): - return None - if "extra_body" in params and isinstance(params["extra_body"], dict): - eb = params["extra_body"] - if isinstance(eb.get("reasoning"), dict) and "effort" in eb["reasoning"]: - return str(eb["reasoning"]["effort"]).lower() - if "reasoning_effort" in eb: - return str(eb["reasoning_effort"]).lower() - if ( - "reasoning" in params - and isinstance(params["reasoning"], dict) - and "effort" in params["reasoning"] - ): - return str(params["reasoning"]["effort"]).lower() - except Exception: - return None - return None - - model_slug = _sanitize_filename(config.completion_params["model"]) - effort_tag = _extract_effort_tag(completion_params) or "" - effort_suffix = f"__effort-{_sanitize_filename(effort_tag)}" if effort_tag else "" + model_slug = sanitize_filename(config.completion_params["model"]) + effort_tag = extract_effort_tag(completion_params) or "" + effort_suffix = f"__effort-{sanitize_filename(effort_tag)}" if effort_tag else "" base_name = f"{suite_name}__{model_slug}{effort_suffix}__{mode}__runs{num_runs}.json" p = _pathlib.Path(summary_path) @@ -1153,7 +934,7 @@ def _extract_effort_tag(params: dict) -> str | None: parent = p.parent parent.mkdir(parents=True, exist_ok=True) if effort_tag: - out_file = parent / f"{p.stem}__{_sanitize_filename(effort_tag)}{p.suffix}" + out_file = parent / f"{p.stem}__{sanitize_filename(effort_tag)}{p.suffix}" else: out_file = p with open(out_file, "w", encoding="utf-8") as f: diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index b2952cfb..8a3be489 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -50,6 +50,3 @@ class RolloutProcessorConfig: steps: int = 30 # max number of rollout steps logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs kwargs: Dict[str, Any] = field(default_factory=dict) # any additional kwargs to pass to the rollout processor - - -RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[asyncio.Task[EvaluationRow]]] diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 23a5722d..0d3f926b 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -1,9 +1,20 @@ import asyncio import inspect -from typing import Any, Callable, List, Literal, Optional +import os +import re +from dataclasses import replace +from typing import Any, Callable, Dict, List, Literal, Optional, Union from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.models import EvalMetadata, EvaluationRow +from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor +from eval_protocol.pytest.types import ( + CompletionParams, + DatasetPathParam, + EvaluationInputParam, + InputMessagesParam, + RolloutProcessorConfig, +) def execute_function(func: Callable, **kwargs) -> Any: @@ -124,3 +135,223 @@ def log_eval_status_and_rows( if r.eval_metadata is not None: r.eval_metadata.status = status logger.log(r) + + +def parse_ep_max_rows(default_value: Optional[int]) -> Optional[int]: + """Read EP_MAX_DATASET_ROWS env override as int or None.""" + raw = os.getenv("EP_MAX_DATASET_ROWS") + if raw is None: + return default_value + s = raw.strip().lower() + if s == "none": + return None + try: + return int(s) + except ValueError: + return default_value + + +def deep_update_dict(base: dict, override: dict) -> dict: + """Recursively update nested dictionaries in-place and return base.""" + for key, value in override.items(): + if isinstance(value, dict) and isinstance(base.get(key), dict): + deep_update_dict(base[key], value) + else: + base[key] = value + return base + + +def generate_parameter_combinations( + input_dataset: Optional[List[DatasetPathParam]], + completion_params: List[CompletionParams], + input_messages: Optional[List[InputMessagesParam]], + evaluation_test_kwargs: Optional[List[EvaluationInputParam]], + max_dataset_rows: Optional[int], + combine_datasets: bool, +) -> List[tuple]: + """ + Generate all combinations of parameters for pytest parameterization. + + Args: + input_dataset: Dataset paths to use + completion_params: Completion parameters to test + input_messages: Input messages to use + evaluation_test_kwargs: Additional kwargs for evaluation tests + max_dataset_rows: Maximum number of dataset rows to process + combine_datasets: Whether to combine multiple datasets into one test + + Returns: + List of parameter tuples for pytest.mark.parametrize + """ + combinations = [] + + # Handle optional parameters with defaults + # Optionally combine multiple dataset paths into one logical dataset, + # or parameterize to run one dataset per test invocation. + if input_dataset is not None: + if combine_datasets: + datasets: List[Optional[List[DatasetPathParam]]] = [input_dataset] # type: ignore + else: + # Fan out: one dataset path per parameterization + if isinstance(input_dataset, list): # type: ignore + datasets = [[p] for p in input_dataset] # type: ignore + else: + datasets = [[input_dataset]] # type: ignore + else: + datasets = [None] + + cps: List[Optional[CompletionParams]] = completion_params if completion_params is not None else [None] # type: ignore + + # Apply EP_MAX_DATASET_ROWS to input_messages, but do NOT parameterize over + # each row. Instead, pass the entire sliced list through in a single test run + # so summaries aggregate all rows together (AIME-style behavior). + if input_messages is not None and isinstance(input_messages, list): + effective_max_rows = parse_ep_max_rows(max_dataset_rows) + if effective_max_rows is not None: + sliced_messages = input_messages[:effective_max_rows] # type: ignore + else: + sliced_messages = input_messages # type: ignore + # Wrap as a single parameter payload + messages = [sliced_messages] # type: ignore + else: + messages = [None] # type: ignore + + kwargs: List[Optional[EvaluationInputParam]] = evaluation_test_kwargs if evaluation_test_kwargs is not None else [None] # type: ignore + + # Generate all combinations + for ds in datasets: + for cp in cps: + for im in messages: + for etk in kwargs: + # if no dataset and no messages, raise an error + if ds is None and im is None: + raise ValueError( + "No dataset or messages provided. Please provide at least one of input_dataset or input_messages." + ) + combinations.append((ds, cp, im, etk)) + + return combinations + + +async def rollout_processor_with_retry( + rollout_processor: BaseRolloutProcessor, + fresh_dataset: List[EvaluationRow], + config: RolloutProcessorConfig, + max_retry: int, +): + """ + Wrapper around rollout_processor that handles retry logic internally. + Uses async queue pattern to yield results immediately as they become available. + Yields both successful and failed results, leaving it up to the user to handle them in test_func. + """ + + try: + queue = asyncio.Queue() + retry_counts = {r.execution_metadata.rollout_id: 0 for r in fresh_dataset} + failed_permanently = [] + + async def retry_handler(failed_row: EvaluationRow): + rollout_id = failed_row.execution_metadata.rollout_id + current_attempts = retry_counts.get(rollout_id, 0) + + if current_attempts >= max_retry: + assert ( + failed_row.rollout_status and failed_row.rollout_status.status == "error" + ), f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" + failed_permanently.append(failed_row) + await queue.put(failed_row) # put failed row on queue + return + + retry_counts[rollout_id] = current_attempts + 1 + + # add kwargs start_server=False to config so we don't start new MCP server + retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False}) + + retry_tasks = rollout_processor([failed_row], retry_config) + + try: + retry_result = await retry_tasks[0] + retry_result.rollout_status.status = "finished" + await queue.put(retry_result) + except Exception as e: + failed_row.rollout_status.status = "error" + failed_row.rollout_status.termination_reason = str(e) + asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry + + async def initial_processor(): + """Process initial batch and spawn retries for failures""" + base_tasks = rollout_processor(fresh_dataset, config) + pending = set(base_tasks) + + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + + for task in done: + task_index = base_tasks.index(task) + + try: + result = await task + result.rollout_status.status = "finished" + await queue.put(result) + except Exception as e: + failed_row = fresh_dataset[task_index] + failed_row.rollout_status.status = "error" + failed_row.rollout_status.termination_reason = str(e) + asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task + + processor_task = asyncio.create_task(initial_processor()) + + # yield results as they become available + completed_count = 0 + total_expected = len(fresh_dataset) + + while completed_count < total_expected: + finished_row = await queue.get() + + # only permanent failure rows are put on the queue, so we can check for them here + if finished_row.rollout_status and finished_row.rollout_status.status == "error": + if os.getenv("EP_FAIL_ON_PERMANENT_FAILURE", "true") != "false": + raise RuntimeError( + f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {finished_row.rollout_status.termination_reason}" + ) + + completed_count += 1 + yield finished_row + + await processor_task # explicitly wait for task completion and catch any exceptions + + finally: + rollout_processor.cleanup() + + +def sanitize_filename(text: str) -> str: + """Sanitize text for use in filenames by replacing special characters with dashes.""" + safe = re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip()) + return safe[:120] + + +def extract_effort_tag(params: dict) -> Optional[str]: + """ + Extract effort tag from completion parameters for use in file naming. + + Args: + params: Completion parameters dictionary + + Returns: + Effort tag string if found, None otherwise + """ + try: + if not isinstance(params, dict): + return None + # Common locations + if "extra_body" in params and isinstance(params["extra_body"], dict): + eb = params["extra_body"] + if isinstance(eb.get("reasoning"), dict) and "effort" in eb["reasoning"]: + return str(eb["reasoning"]["effort"]).lower() + if "reasoning_effort" in eb: + return str(eb["reasoning_effort"]).lower() + if "reasoning" in params and isinstance(params["reasoning"], dict) and "effort" in params["reasoning"]: + return str(params["reasoning"]["effort"]).lower() + except Exception: + return None + return None diff --git a/examples/gpqa/tests/test_gpqa.py b/examples/gpqa/tests/test_gpqa.py index dcbf7b53..d67e64a1 100644 --- a/examples/gpqa/tests/test_gpqa.py +++ b/examples/gpqa/tests/test_gpqa.py @@ -7,7 +7,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult from eval_protocol.pytest.default_single_turn_rollout_process import ( - default_single_turn_rollout_processor, + SingleTurnRolloutProcessor, ) from eval_protocol.pytest.evaluation_test import evaluation_test @@ -66,7 +66,7 @@ def _load_gpqa_messages_from_csv() -> List[List[Message]]: completion_params=[ {"extra_body": {"reasoning_effort": "low"}, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"} ], # default to low effort; override via CLI plugin - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=8, diff --git a/examples/healthbench/tests/test_evaluation.py b/examples/healthbench/tests/test_evaluation.py index a40c5d96..e0c7917b 100644 --- a/examples/healthbench/tests/test_evaluation.py +++ b/examples/healthbench/tests/test_evaluation.py @@ -3,7 +3,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult from eval_protocol.pytest.default_single_turn_rollout_process import ( - default_single_turn_rollout_processor, + SingleTurnRolloutProcessor, ) from eval_protocol.pytest.evaluation_test import evaluation_test @@ -51,7 +51,7 @@ completion_params=[ {"temperature": 0.2, "max_tokens": 512, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"} ], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), aggregation_method="mean", passed_threshold=None, num_runs=1, diff --git a/tests/pytest/test_apps_coding.py b/tests/pytest/test_apps_coding.py index 7cb976ac..9350a381 100644 --- a/tests/pytest/test_apps_coding.py +++ b/tests/pytest/test_apps_coding.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List from eval_protocol.models import EvaluateResult, EvaluationRow, Message -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.apps_coding_reward import evaluate_apps_solution @@ -30,7 +30,7 @@ def apps_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio {"temperature": 0.0, "max_tokens": 4096, "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"} ], passed_threshold=0.33, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), num_runs=1, mode="pointwise", ) diff --git a/tests/pytest/test_basic_coding.py b/tests/pytest/test_basic_coding.py index 2b1c2a4a..4945d378 100644 --- a/tests/pytest/test_basic_coding.py +++ b/tests/pytest/test_basic_coding.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List from eval_protocol.models import EvaluateResult, EvaluationRow, Message -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.code_execution import execute_python_code, extract_code_blocks @@ -32,7 +32,7 @@ def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluat {"temperature": 0.0, "max_tokens": 4096, "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"} ], passed_threshold=0.8, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), num_runs=1, mode="pointwise", ) diff --git a/tests/pytest/test_frozen_lake.py b/tests/pytest/test_frozen_lake.py index bea42bed..24e32b56 100644 --- a/tests/pytest/test_frozen_lake.py +++ b/tests/pytest/test_frozen_lake.py @@ -9,7 +9,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message, MetricResult from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor +from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]: @@ -41,7 +41,7 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation completion_params=[ {"temperature": 0.0, "max_tokens": 4096, "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"} ], - rollout_processor=default_mcp_gym_rollout_processor, + rollout_processor=MCPGymRolloutProcessor(), passed_threshold=0.66, num_runs=1, max_concurrent_rollouts=3, diff --git a/tests/pytest/test_hallucination.py b/tests/pytest/test_hallucination.py index b29fb53c..fe8f32f0 100644 --- a/tests/pytest/test_hallucination.py +++ b/tests/pytest/test_hallucination.py @@ -12,7 +12,7 @@ import litellm from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test # Configure the judge model for LiteLLM JUDGE_MODEL = "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct" @@ -35,7 +35,7 @@ def hallucination_dataset_adapter(data: List[Dict[str, Any]]) -> List[Evaluation completion_params=[ {"temperature": 0.0, "max_tokens": 512, "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"} ], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), passed_threshold=0.33, num_runs=1, mode="pointwise", diff --git a/tests/pytest/test_lunar_lander.py b/tests/pytest/test_lunar_lander.py index 3fddac62..00f966a5 100644 --- a/tests/pytest/test_lunar_lander.py +++ b/tests/pytest/test_lunar_lander.py @@ -9,7 +9,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor +from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor def lunar_lander_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]: @@ -39,7 +39,7 @@ def lunar_lander_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio input_dataset=["tests/pytest/data/lunar_lander_dataset.jsonl"], dataset_adapter=lunar_lander_to_evaluation_row, completion_params=[{"temperature": 0.0, "max_tokens": 4096, "model": "gpt-4.1"}], - rollout_processor=default_mcp_gym_rollout_processor, + rollout_processor=MCPGymRolloutProcessor(), passed_threshold=0.0, num_runs=1, mode="pointwise", diff --git a/tests/pytest/test_markdown_highlighting.py b/tests/pytest/test_markdown_highlighting.py index 9c70721f..c393ee60 100644 --- a/tests/pytest/test_markdown_highlighting.py +++ b/tests/pytest/test_markdown_highlighting.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test def markdown_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]: @@ -32,7 +32,7 @@ def markdown_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu {"temperature": 0.0, "max_tokens": 4096, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"} ], passed_threshold=0.5, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), num_runs=1, mode="pointwise", ) diff --git a/tests/pytest/test_pytest_default_agent_rollout_processor.py b/tests/pytest/test_pytest_default_agent_rollout_processor.py index 8320ec8a..bfabe35c 100644 --- a/tests/pytest/test_pytest_default_agent_rollout_processor.py +++ b/tests/pytest/test_pytest_default_agent_rollout_processor.py @@ -2,7 +2,7 @@ from typing import List from eval_protocol.models import EvaluationRow, Message -from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test +from eval_protocol.pytest import AgentRolloutProcessor, evaluation_test @evaluation_test( @@ -16,7 +16,7 @@ ) ] ], - rollout_processor=default_agent_rollout_processor, + rollout_processor=AgentRolloutProcessor(), completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}], ) def test_pytest_default_agent_rollout_processor(rows: List[EvaluationRow]) -> List[EvaluationRow]: diff --git a/tests/pytest/test_pytest_ensure_logging.py b/tests/pytest/test_pytest_ensure_logging.py index 4300e1b4..3288204d 100644 --- a/tests/pytest/test_pytest_ensure_logging.py +++ b/tests/pytest/test_pytest_ensure_logging.py @@ -5,7 +5,7 @@ from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore from eval_protocol.models import EvaluationRow -from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor +from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row @@ -42,7 +42,7 @@ def read(self, rollout_id=None) -> List[EvaluationRow]: ], completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}], dataset_adapter=markdown_dataset_to_evaluation_row, - rollout_processor=default_no_op_rollout_processor, + rollout_processor=BaseRolloutProcessor(), mode="pointwise", combine_datasets=False, num_runs=2, diff --git a/tests/pytest/test_pytest_flaky_sometimes.py b/tests/pytest/test_pytest_flaky_sometimes.py index 65e1e63d..730a0a81 100644 --- a/tests/pytest/test_pytest_flaky_sometimes.py +++ b/tests/pytest/test_pytest_flaky_sometimes.py @@ -5,7 +5,7 @@ import pytest from eval_protocol.models import EvaluateResult, EvaluationRow, Message -from eval_protocol.pytest import default_no_op_rollout_processor, evaluation_test +from eval_protocol.pytest import BaseRolloutProcessor, evaluation_test # skip in CI since it will intentionally fail. This is useful for local generation of logs @@ -13,7 +13,7 @@ @evaluation_test( input_messages=[[Message(role="user", content="Return HEADS or TAILS at random.")]], completion_params=[{"model": "dummy/local-model"}], - rollout_processor=default_no_op_rollout_processor, + rollout_processor=BaseRolloutProcessor(), mode="pointwise", num_runs=5, ) diff --git a/tests/pytest/test_pytest_function_calling.py b/tests/pytest/test_pytest_function_calling.py index 63488dbe..60f38b0d 100644 --- a/tests/pytest/test_pytest_function_calling.py +++ b/tests/pytest/test_pytest_function_calling.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List from eval_protocol.models import EvaluationRow -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.function_calling import exact_tool_match_reward @@ -23,7 +23,7 @@ def function_calling_to_evaluation_row(rows: List[Dict[str, Any]]) -> List[Evalu completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], mode="pointwise", dataset_adapter=function_calling_to_evaluation_row, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), ) async def test_pytest_function_calling(row: EvaluationRow) -> EvaluationRow: """Run pointwise evaluation on sample dataset using pytest interface.""" diff --git a/tests/pytest/test_pytest_ids.py b/tests/pytest/test_pytest_ids.py index 045d2a19..69068ab9 100644 --- a/tests/pytest/test_pytest_ids.py +++ b/tests/pytest/test_pytest_ids.py @@ -3,7 +3,7 @@ import eval_protocol.dataset_logger as dataset_logger from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.models import EvaluationRow -from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor +from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row @@ -30,7 +30,7 @@ async def test_evaluation_test_decorator(monkeypatch): ], completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}], dataset_adapter=markdown_dataset_to_evaluation_row, - rollout_processor=default_no_op_rollout_processor, + rollout_processor=BaseRolloutProcessor(), mode="pointwise", combine_datasets=False, num_runs=2, @@ -71,7 +71,7 @@ async def test_evaluation_test_decorator_ids_single(monkeypatch): {"temperature": 1.0, "model": "dummy/local-model"}, ], dataset_adapter=markdown_dataset_to_evaluation_row, - rollout_processor=default_no_op_rollout_processor, + rollout_processor=BaseRolloutProcessor(), mode="pointwise", combine_datasets=False, num_runs=5, diff --git a/tests/pytest/test_pytest_input_messages.py b/tests/pytest/test_pytest_input_messages.py index dc460aa5..7b4f8d9e 100644 --- a/tests/pytest/test_pytest_input_messages.py +++ b/tests/pytest/test_pytest_input_messages.py @@ -1,7 +1,7 @@ from typing import List from eval_protocol.models import EvaluationRow, Message -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test @evaluation_test( @@ -11,7 +11,7 @@ ] ], completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), ) def test_input_messages_in_decorator(rows: List[EvaluationRow]) -> List[EvaluationRow]: """Run math evaluation on sample dataset using pytest interface.""" diff --git a/tests/pytest/test_pytest_json_schema.py b/tests/pytest/test_pytest_json_schema.py index 158874f1..c5a20c5d 100644 --- a/tests/pytest/test_pytest_json_schema.py +++ b/tests/pytest/test_pytest_json_schema.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List from eval_protocol.models import EvaluationRow -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.json_schema import json_schema_reward @@ -26,7 +26,7 @@ def json_schema_to_evaluation_row(rows: List[Dict[str, Any]]) -> List[Evaluation input_dataset=["tests/pytest/data/json_schema.jsonl"], completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], mode="pointwise", - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), dataset_adapter=json_schema_to_evaluation_row, ) async def test_pytest_function_calling(row: EvaluationRow) -> EvaluationRow: diff --git a/tests/pytest/test_pytest_math_example.py b/tests/pytest/test_pytest_math_example.py index 23010797..55c525be 100644 --- a/tests/pytest/test_pytest_math_example.py +++ b/tests/pytest/test_pytest_math_example.py @@ -1,5 +1,5 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.math import math_reward from examples.math_example.main import check_think_answer_format from tests.pytest.helper.gsm8k_to_evaluation_row import gsm8k_to_evaluation_row @@ -11,7 +11,7 @@ completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], max_dataset_rows=5, passed_threshold=0.0, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), mode="pointwise", evaluation_test_kwargs=[ {"math_reward_kwargs": {"tolerance": 0.001, "absolute_tolerance": 1e-8, "require_units": False}} diff --git a/tests/pytest/test_pytest_math_format_length.py b/tests/pytest/test_pytest_math_format_length.py index 5bba5c0e..3da732a0 100644 --- a/tests/pytest/test_pytest_math_format_length.py +++ b/tests/pytest/test_pytest_math_format_length.py @@ -1,7 +1,7 @@ import math from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from eval_protocol.rewards.length import count_tokens from eval_protocol.rewards.math import math_reward from examples.math_with_format_and_length.main import check_think_answer_format @@ -14,7 +14,7 @@ completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], max_dataset_rows=5, passed_threshold=0.0, - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), mode="pointwise", evaluation_test_kwargs=[ { diff --git a/tests/pytest/test_pytest_mcp_config.py b/tests/pytest/test_pytest_mcp_config.py index dde15aa9..c578d07c 100644 --- a/tests/pytest/test_pytest_mcp_config.py +++ b/tests/pytest/test_pytest_mcp_config.py @@ -2,7 +2,7 @@ from typing import List from eval_protocol.models import EvaluateResult, EvaluationRow, Message -from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test +from eval_protocol.pytest import AgentRolloutProcessor, evaluation_test @evaluation_test( @@ -19,7 +19,7 @@ ) ] ], - rollout_processor=default_agent_rollout_processor, + rollout_processor=AgentRolloutProcessor(), completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-20b"}], mode="pointwise", mcp_config_path="tests/pytest/mcp_configurations/mock_discord_mcp_config.json", diff --git a/tests/pytest/test_pytest_mcp_url.py b/tests/pytest/test_pytest_mcp_url.py index 01c06c45..ce265da5 100644 --- a/tests/pytest/test_pytest_mcp_url.py +++ b/tests/pytest/test_pytest_mcp_url.py @@ -1,5 +1,5 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, Message -from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test +from eval_protocol.pytest import AgentRolloutProcessor, evaluation_test @evaluation_test( @@ -18,7 +18,7 @@ ), ] ], - rollout_processor=default_agent_rollout_processor, + rollout_processor=AgentRolloutProcessor(), completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}], mode="pointwise", mcp_config_path="tests/pytest/mcp_configurations/docs_mcp_config.json", diff --git a/tests/pytest/test_pytest_word_count_example.py b/tests/pytest/test_pytest_word_count_example.py index 339c5152..72c9bc2f 100644 --- a/tests/pytest/test_pytest_word_count_example.py +++ b/tests/pytest/test_pytest_word_count_example.py @@ -1,7 +1,7 @@ from haikus import haikus from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult -from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test from tests.pytest.helper.word_count_to_evaluation_row import word_count_to_evaluation_row @@ -11,7 +11,7 @@ completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], max_dataset_rows=5, passed_threshold=0.3, # Reasonable threshold for word count evaluation - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=SingleTurnRolloutProcessor(), mode="pointwise", # Use pointwise mode for elegant row-by-row evaluation ) def test_word_count_evaluate(row: EvaluationRow) -> EvaluationRow: diff --git a/tests/pytest/test_tau_bench_airline.py b/tests/pytest/test_tau_bench_airline.py index 0eeba626..f3a7c65f 100644 --- a/tests/pytest/test_tau_bench_airline.py +++ b/tests/pytest/test_tau_bench_airline.py @@ -12,7 +12,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor +from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor from vendor.tau2.data_model.message import ( AssistantMessage, SystemMessage, @@ -72,7 +72,7 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", } ], - rollout_processor=default_mcp_gym_rollout_processor, + rollout_processor=MCPGymRolloutProcessor(), passed_threshold={"success": 0.4, "standard_deviation": 0.1}, num_runs=8, mode="pointwise", diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index a937192a..9529e80e 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -10,6 +10,7 @@ from typing import AsyncIterator, List from eval_protocol.models import EvaluateResult, EvaluationRow, Message, RolloutStatus +from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor from eval_protocol.pytest.evaluation_test import evaluation_test from eval_protocol.pytest.types import RolloutProcessorConfig @@ -19,38 +20,40 @@ timing_results = [] # Collect timing data for assertions -def mock_rollout_processor_with_retries( - rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> List[asyncio.Task[EvaluationRow]]: +class MockRolloutProcessorWithRetries(BaseRolloutProcessor): """Mock rollout processor that fails second task alphabetically on first attempt, succeeds on retry""" - row_setup = { - 0: {"delay": 3.0, "should_fail": False}, - 1: {"delay": 3.0, "should_fail": True}, - 2: {"delay": 5.0, "should_fail": False}, - 3: {"delay": 5.0, "should_fail": False}, - 4: {"delay": 5.0, "should_fail": False}, - } - async def process_single_row(row: EvaluationRow, delay: float, should_fail: bool = False) -> EvaluationRow: - await asyncio.sleep(delay) + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + row_setup = { + 0: {"delay": 3.0, "should_fail": False}, + 1: {"delay": 3.0, "should_fail": True}, + 2: {"delay": 5.0, "should_fail": False}, + 3: {"delay": 5.0, "should_fail": False}, + 4: {"delay": 5.0, "should_fail": False}, + } - elapsed = time.time() - start_time - print( - f"šŸŽ‰ FINISHED {'error' if should_fail else 'finished'} at {elapsed:.2f}s: {row.execution_metadata.rollout_id}" - ) + async def process_single_row(row: EvaluationRow, delay: float, should_fail: bool = False) -> EvaluationRow: + await asyncio.sleep(delay) - if should_fail: - raise Exception("Simulated failure for testing") + elapsed = time.time() - start_time + print( + f"šŸŽ‰ FINISHED {'error' if should_fail else 'finished'} at {elapsed:.2f}s: {row.execution_metadata.rollout_id}" + ) - return row + if should_fail: + raise Exception("Simulated failure for testing") - # Create and return tasks (let evaluation_test handle them) - tasks = [ - asyncio.create_task(process_single_row(row, row_setup[i]["delay"], row_setup[i]["should_fail"])) - for i, row in enumerate(rows) - ] + return row - return tasks + # Create and return tasks (let evaluation_test handle them) + tasks = [ + asyncio.create_task(process_single_row(row, row_setup[i]["delay"], row_setup[i]["should_fail"])) + for i, row in enumerate(rows) + ] + + return tasks + + # Inherits cleanup() from BaseRolloutProcessor - no override needed @evaluation_test( @@ -62,7 +65,7 @@ async def process_single_row(row: EvaluationRow, delay: float, should_fail: bool [Message(role="user", content="Task D")], [Message(role="user", content="Task E")], ], - rollout_processor=mock_rollout_processor_with_retries, + rollout_processor=MockRolloutProcessorWithRetries(), num_runs=1, mode="pointwise", ) diff --git a/tests/test_tau_bench_airline_smoke.py b/tests/test_tau_bench_airline_smoke.py index 200f7ca8..044447b7 100644 --- a/tests/test_tau_bench_airline_smoke.py +++ b/tests/test_tau_bench_airline_smoke.py @@ -13,7 +13,7 @@ from eval_protocol.models import CompletionParams, EvaluateResult, EvaluationRow, InputMetadata, Message from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor +from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor from vendor.tau2.data_model.message import ( AssistantMessage, SystemMessage, @@ -72,7 +72,7 @@ def tau_bench_airline_smoke_to_evaluation_row(data: List[Dict[str, Any]]) -> Lis "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", } ], - rollout_processor=default_mcp_gym_rollout_processor, + rollout_processor=MCPGymRolloutProcessor(), passed_threshold=0.36, num_runs=1, # Smoke test: single run for quick feedback mode="pointwise", From 99bcb4d0bf2999628cab41b93acdc10d366a8511 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 14 Aug 2025 01:57:55 -0700 Subject: [PATCH 4/9] cleaning up mcp gym --- .../default_mcp_gym_rollout_processor.py | 39 ++++++++----------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 491e4298..a06121ae 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -204,7 +204,8 @@ class MCPGymRolloutProcessor(BaseRolloutProcessor): """ def __init__(self): - self.current_run_state: Dict[str, Any] = {} + self.server = None + self.policy = None def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: """Process evaluation rows with MCP gym environments.""" @@ -215,52 +216,43 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> if config.server_script_path is None: raise ValueError("server_script_path is required for MCPGymRolloutProcessor") - server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {})) + self.server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {})) try: - server.start() + self.server.start() - policy = ep.LiteLLMPolicy( + self.policy = ep.LiteLLMPolicy( model_id=config.completion_params.get("model", None), temperature=config.completion_params.get("temperature", 0.0), max_tokens=config.completion_params.get("max_tokens", 4096), reasoning_effort=config.completion_params.get("reasoning_effort", None), ) - # Store in instance state for cleanup - self.current_run_state.update( - { - "server": server, - "policy": policy, - } - ) - except Exception as e: - server.stop() - self.current_run_state.clear() + if self.server: + self.server.stop() + self.server = None + self.policy = None raise e else: # Reuse existing MCP environments for retry - if not self.current_run_state: + if not self.server or not self.policy: raise RuntimeError( "Cannot retry without existing server/environments. Call with start_server=True first." ) - server = self.current_run_state["server"] - policy = self.current_run_state["policy"] - # Create MCP environments directly from evaluation_rows envs = ep.make( "http://localhost:9700/mcp/", evaluation_rows=rows, - model_id=policy.model_id, + model_id=self.policy.model_id, ) # Get rollout tasks from ep.rollout tasks = ep.rollout( envs, - policy=policy, + policy=self.policy, evaluation_rows=rows, steps=config.steps, max_concurrent_rollouts=config.max_concurrent_rollouts, @@ -269,6 +261,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> def cleanup(self) -> None: """Cleanup MCP server and environments.""" - if self.current_run_state and "server" in self.current_run_state: - self.current_run_state["server"].stop() - self.current_run_state.clear() + if self.server: + self.server.stop() + self.server = None + self.policy = None From 57a28d04e5848d806633c999323f88be187c50dd Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 14 Aug 2025 02:10:45 -0700 Subject: [PATCH 5/9] remove import --- tests/pytest/test_pytest_ensure_logging.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pytest/test_pytest_ensure_logging.py b/tests/pytest/test_pytest_ensure_logging.py index 518cb67c..22a11338 100644 --- a/tests/pytest/test_pytest_ensure_logging.py +++ b/tests/pytest/test_pytest_ensure_logging.py @@ -25,7 +25,6 @@ async def test_ensure_logging(monkeypatch): "eval_protocol.dataset_logger.sqlite_dataset_logger_adapter.SqliteEvaluationRowStore", return_value=mock_store ): from eval_protocol.models import EvaluationRow - from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor from eval_protocol.pytest.evaluation_test import evaluation_test from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row From b1eaf1ec6680351f39bbe08abbe8e89b574c12f0 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 14 Aug 2025 09:47:59 -0700 Subject: [PATCH 6/9] Update --- eval_protocol/benchmarks/suites/gpqa.py | 4 ++-- eval_protocol/pytest/__init__.py | 6 ++++-- .../pytest/default_agent_rollout_processor.py | 4 ++-- .../default_mcp_gym_rollout_processor.py | 10 ++++----- ....py => default_no_op_rollout_processor.py} | 16 ++++++-------- .../default_single_turn_rollout_process.py | 7 +++---- eval_protocol/pytest/evaluation_test.py | 7 ++++--- eval_protocol/pytest/rollout_processor.py | 21 +++++++++++++++++++ eval_protocol/pytest/utils.py | 4 ++-- tests/pytest/test_pytest_ensure_logging.py | 4 ++-- tests/pytest/test_pytest_flaky_sometimes.py | 4 ++-- tests/pytest/test_pytest_ids.py | 6 +++--- tests/test_retry_mechanism.py | 6 ++---- 13 files changed, 57 insertions(+), 42 deletions(-) rename eval_protocol/pytest/{default_base_rollout_process.py => default_no_op_rollout_processor.py} (57%) create mode 100644 eval_protocol/pytest/rollout_processor.py diff --git a/eval_protocol/benchmarks/suites/gpqa.py b/eval_protocol/benchmarks/suites/gpqa.py index 531dfcaa..ced8ac9f 100644 --- a/eval_protocol/benchmarks/suites/gpqa.py +++ b/eval_protocol/benchmarks/suites/gpqa.py @@ -8,11 +8,11 @@ from eval_protocol.benchmarks.registry import export_benchmark from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult -from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor from eval_protocol.pytest.default_single_turn_rollout_process import ( SingleTurnRolloutProcessor, ) from eval_protocol.pytest.evaluation_test import evaluation_test +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig SYSTEM_PROMPT = ( @@ -63,7 +63,7 @@ def _strip_gt_messages(msgs: List[Message]) -> List[Message]: return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))] -class GPQAStripGTRolloutProcessor(BaseRolloutProcessor): +class GPQAStripGTRolloutProcessor(RolloutProcessor): """Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to SingleTurnRolloutProcessor.""" def __init__(self): diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index 7b5de324..171fa3dc 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -1,16 +1,18 @@ from .default_agent_rollout_processor import AgentRolloutProcessor -from .default_base_rollout_process import BaseRolloutProcessor from .default_dataset_adapter import default_dataset_adapter from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor +from .default_no_op_rollout_processor import NoOpRolloutProcessor from .default_single_turn_rollout_process import SingleTurnRolloutProcessor from .evaluation_test import evaluation_test +from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig __all__ = [ "AgentRolloutProcessor", "MCPGymRolloutProcessor", - "BaseRolloutProcessor", + "RolloutProcessor", "SingleTurnRolloutProcessor", + "NoOpRolloutProcessor", "default_dataset_adapter", "RolloutProcessorConfig", "evaluation_test", diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index f89a888a..65428b4b 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -13,7 +13,7 @@ from eval_protocol.mcp.execution.policy import LiteLLMPolicy from eval_protocol.mcp.mcp_multi_client import MCPMultiClient from eval_protocol.models import EvaluationRow, Message -from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig logger = logging.getLogger(__name__) @@ -116,7 +116,7 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[Tex return tool_result.content -class AgentRolloutProcessor(BaseRolloutProcessor): +class AgentRolloutProcessor(RolloutProcessor): """Agent rollout processor for tool-calling agents.""" def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index a06121ae..b7376e9c 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -6,15 +6,13 @@ import subprocess import time from pathlib import Path -from typing import Any, AsyncIterator, Dict, List, Optional +from typing import List, Optional import eval_protocol as ep -from eval_protocol.models import EvaluationRow, Message -from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor +from eval_protocol.models import EvaluationRow +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig -CURRENT_RUN_STATE: Dict[str, Any] = {} - class MCPServerManager: """Manages MCP server lifecycle for testing.""" @@ -195,7 +193,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False # Don't suppress exceptions -class MCPGymRolloutProcessor(BaseRolloutProcessor): +class MCPGymRolloutProcessor(RolloutProcessor): """ Rollout processor for tau bench environments. diff --git a/eval_protocol/pytest/default_base_rollout_process.py b/eval_protocol/pytest/default_no_op_rollout_processor.py similarity index 57% rename from eval_protocol/pytest/default_base_rollout_process.py rename to eval_protocol/pytest/default_no_op_rollout_processor.py index be3597e6..973d6083 100644 --- a/eval_protocol/pytest/default_base_rollout_process.py +++ b/eval_protocol/pytest/default_no_op_rollout_processor.py @@ -2,18 +2,16 @@ from typing import List from eval_protocol.models import EvaluationRow +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig -class BaseRolloutProcessor: +class NoOpRolloutProcessor(RolloutProcessor): """ - Base rollout processor - minimal implementation that all others inherit from. + No-op rollout processor that passes input dataset through unchanged. - This is the Strategy pattern base class. It provides: - 1. __call__(rows, config) -> tasks (the main interface) - 2. cleanup() -> None (resource cleanup) - - All other processors inherit from this and override as needed. + Simply returns the input rows as completed tasks. This is useful for testing + or when you want to handle rollout processing manually. """ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: @@ -26,6 +24,4 @@ async def return_row(row: EvaluationRow) -> EvaluationRow: tasks = [asyncio.create_task(return_row(row)) for row in rows] return tasks - def cleanup(self) -> None: - """No-op cleanup - override in subclasses if needed.""" - pass + # Inherits cleanup() from RolloutProcessor - no override needed diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index cb7ea347..bf43b7da 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -2,21 +2,20 @@ import logging import os import time -from typing import AsyncIterator, List +from typing import List -import litellm from litellm import acompletion from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall from eval_protocol.dataset_logger import default_logger from eval_protocol.models import EvaluationRow, Message -from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig logger = logging.getLogger(__name__) -class SingleTurnRolloutProcessor(BaseRolloutProcessor): +class SingleTurnRolloutProcessor(RolloutProcessor): """Single turn rollout processor for direct LLM calls.""" def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 70f976fd..c24118bb 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -24,8 +24,9 @@ InputMetadata, Message, ) -from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter +from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import ( Dataset, DatasetPathParam, @@ -61,7 +62,7 @@ def evaluation_test( # noqa: C901 input_messages: Optional[List[InputMessagesParam]] = None, input_dataset: Optional[List[DatasetPathParam]] = None, dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter, - rollout_processor: BaseRolloutProcessor = BaseRolloutProcessor(), + rollout_processor: RolloutProcessor = NoOpRolloutProcessor(), evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None, rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None, aggregation_method: AggregationMethod = "mean", @@ -749,7 +750,7 @@ def run_evaluation_test_direct( input_dataset: Optional[List[DatasetPathParam]] = None, dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter, completion_params: Optional[CompletionParams] = None, - rollout_processor: BaseRolloutProcessor = BaseRolloutProcessor(), + rollout_processor: RolloutProcessor = NoOpRolloutProcessor(), rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None, aggregation_method: AggregationMethod = "mean", passed_threshold: Optional[Union[EvaluationThreshold, float]] = None, diff --git a/eval_protocol/pytest/rollout_processor.py b/eval_protocol/pytest/rollout_processor.py new file mode 100644 index 00000000..824dd015 --- /dev/null +++ b/eval_protocol/pytest/rollout_processor.py @@ -0,0 +1,21 @@ +import asyncio +from abc import ABC, abstractmethod +from typing import List + +from eval_protocol.models import EvaluationRow +from eval_protocol.pytest.types import RolloutProcessorConfig + + +class RolloutProcessor(ABC): + """ + Abstract base class for all rollout processor strategies. + """ + + @abstractmethod + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Process evaluation rows and return async tasks. Must be implemented by subclasses.""" + pass + + def cleanup(self) -> None: + """Cleanup resources. Override in subclasses if cleanup is needed.""" + pass diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 0d3f926b..24b60028 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -7,7 +7,7 @@ from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.models import EvalMetadata, EvaluationRow -from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import ( CompletionParams, DatasetPathParam, @@ -234,7 +234,7 @@ def generate_parameter_combinations( async def rollout_processor_with_retry( - rollout_processor: BaseRolloutProcessor, + rollout_processor: RolloutProcessor, fresh_dataset: List[EvaluationRow], config: RolloutProcessorConfig, max_retry: int, diff --git a/tests/pytest/test_pytest_ensure_logging.py b/tests/pytest/test_pytest_ensure_logging.py index 22a11338..afcafd23 100644 --- a/tests/pytest/test_pytest_ensure_logging.py +++ b/tests/pytest/test_pytest_ensure_logging.py @@ -5,7 +5,7 @@ from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore from eval_protocol.models import EvaluationRow -from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor +from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row @@ -34,7 +34,7 @@ async def test_ensure_logging(monkeypatch): ], completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}], dataset_adapter=markdown_dataset_to_evaluation_row, - rollout_processor=BaseRolloutProcessor(), + rollout_processor=NoOpRolloutProcessor(), mode="pointwise", combine_datasets=False, num_runs=2, diff --git a/tests/pytest/test_pytest_flaky_sometimes.py b/tests/pytest/test_pytest_flaky_sometimes.py index 730a0a81..bde5e34c 100644 --- a/tests/pytest/test_pytest_flaky_sometimes.py +++ b/tests/pytest/test_pytest_flaky_sometimes.py @@ -5,7 +5,7 @@ import pytest from eval_protocol.models import EvaluateResult, EvaluationRow, Message -from eval_protocol.pytest import BaseRolloutProcessor, evaluation_test +from eval_protocol.pytest import NoOpRolloutProcessor, evaluation_test # skip in CI since it will intentionally fail. This is useful for local generation of logs @@ -13,7 +13,7 @@ @evaluation_test( input_messages=[[Message(role="user", content="Return HEADS or TAILS at random.")]], completion_params=[{"model": "dummy/local-model"}], - rollout_processor=BaseRolloutProcessor(), + rollout_processor=NoOpRolloutProcessor(), mode="pointwise", num_runs=5, ) diff --git a/tests/pytest/test_pytest_ids.py b/tests/pytest/test_pytest_ids.py index 69068ab9..b6bb4a35 100644 --- a/tests/pytest/test_pytest_ids.py +++ b/tests/pytest/test_pytest_ids.py @@ -3,7 +3,7 @@ import eval_protocol.dataset_logger as dataset_logger from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.models import EvaluationRow -from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor +from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row @@ -30,7 +30,7 @@ async def test_evaluation_test_decorator(monkeypatch): ], completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}], dataset_adapter=markdown_dataset_to_evaluation_row, - rollout_processor=BaseRolloutProcessor(), + rollout_processor=NoOpRolloutProcessor(), mode="pointwise", combine_datasets=False, num_runs=2, @@ -71,7 +71,7 @@ async def test_evaluation_test_decorator_ids_single(monkeypatch): {"temperature": 1.0, "model": "dummy/local-model"}, ], dataset_adapter=markdown_dataset_to_evaluation_row, - rollout_processor=BaseRolloutProcessor(), + rollout_processor=NoOpRolloutProcessor(), mode="pointwise", combine_datasets=False, num_runs=5, diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index 9529e80e..73937c8a 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -10,8 +10,8 @@ from typing import AsyncIterator, List from eval_protocol.models import EvaluateResult, EvaluationRow, Message, RolloutStatus -from eval_protocol.pytest.default_base_rollout_process import BaseRolloutProcessor from eval_protocol.pytest.evaluation_test import evaluation_test +from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig os.environ["EP_MAX_RETRY"] = "2" # Allow up to 2 retries @@ -20,7 +20,7 @@ timing_results = [] # Collect timing data for assertions -class MockRolloutProcessorWithRetries(BaseRolloutProcessor): +class MockRolloutProcessorWithRetries(RolloutProcessor): """Mock rollout processor that fails second task alphabetically on first attempt, succeeds on retry""" def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: @@ -53,8 +53,6 @@ async def process_single_row(row: EvaluationRow, delay: float, should_fail: bool return tasks - # Inherits cleanup() from BaseRolloutProcessor - no override needed - @evaluation_test( completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], From 0b637de8192e245740b460048383a9e67b7a1152 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 14 Aug 2025 10:15:12 -0700 Subject: [PATCH 7/9] failing test --- eval_protocol/pytest/evaluation_test.py | 3 +- tests/pytest/test_pytest_ensure_logging.py | 78 ++++++++++------------ 2 files changed, 37 insertions(+), 44 deletions(-) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index c24118bb..6127c7b9 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -471,11 +471,10 @@ async def _execute_with_semaphore(row): passed = success_passed and std_passed - # Update eval metadata status and passed field for all results + # Update eval metadata passed field for all results for result in all_results: for r in result: if r.eval_metadata is not None: - r.eval_metadata.status = "finished" # TODO: might not be needed r.eval_metadata.passed = passed active_logger.log(r) diff --git a/tests/pytest/test_pytest_ensure_logging.py b/tests/pytest/test_pytest_ensure_logging.py index afcafd23..e57b3c8c 100644 --- a/tests/pytest/test_pytest_ensure_logging.py +++ b/tests/pytest/test_pytest_ensure_logging.py @@ -1,13 +1,6 @@ import os from unittest.mock import Mock, patch -import eval_protocol.dataset_logger as dataset_logger -from eval_protocol.dataset_logger.dataset_logger import DatasetLogger -from eval_protocol.dataset_logger.sqlite_evaluation_row_store import SqliteEvaluationRowStore -from eval_protocol.models import EvaluationRow -from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor -from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row - async def test_ensure_logging(monkeypatch): """ @@ -25,41 +18,42 @@ async def test_ensure_logging(monkeypatch): "eval_protocol.dataset_logger.sqlite_dataset_logger_adapter.SqliteEvaluationRowStore", return_value=mock_store ): from eval_protocol.models import EvaluationRow + from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor from eval_protocol.pytest.evaluation_test import evaluation_test from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row - @evaluation_test( - input_dataset=[ - "tests/pytest/data/markdown_dataset.jsonl", - ], - completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}], - dataset_adapter=markdown_dataset_to_evaluation_row, - rollout_processor=NoOpRolloutProcessor(), - mode="pointwise", - combine_datasets=False, - num_runs=2, - # Don't pass logger parameter - let it use the default_logger (which we've replaced) - ) - def eval_fn(row: EvaluationRow) -> EvaluationRow: - return row - - await eval_fn( - dataset_path=["tests/pytest/data/markdown_dataset.jsonl"], - completion_params={"temperature": 0.0, "model": "dummy/local-model"}, - ) - - # Verify that the store's upsert_row method was called - assert mock_store.upsert_row.called, "SqliteEvaluationRowStore.upsert_row should have been called" - - # Check that it was called multiple times (once for each row) - call_count = mock_store.upsert_row.call_count - assert call_count > 0, f"Expected upsert_row to be called at least once, but it was called {call_count} times" - - # Verify the calls were made with proper data structure - for call in mock_store.upsert_row.call_args_list: - args, kwargs = call - data = args[0] if args else kwargs.get("data") - assert data is not None, "upsert_row should be called with data parameter" - assert isinstance(data, dict), "data should be a dictionary" - assert "execution_metadata" in data, "data should contain execution_metadata" - assert "rollout_id" in data["execution_metadata"], "data should contain rollout_id in execution_metadata" + @evaluation_test( + input_dataset=[ + "tests/pytest/data/markdown_dataset.jsonl", + ], + completion_params=[{"temperature": 0.0, "model": "dummy/local-model"}], + dataset_adapter=markdown_dataset_to_evaluation_row, + rollout_processor=NoOpRolloutProcessor(), + mode="pointwise", + combine_datasets=False, + num_runs=2, + # Don't pass logger parameter - let it use the default_logger (which we've replaced) + ) + def eval_fn(row: EvaluationRow) -> EvaluationRow: + return row + + await eval_fn( + dataset_path=["tests/pytest/data/markdown_dataset.jsonl"], + completion_params={"temperature": 0.0, "model": "dummy/local-model"}, + ) + + # Verify that the store's upsert_row method was called + assert mock_store.upsert_row.called, "SqliteEvaluationRowStore.upsert_row should have been called" + + # Check that it was called multiple times (once for each row) + call_count = mock_store.upsert_row.call_count + assert call_count > 0, f"Expected upsert_row to be called at least once, but it was called {call_count} times" + + # Verify the calls were made with proper data structure + for call in mock_store.upsert_row.call_args_list: + args, kwargs = call + data = args[0] if args else kwargs.get("data") + assert data is not None, "upsert_row should be called with data parameter" + assert isinstance(data, dict), "data should be a dictionary" + assert "execution_metadata" in data, "data should contain execution_metadata" + assert "rollout_id" in data["execution_metadata"], "data should contain rollout_id in execution_metadata" From 8d0a57d1612a27214a3ce8bfae22f81b47f8e3f2 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 14 Aug 2025 10:40:21 -0700 Subject: [PATCH 8/9] fixing flaky test --- tests/test_retry_mechanism.py | 170 +++++++++++++++++++--------------- 1 file changed, 95 insertions(+), 75 deletions(-) diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index 73937c8a..a483f0e1 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -5,9 +5,11 @@ import asyncio import os -import time -from dataclasses import dataclass -from typing import AsyncIterator, List +from collections import Counter +from typing import List +from unittest.mock import Mock + +import pytest from eval_protocol.models import EvaluateResult, EvaluationRow, Message, RolloutStatus from eval_protocol.pytest.evaluation_test import evaluation_test @@ -16,29 +18,47 @@ os.environ["EP_MAX_RETRY"] = "2" # Allow up to 2 retries -start_time = time.time() -timing_results = [] # Collect timing data for assertions - class MockRolloutProcessorWithRetries(RolloutProcessor): """Mock rollout processor that fails second task alphabetically on first attempt, succeeds on retry""" + def __init__(self): + self.mock_tracker = Mock() + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + # Track this batch call + self.mock_tracker.batch_call(len(rows)) + row_setup = { - 0: {"delay": 3.0, "should_fail": False}, - 1: {"delay": 3.0, "should_fail": True}, - 2: {"delay": 5.0, "should_fail": False}, - 3: {"delay": 5.0, "should_fail": False}, - 4: {"delay": 5.0, "should_fail": False}, + 0: {"delay": 0.01, "should_fail": False}, + 1: {"delay": 0.01, "should_fail": True}, # Will be adjusted based on attempt number + 2: {"delay": 0.01, "should_fail": False}, + 3: {"delay": 0.01, "should_fail": False}, + 4: {"delay": 0.01, "should_fail": False}, } - async def process_single_row(row: EvaluationRow, delay: float, should_fail: bool = False) -> EvaluationRow: - await asyncio.sleep(delay) + async def process_single_row( + row: EvaluationRow, delay: float, base_should_fail: bool = False + ) -> EvaluationRow: + rollout_id = row.execution_metadata.rollout_id + + # Track individual row processing call + self.mock_tracker.process_row_call(rollout_id) + + # Determine attempt number by counting previous calls for this rollout_id + previous_calls = [ + call for call in self.mock_tracker.process_row_call.call_args_list if call[0][0] == rollout_id + ] + attempt_number = len(previous_calls) - elapsed = time.time() - start_time - print( - f"šŸŽ‰ FINISHED {'error' if should_fail else 'finished'} at {elapsed:.2f}s: {row.execution_metadata.rollout_id}" - ) + # Determine if this specific attempt should fail + # Row 1 fails on first attempt (attempt_number == 1), succeeds on retry (attempt_number == 2) + should_fail = base_should_fail and attempt_number == 1 + + print(f"šŸ”„ ATTEMPTING rollout_id={rollout_id}, attempt={attempt_number}, will_fail={should_fail}") + + await asyncio.sleep(delay) + print(f"šŸŽ‰ FINISHED {'error' if should_fail else 'finished'}: {row.execution_metadata.rollout_id}") if should_fail: raise Exception("Simulated failure for testing") @@ -54,6 +74,10 @@ async def process_single_row(row: EvaluationRow, delay: float, should_fail: bool return tasks +# Create a shared processor instance for testing +shared_processor = MockRolloutProcessorWithRetries() + + @evaluation_test( completion_params=[{"model": "gpt-4o-mini", "temperature": 0}], input_messages=[ @@ -63,16 +87,14 @@ async def process_single_row(row: EvaluationRow, delay: float, should_fail: bool [Message(role="user", content="Task D")], [Message(role="user", content="Task E")], ], - rollout_processor=MockRolloutProcessorWithRetries(), + rollout_processor=shared_processor, num_runs=1, mode="pointwise", ) def test_retry_mechanism(row: EvaluationRow) -> EvaluationRow: - """MOCK TEST: first 2 rows take 3s, last 3 take 5s, second row fails on first attempt, succeeds on retry. Should take around 6s total.""" - # Just print the timing - we'll parse it from output - elapsed = time.time() - start_time + """MOCK TEST: Tests that retry mechanism works - one task fails on first attempt, succeeds on retry.""" print( - f"šŸ“Š EVALUATED at {elapsed:.2f}s: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" + f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" ) # Assign a score based on success/failure @@ -82,56 +104,54 @@ def test_retry_mechanism(row: EvaluationRow) -> EvaluationRow: return row -def test_timing_assertions(): - """Validate that timing results match expected pipeline behavior""" - global start_time - - # Reset and run the evaluation test - start_time = time.time() - - # Capture pytest output - import subprocess - import sys - - result = subprocess.run( - [sys.executable, "-m", "pytest", __file__ + "::test_retry_mechanism", "-v", "-s"], - capture_output=True, - text=True, - cwd=os.getcwd(), - ) - - print(result.stdout) # Show the original output - - # Parse timing from output - import re - - timing_results = [] - for line in result.stdout.split("\n"): - match = re.search(r"šŸ“Š EVALUATED at (\d+\.\d+)s:", line) - if match: - timing_results.append(float(match.group(1))) - - print(f"\nšŸ“Š PIPELINE TIMING ANALYSIS:") - print(f" Results received at: {[f'{t:.2f}s' for t in sorted(timing_results)]}") - - # Assertions for expected timing behavior - sorted_times = sorted(timing_results) - - assert len(sorted_times) == 5, f"Expected 5 evaluation results, got {len(sorted_times)}" - - # First result should be around 3s (row 0 success) - assert 2.5 <= sorted_times[0] <= 3.5, f"First result at {sorted_times[0]:.2f}s, expected ~3s" - - # Next three results should be around 5s (rows 2,3,4) - assert 4.5 <= sorted_times[1] <= 5.5, f"Second result at {sorted_times[1]:.2f}s, expected ~5s" - assert 4.5 <= sorted_times[2] <= 5.5, f"Third result at {sorted_times[2]:.2f}s, expected ~5s" - assert 4.5 <= sorted_times[3] <= 5.5, f"Fourth result at {sorted_times[3]:.2f}s, expected ~5s" - - # Last result should be around 6s (row 1 retry success) - assert 5.5 <= sorted_times[4] <= 6.5, f"Fifth result at {sorted_times[4]:.2f}s, expected ~6s (retry success)" - - print("āœ… All timing assertions passed! Pipeline behavior is correct.") - - -if __name__ == "__main__": - test_timing_assertions() +def test_retry_mechanism_mock_verification(): + """Test that verifies the retry mechanism worked by checking the mock calls""" + # Get our mock tracker + mock_tracker = shared_processor.mock_tracker + + print(f"\nšŸ”„ MOCK CALL ANALYSIS:") + print(f" Batch calls made: {mock_tracker.batch_call.call_count}") + print(f" Total row processing calls: {mock_tracker.process_row_call.call_count}") + + if mock_tracker.process_row_call.call_count == 0: + print("āš ļø No calls recorded yet. The evaluation test may not have run or completed.") + return + + # Get all rollout_ids that were processed + call_args = mock_tracker.process_row_call.call_args_list + rollout_ids = [call[0][0] for call in call_args] + + # Count calls per rollout_id + call_counts = Counter(rollout_ids) + + print(f" Call counts per rollout_id: {dict(call_counts)}") + print(f" Individual calls:") + for i, call_arg in enumerate(call_args, 1): + rollout_id = call_arg[0][0] + attempt_num = rollout_ids[:i].count(rollout_id) + print(f" {i}. rollout_id={rollout_id}, attempt={attempt_num}") + + # ASSERTIONS USING MOCK DATA + # Should have exactly 6 total row processing calls (5 initial + 1 retry) + assert ( + mock_tracker.process_row_call.call_count == 6 + ), f"Expected 6 total calls, got {mock_tracker.process_row_call.call_count}" + + # Should have exactly 2 batch calls (initial batch + retry batch) + assert mock_tracker.batch_call.call_count == 2, f"Expected 2 batch calls, got {mock_tracker.batch_call.call_count}" + + # First batch should have 5 rows, second batch should have 1 row (the retry) + batch_call_args = mock_tracker.batch_call.call_args_list + assert batch_call_args[0][0][0] == 5, f"Expected first batch to have 5 rows, got {batch_call_args[0][0][0]}" + assert batch_call_args[1][0][0] == 1, f"Expected second batch to have 1 row, got {batch_call_args[1][0][0]}" + + # Exactly one rollout_id should be called twice, others called once + call_count_values = list(call_counts.values()) + assert ( + call_count_values.count(2) == 1 + ), f"Expected exactly 1 rollout_id to be called twice, got counts: {dict(call_counts)}" + assert ( + call_count_values.count(1) == 4 + ), f"Expected exactly 4 rollout_ids to be called once, got counts: {dict(call_counts)}" + + print("āœ… All mock-based assertions passed! Retry mechanism is working correctly.") From 2d4bfc5a6c756f149076abbdc1258c5811cad819 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 14 Aug 2025 22:34:55 +0000 Subject: [PATCH 9/9] update comments --- eval_protocol/mcp_env.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/eval_protocol/mcp_env.py b/eval_protocol/mcp_env.py index 5dc77c48..f5d09ba0 100644 --- a/eval_protocol/mcp_env.py +++ b/eval_protocol/mcp_env.py @@ -278,10 +278,10 @@ def rollout( Example: # Live mode - tasks = await ep.rollout(envs, policy) + tasks = ep.rollout(envs, policy) # Create environments automatically - tasks = await ep.rollout( + tasks = ep.rollout( "http://localhost:8000/mcp/", policy, evaluation_rows=my_evaluation_rows, @@ -290,10 +290,10 @@ def rollout( # Recording mode os.environ["EP_PLAYBACK_FILE"] = "record.jsonl" - tasks = await ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl") + tasks = ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl") # Playback mode (after recording file exists) - tasks = await ep.rollout(envs, policy) + tasks = ep.rollout(envs, policy) """ # Automatically create environments if a base URL is provided if isinstance(envs, str):