Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ jobs:
--ignore=tests/pytest/test_frozen_lake.py \
--ignore=tests/pytest/test_lunar_lander.py \
--ignore=tests/pytest/test_tau_bench_airline.py \
--ignore=tests/pytest/test_apps_coding.py \
--ignore=tests/test_tau_bench_airline_smoke.py \
--cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10

Expand Down
4 changes: 2 additions & 2 deletions eval_protocol/benchmarks/suites/aime25.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from eval_protocol.benchmarks.registry import export_benchmark
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
from eval_protocol.pytest.default_single_turn_rollout_process import (
default_single_turn_rollout_processor,
SingleTurnRolloutProcessor,
)
from eval_protocol.pytest.evaluation_test import evaluation_test

Expand Down Expand Up @@ -72,7 +72,7 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
}
],
rollout_processor=default_single_turn_rollout_processor,
rollout_processor=SingleTurnRolloutProcessor(),
aggregation_method="mean",
passed_threshold=None,
num_runs=8,
Expand Down
43 changes: 29 additions & 14 deletions eval_protocol/benchmarks/suites/gpqa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import csv
import io
import re
Expand All @@ -8,9 +9,11 @@
from eval_protocol.benchmarks.registry import export_benchmark
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
from eval_protocol.pytest.default_single_turn_rollout_process import (
default_single_turn_rollout_processor,
SingleTurnRolloutProcessor,
)
from eval_protocol.pytest.evaluation_test import evaluation_test
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig

SYSTEM_PROMPT = (
"You are a helpful assistant. Read the question and options carefully. "
Expand Down Expand Up @@ -60,19 +63,31 @@ def _strip_gt_messages(msgs: List[Message]) -> List[Message]:
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]


async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[EvaluationRow]:
"""Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to default processor."""
processed: List[EvaluationRow] = []
for r in rows:
gt_tokens = [m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")]
if gt_tokens:
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
r.ground_truth = gt_val
r.messages = [
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
class GPQAStripGTRolloutProcessor(RolloutProcessor):
"""Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to SingleTurnRolloutProcessor."""

def __init__(self):
super().__init__()
self.single_turn_processor = SingleTurnRolloutProcessor()

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Preprocess rows and delegate to SingleTurnRolloutProcessor."""
processed: List[EvaluationRow] = []

for r in rows:
gt_tokens = [
m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")
]
processed.append(r)
return await default_single_turn_rollout_processor(processed, config)
if gt_tokens:
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
r.ground_truth = gt_val
r.messages = [
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
]
processed.append(r)

# Delegate to SingleTurnRolloutProcessor
return self.single_turn_processor(processed, config)


@export_benchmark("gpqa")
Expand All @@ -81,7 +96,7 @@ async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) ->
completion_params=[
{"extra_body": {"reasoning_effort": "low"}, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}
],
rollout_processor=gpqa_strip_gt_rollout_processor,
rollout_processor=GPQAStripGTRolloutProcessor(),
aggregation_method="mean",
passed_threshold=None,
num_runs=8,
Expand Down
8 changes: 4 additions & 4 deletions eval_protocol/benchmarks/suites/livebench_data_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from eval_protocol.benchmarks.registry import export_benchmark, register_composite_benchmark
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
from eval_protocol.pytest.default_single_turn_rollout_process import (
default_single_turn_rollout_processor,
SingleTurnRolloutProcessor,
)
from eval_protocol.pytest.evaluation_test import evaluation_test

Expand Down Expand Up @@ -375,7 +375,7 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
input_messages=[[m for m in r.messages] for r in _CTA_ROWS],
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
rollout_processor=default_single_turn_rollout_processor,
rollout_processor=SingleTurnRolloutProcessor(),
aggregation_method="mean",
passed_threshold=None,
num_runs=4,
Expand Down Expand Up @@ -418,7 +418,7 @@ def livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS],
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
rollout_processor=default_single_turn_rollout_processor,
rollout_processor=SingleTurnRolloutProcessor(),
aggregation_method="mean",
passed_threshold=None,
num_runs=4,
Expand Down Expand Up @@ -462,7 +462,7 @@ def livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS],
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
rollout_processor=default_single_turn_rollout_processor,
rollout_processor=SingleTurnRolloutProcessor(),
aggregation_method="mean",
passed_threshold=None,
num_runs=4,
Expand Down
4 changes: 2 additions & 2 deletions eval_protocol/benchmarks/suites/tau_bench_retail.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from eval_protocol.benchmarks.registry import export_benchmark
from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
from vendor.tau2.data_model.message import (
AssistantMessage,
SystemMessage,
Expand Down Expand Up @@ -73,7 +73,7 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
}
],
rollout_processor=default_mcp_gym_rollout_processor,
rollout_processor=MCPGymRolloutProcessor(),
rollout_processor_kwargs={"domain": "retail"},
num_runs=8,
mode="pointwise",
Expand Down
21 changes: 5 additions & 16 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class ExecutionManager:
Manage rollout for MCP environments.
"""

async def execute_rollouts(
def execute_rollouts(
self,
envs: "GeneralMCPVectorEnv",
policy: Union["LLMBasePolicy", Callable],
steps: int = 512,
openai_format_log_file: Optional[str] = None,
max_concurrent_rollouts: int = 8,
evaluation_rows: Optional[List[EvaluationRow]] = None,
) -> AsyncIterator[EvaluationRow]:
) -> List[asyncio.Task[EvaluationRow]]:
"""
Execute general rollouts using tool calling interface with automatic record/playback.

Expand All @@ -66,7 +66,7 @@ async def execute_rollouts(
- Set and file exists: Playback mode (uses recorded data)

Returns:
AsyncIterator of EvaluationRow objects with unified evaluation data format
List of asyncio.Task objects for external handling
"""
start_time = time.time()

Expand Down Expand Up @@ -138,7 +138,7 @@ async def _execute_with_semaphore(idx):
if trajectory.terminated:
if trajectory.termination_reason == TerminationReason.ERROR:
evaluation_row.rollout_status.status = "error"
evaluation_row.rollout_status.error_message = trajectory.control_plane_summary.get(
evaluation_row.rollout_status.termination_reason = trajectory.control_plane_summary.get(
"error_message", None
)
else:
Expand All @@ -151,18 +151,7 @@ async def _execute_with_semaphore(idx):

# Create all tasks
tasks = [asyncio.create_task(_execute_with_semaphore(i)) for i in range(envs.n)]

# Yield results as they complete (note that they're not necessarily in original order)
try:
for task in asyncio.as_completed(tasks):
try:
yield await task
except Exception:
logger.exception("Error processing rollout")
finally:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
return tasks

async def _execute_rollout(
self,
Expand Down
24 changes: 12 additions & 12 deletions eval_protocol/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def make(
return mcp_envs


async def rollout(
def rollout(
envs: GeneralMCPVectorEnv,
policy: Union[FireworksPolicy, LLMBasePolicy, Callable],
*,
Expand All @@ -246,7 +246,7 @@ async def rollout(
steps: int = 512,
openai_format_log_file: Optional[str] = None,
max_concurrent_rollouts: int = 8,
) -> AsyncIterator[EvaluationRow]:
) -> List[asyncio.Task[EvaluationRow]]:
"""
Execute general rollouts using tool calling interface with automatic record/playback.

Expand Down Expand Up @@ -274,14 +274,14 @@ async def rollout(
- Set and file exists: Playback mode (uses recorded data)

Returns:
List of EvaluationRow objects
List of asyncio.Task objects for external handling

Example:
# Live mode
evaluation_rows = await ep.rollout(envs, policy)
tasks = ep.rollout(envs, policy)

# Create environments automatically
trajectories = await ep.rollout(
tasks = ep.rollout(
"http://localhost:8000/mcp/",
policy,
evaluation_rows=my_evaluation_rows,
Expand All @@ -290,26 +290,26 @@ async def rollout(

# Recording mode
os.environ["EP_PLAYBACK_FILE"] = "record.jsonl"
evaluation_rows = await ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl")
tasks = ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl")

# Playback mode (after recording file exists)
evaluation_rows = await ep.rollout(envs, policy)
tasks = ep.rollout(envs, policy)
"""
# Automatically create environments if a base URL is provided
if isinstance(envs, str):
if evaluation_rows is None and dataset is None:
raise ValueError("Either 'evaluation_rows' or 'dataset' must be provided when envs is a URL")

auto_model_id = model_id or getattr(policy, "model_id", "unknown")
envs = await make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)
envs = make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)

# Use the new ExecutionManager for execution
execution_manager = ExecutionManager()

async for evaluation_row in execution_manager.execute_rollouts(
tasks = execution_manager.execute_rollouts(
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows
):
yield evaluation_row
)
return tasks


async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
Expand All @@ -336,7 +336,7 @@ async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
policy = FireworksPolicy("test-model")

# Run short rollout
evaluation_rows = await rollout(envs, policy=policy, steps=10)
evaluation_rows = rollout(envs, policy=policy, steps=10)

if evaluation_rows and len(evaluation_rows[0].messages) > 1:
results["successful"] += 1
Expand Down
21 changes: 11 additions & 10 deletions eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from .default_agent_rollout_processor import default_agent_rollout_processor
from .default_agent_rollout_processor import AgentRolloutProcessor
from .default_dataset_adapter import default_dataset_adapter
from .default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
from .default_no_op_rollout_process import default_no_op_rollout_processor
from .default_single_turn_rollout_process import default_single_turn_rollout_processor
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
from .default_no_op_rollout_processor import NoOpRolloutProcessor
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
from .evaluation_test import evaluation_test
from .types import RolloutProcessor, RolloutProcessorConfig
from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig

__all__ = [
"default_agent_rollout_processor",
"default_mcp_gym_rollout_processor",
"default_no_op_rollout_processor",
"default_single_turn_rollout_processor",
"default_dataset_adapter",
"AgentRolloutProcessor",
"MCPGymRolloutProcessor",
"RolloutProcessor",
"SingleTurnRolloutProcessor",
"NoOpRolloutProcessor",
"default_dataset_adapter",
"RolloutProcessorConfig",
"evaluation_test",
]
69 changes: 30 additions & 39 deletions eval_protocol/pytest/default_agent_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -115,46 +116,36 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[Tex
return tool_result.content


async def default_agent_rollout_processor(
rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> AsyncIterator[EvaluationRow]:
"""Process agent rollouts with bounded concurrency and yield as they complete."""
class AgentRolloutProcessor(RolloutProcessor):
"""Agent rollout processor for tool-calling agents."""

max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
semaphore = asyncio.Semaphore(max_concurrent)
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Create agent rollout tasks and return them for external handling."""

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
agent = Agent(
model=config.completion_params["model"], row=row, config_path=config.mcp_config_path, logger=config.logger
)
try:
await agent.setup()
await agent.call_agent()
return agent.evaluation_row
finally:
if agent.mcp_client:
await agent.mcp_client.cleanup()

async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
try:
return await process_row(r)
except Exception as e:
logger.exception(f"Error processing row {r.input_metadata.row_id}: {e}")
return r

# Create all tasks
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
semaphore = asyncio.Semaphore(max_concurrent)

# Yield results as they complete (note that they're not necessarily in original order)
try:
for task in asyncio.as_completed(tasks):
async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
agent = Agent(
model=config.completion_params["model"],
row=row,
config_path=config.mcp_config_path,
logger=config.logger,
)
try:
yield await task
except Exception:
logger.exception("Error processing row")
finally:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
await agent.setup()
await agent.call_agent()
return agent.evaluation_row
finally:
if agent.mcp_client:
await agent.mcp_client.cleanup()

async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
result = await process_row(r)
return result

# Create and return tasks for external handling
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
return tasks
Loading
Loading