diff --git a/src/inference_endpoint/cli.py b/src/inference_endpoint/cli.py index d8957dfd..d1e0ab62 100644 --- a/src/inference_endpoint/cli.py +++ b/src/inference_endpoint/cli.py @@ -165,8 +165,6 @@ def create_parser() -> argparse.ArgumentParser: required=True, help="Template type", ) - init_parser.add_argument("--output", "-o", type=str, help="Output filename") - return parser @@ -260,12 +258,6 @@ def _add_auxiliary_args(parser): Args: parser: The argument parser to add arguments to. """ - parser.add_argument( - "--output", - "-o", - type=Path, - help="Path to save additional output data (not benchmark report)", - ) parser.add_argument( "--timeout", type=float, help="Timeout in seconds (default: 300)", default=300 ) diff --git a/src/inference_endpoint/commands/benchmark.py b/src/inference_endpoint/commands/benchmark.py index e55cae15..784e6341 100644 --- a/src/inference_endpoint/commands/benchmark.py +++ b/src/inference_endpoint/commands/benchmark.py @@ -24,7 +24,6 @@ import shutil import signal import tempfile -import time import uuid from pathlib import Path from urllib.parse import urljoin @@ -275,8 +274,6 @@ def _build_config_from_cli( ) timeout = getattr(args, "timeout", None) verbose_level = getattr(args, "verbose", 0) - output = getattr(args, "output", None) - # Build BenchmarkConfig from CLI params return BenchmarkConfig( name=f"cli_{benchmark_mode}", @@ -328,7 +325,6 @@ def _build_config_from_cli( metrics=Metrics(), baseline=None, # CLI mode doesn't use baseline report_dir=report_dir, - output=output, timeout=timeout, verbose=verbose_level > 0, ) @@ -573,7 +569,6 @@ def _run_benchmark( # Run benchmark logger.info("Running...") - start_time = time.time() sess = None try: @@ -602,15 +597,26 @@ def signal_handler(signum, frame): # Always restore original handler signal.signal(signal.SIGINT, old_handler) - elapsed_time = time.time() - start_time - success_count = response_collector.count - len(response_collector.errors) - estimated_qps = success_count / elapsed_time if elapsed_time > 0 else 0 + # Prefer authoritative metrics from the session report + report = getattr(sess, "report", None) + if report is None: + logger.error( + "Session report missing — benchmark reporter failed to produce results" + ) + raise ExecutionError( + "Session report missing — cannot produce benchmark results" + ) + + elapsed_time = report.duration_ns / 1e9 + total = report.n_samples_issued + success_count = report.n_samples_completed + + # qps will be None if duration was 0, so fall back to 0.0 + estimated_qps = report.qps or 0.0 # Report results logger.info(f"Completed in {elapsed_time:.1f}s") - logger.info( - f"Results: {success_count}/{scheduler.total_samples_to_issue} successful" - ) + logger.info(f"Results: {success_count}/{total} successful") logger.info(f"Estimated QPS: {estimated_qps:.1f}") if response_collector.errors: @@ -621,36 +627,33 @@ def signal_handler(signum, frame): if len(response_collector.errors) > 3: logger.warning(f" ... +{len(response_collector.errors) - 3} more") - # Save results if requested - if config.output: - try: - results = { - "config": { - "endpoint": endpoint, - "mode": test_mode, - "target_qps": target_qps, - }, - "results": { - "total": scheduler.total_samples_to_issue, - "successful": success_count, - "failed": len(response_collector.errors), - "elapsed_time": elapsed_time, - "qps": estimated_qps, - }, - } - - if collect_responses: - results["responses"] = response_collector.responses - - # Always save all errors (useful for debugging) - if response_collector.errors: - results["errors"] = response_collector.errors - - with open(config.output, "w") as f: - json.dump(results, f, indent=2) - logger.info(f"Saved: {config.output}") - except Exception as e: - logger.error(f"Save failed: {e}") + try: + results = { + "config": { + "endpoint": endpoint, + "mode": test_mode, + "target_qps": target_qps, + }, + "results": { + "total": total, + "successful": success_count, + "failed": total - success_count, + "elapsed_time": elapsed_time, + "qps": estimated_qps, + }, + } + if collect_responses: + results["responses"] = response_collector.responses + # Always save all errors (useful for debugging) + if response_collector.errors: + results["errors"] = response_collector.errors + # Save results to JSON file + results_path = report_dir / "results.json" + with open(results_path, "w") as f: + json.dump(results, f, indent=2) + logger.info(f"Saved: {results_path}") + except Exception as e: + logger.error(f"Save failed: {e}") except KeyboardInterrupt: logger.warning("Benchmark interrupted by user") diff --git a/src/inference_endpoint/commands/utils.py b/src/inference_endpoint/commands/utils.py index 7e6009fc..da4c27c4 100644 --- a/src/inference_endpoint/commands/utils.py +++ b/src/inference_endpoint/commands/utils.py @@ -251,14 +251,13 @@ async def run_init_command(args: argparse.Namespace) -> None: Args: args: Command arguments. Required: --template TYPE (offline/online/eval/submission) - Optional: --output PATH (default: _template.yaml) Raises: InputValidationError: If template type is unknown. SetupError: If template generation/writing fails. """ template_type = args.template - output_path = getattr(args, "output", None) or f"{template_type}_template.yaml" + output_path = f"{template_type}_template.yaml" if template_type not in TEMPLATE_FILES: logger.error(f"Unknown template: {template_type}") diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 444ccb9d..062eedd4 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -322,7 +322,6 @@ class BenchmarkConfig(BaseModel): settings: Settings = Field(default_factory=Settings) metrics: Metrics = Field(default_factory=Metrics) endpoint_config: EndpointConfig = Field(default_factory=EndpointConfig) - output: Path | None = None report_dir: Path | None = None timeout: int | None = None verbose: bool = False diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 079bee26..72b39c31 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -58,6 +58,8 @@ def __init__( self.event_recorder = EventRecorder( session_id=self.session_id, notify_idle=self.end_event ) + # Will be populated after the test finishes by _run_test + self.report = None self.sample_uuid_map = None @@ -153,6 +155,9 @@ def _run_test( tokenizer = None report = reporter.create_report(tokenizer) + # Store report on session so external callers can use it + self.report = report + # Consolidate UUID->index mappings perf_name = ( perf_test_generator.name diff --git a/tests/integration/commands/test_benchmark_command.py b/tests/integration/commands/test_benchmark_command.py index b9b0d444..a68442de 100644 --- a/tests/integration/commands/test_benchmark_command.py +++ b/tests/integration/commands/test_benchmark_command.py @@ -116,7 +116,11 @@ async def test_benchmark_with_output_file( self, mock_http_echo_server, ds_pickle_dataset_path, tmp_path ): """Test benchmark saves results to JSON file.""" - output_file = tmp_path / "benchmark_results.json" + # The benchmark command writes results to `results.json` inside the + # configured `report_dir`. Pass `report_dir=tmp_path` so the command + # will write output into this temporary directory and we can assert on + # the produced file. + report_dir = tmp_path args = argparse.Namespace( benchmark_mode="offline", @@ -133,17 +137,19 @@ async def test_benchmark_with_output_file( min_output_tokens=None, max_output_tokens=None, mode=None, - output=output_file, + report_dir=report_dir, verbose=0, model="echo-server", timeout=None, ) + await run_benchmark_command(args) - # Verify file was created - assert output_file.exists() + # Verify file was created at /results.json + results_path = report_dir / "results.json" + assert results_path.exists() - with open(output_file) as f: + with open(results_path) as f: results = json.load(f) assert "config" in results diff --git a/tests/unit/commands/test_utils.py b/tests/unit/commands/test_utils.py index dee267db..72d0fee0 100644 --- a/tests/unit/commands/test_utils.py +++ b/tests/unit/commands/test_utils.py @@ -142,56 +142,62 @@ async def test_init_unknown_template(self): """Test init with unknown template type.""" args = MagicMock() args.template = "unknown" - args.output = None with pytest.raises(InputValidationError, match="Unknown template"): await run_init_command(args) @pytest.mark.asyncio - async def test_init_success(self, tmp_path): + async def test_init_success(self): """Test successful template generation.""" - output_file = tmp_path / "test_template.yaml" args = MagicMock() args.template = "offline" - args.output = str(output_file) - await run_init_command(args) + output_file = Path(f"{args.template}_template.yaml") - assert output_file.exists() - content = output_file.read_text() - assert "offline-benchmark" in content - assert "max_throughput" in content + try: + await run_init_command(args) + + assert output_file.exists() + content = output_file.read_text() + assert "offline-benchmark" in content + assert "max_throughput" in content + finally: + output_file.unlink(missing_ok=True) @pytest.mark.asyncio - async def test_init_warns_on_overwrite(self, tmp_path, caplog): + async def test_init_warns_on_overwrite(self, caplog): """Test warning when file already exists.""" - output_file = tmp_path / "existing.yaml" - output_file.write_text("existing content") args = MagicMock() args.template = "online" - args.output = str(output_file) - await run_init_command(args) + output_file = Path(f"{args.template}_template.yaml") + output_file.write_text("existing content") + + try: + await run_init_command(args) - assert "will be overwritten" in caplog.text - # File should be replaced - assert "online-benchmark" in output_file.read_text() + assert "will be overwritten" in caplog.text + # File should be replaced + assert "online-benchmark" in output_file.read_text() + finally: + output_file.unlink(missing_ok=True) @pytest.mark.asyncio - async def test_init_all_templates(self, tmp_path): + async def test_init_all_templates(self): """Test generating all template types.""" templates = ["offline", "online", "eval", "submission"] for template_type in templates: - output_file = tmp_path / f"{template_type}_test.yaml" - + output_file = Path(f"{template_type}_template.yaml") args = MagicMock() args.template = template_type - args.output = str(output_file) - await run_init_command(args) + try: + await run_init_command(args) - assert output_file.exists() - assert output_file.stat().st_size > 0 + assert output_file.exists() + assert output_file.stat().st_size > 0 + finally: + output_file.unlink(missing_ok=True)