Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions src/inference_endpoint/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ def create_parser() -> argparse.ArgumentParser:
required=True,
help="Template type",
)
Comment thread
anandhu-eng marked this conversation as resolved.
init_parser.add_argument("--output", "-o", type=str, help="Output filename")

return parser


Expand Down Expand Up @@ -260,12 +258,6 @@ def _add_auxiliary_args(parser):
Args:
parser: The argument parser to add arguments to.
"""
parser.add_argument(
"--output",
"-o",
type=Path,
help="Path to save additional output data (not benchmark report)",
)
parser.add_argument(
"--timeout", type=float, help="Timeout in seconds (default: 300)", default=300
)
Expand Down
85 changes: 44 additions & 41 deletions src/inference_endpoint/commands/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import shutil
import signal
import tempfile
import time
import uuid
Comment thread
arekay-nv marked this conversation as resolved.
from pathlib import Path
from urllib.parse import urljoin
Expand Down Expand Up @@ -275,8 +274,6 @@ def _build_config_from_cli(
)
timeout = getattr(args, "timeout", None)
verbose_level = getattr(args, "verbose", 0)
output = getattr(args, "output", None)

# Build BenchmarkConfig from CLI params
return BenchmarkConfig(
name=f"cli_{benchmark_mode}",
Expand Down Expand Up @@ -328,7 +325,6 @@ def _build_config_from_cli(
metrics=Metrics(),
baseline=None, # CLI mode doesn't use baseline
report_dir=report_dir,
output=output,
timeout=timeout,
verbose=verbose_level > 0,
)
Expand Down Expand Up @@ -573,7 +569,6 @@ def _run_benchmark(

# Run benchmark
logger.info("Running...")
start_time = time.time()

sess = None
try:
Expand Down Expand Up @@ -602,15 +597,26 @@ def signal_handler(signum, frame):
# Always restore original handler
signal.signal(signal.SIGINT, old_handler)

elapsed_time = time.time() - start_time
success_count = response_collector.count - len(response_collector.errors)
estimated_qps = success_count / elapsed_time if elapsed_time > 0 else 0
# Prefer authoritative metrics from the session report
report = getattr(sess, "report", None)
if report is None:
logger.error(
"Session report missing — benchmark reporter failed to produce results"
)
raise ExecutionError(
"Session report missing — cannot produce benchmark results"
)
Comment thread
arekay-nv marked this conversation as resolved.

elapsed_time = report.duration_ns / 1e9
total = report.n_samples_issued
success_count = report.n_samples_completed

# qps will be None if duration was 0, so fall back to 0.0
estimated_qps = report.qps or 0.0

# Report results
logger.info(f"Completed in {elapsed_time:.1f}s")
logger.info(
f"Results: {success_count}/{scheduler.total_samples_to_issue} successful"
)
logger.info(f"Results: {success_count}/{total} successful")
logger.info(f"Estimated QPS: {estimated_qps:.1f}")

if response_collector.errors:
Expand All @@ -621,36 +627,33 @@ def signal_handler(signum, frame):
if len(response_collector.errors) > 3:
logger.warning(f" ... +{len(response_collector.errors) - 3} more")

# Save results if requested
if config.output:
try:
results = {
"config": {
"endpoint": endpoint,
"mode": test_mode,
"target_qps": target_qps,
},
"results": {
"total": scheduler.total_samples_to_issue,
"successful": success_count,
"failed": len(response_collector.errors),
"elapsed_time": elapsed_time,
"qps": estimated_qps,
},
}

if collect_responses:
results["responses"] = response_collector.responses

# Always save all errors (useful for debugging)
if response_collector.errors:
results["errors"] = response_collector.errors

with open(config.output, "w") as f:
json.dump(results, f, indent=2)
logger.info(f"Saved: {config.output}")
except Exception as e:
logger.error(f"Save failed: {e}")
try:
results = {
"config": {
"endpoint": endpoint,
"mode": test_mode,
"target_qps": target_qps,
},
"results": {
"total": total,
"successful": success_count,
"failed": total - success_count,
Comment thread
arekay-nv marked this conversation as resolved.
"elapsed_time": elapsed_time,
"qps": estimated_qps,
},
}
if collect_responses:
results["responses"] = response_collector.responses
# Always save all errors (useful for debugging)
if response_collector.errors:
results["errors"] = response_collector.errors
# Save results to JSON file
results_path = report_dir / "results.json"
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
logger.info(f"Saved: {results_path}")
except Exception as e:
logger.error(f"Save failed: {e}")

except KeyboardInterrupt:
logger.warning("Benchmark interrupted by user")
Expand Down
3 changes: 1 addition & 2 deletions src/inference_endpoint/commands/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,13 @@ async def run_init_command(args: argparse.Namespace) -> None:
Args:
args: Command arguments.
Required: --template TYPE (offline/online/eval/submission)
Optional: --output PATH (default: <type>_template.yaml)

Raises:
InputValidationError: If template type is unknown.
SetupError: If template generation/writing fails.
"""
template_type = args.template
output_path = getattr(args, "output", None) or f"{template_type}_template.yaml"
output_path = f"{template_type}_template.yaml"

if template_type not in TEMPLATE_FILES:
logger.error(f"Unknown template: {template_type}")
Expand Down
1 change: 0 additions & 1 deletion src/inference_endpoint/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ class BenchmarkConfig(BaseModel):
settings: Settings = Field(default_factory=Settings)
metrics: Metrics = Field(default_factory=Metrics)
endpoint_config: EndpointConfig = Field(default_factory=EndpointConfig)
output: Path | None = None
report_dir: Path | None = None
timeout: int | None = None
verbose: bool = False
Expand Down
5 changes: 5 additions & 0 deletions src/inference_endpoint/load_generator/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(
self.event_recorder = EventRecorder(
session_id=self.session_id, notify_idle=self.end_event
)
# Will be populated after the test finishes by _run_test
self.report = None

self.sample_uuid_map = None

Expand Down Expand Up @@ -153,6 +155,9 @@ def _run_test(
tokenizer = None
report = reporter.create_report(tokenizer)

# Store report on session so external callers can use it
self.report = report

# Consolidate UUID->index mappings
perf_name = (
perf_test_generator.name
Expand Down
16 changes: 11 additions & 5 deletions tests/integration/commands/test_benchmark_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ async def test_benchmark_with_output_file(
self, mock_http_echo_server, ds_pickle_dataset_path, tmp_path
):
"""Test benchmark saves results to JSON file."""
output_file = tmp_path / "benchmark_results.json"
# The benchmark command writes results to `results.json` inside the
# configured `report_dir`. Pass `report_dir=tmp_path` so the command
# will write output into this temporary directory and we can assert on
# the produced file.
report_dir = tmp_path

args = argparse.Namespace(
benchmark_mode="offline",
Expand All @@ -133,17 +137,19 @@ async def test_benchmark_with_output_file(
min_output_tokens=None,
max_output_tokens=None,
mode=None,
output=output_file,
report_dir=report_dir,
verbose=0,
model="echo-server",
timeout=None,
)

await run_benchmark_command(args)

# Verify file was created
assert output_file.exists()
# Verify file was created at <report_dir>/results.json
results_path = report_dir / "results.json"
assert results_path.exists()

with open(output_file) as f:
with open(results_path) as f:
results = json.load(f)

assert "config" in results
Expand Down
54 changes: 30 additions & 24 deletions tests/unit/commands/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,56 +142,62 @@ async def test_init_unknown_template(self):
"""Test init with unknown template type."""
args = MagicMock()
args.template = "unknown"
args.output = None

with pytest.raises(InputValidationError, match="Unknown template"):
await run_init_command(args)

@pytest.mark.asyncio
async def test_init_success(self, tmp_path):
async def test_init_success(self):
"""Test successful template generation."""
output_file = tmp_path / "test_template.yaml"

args = MagicMock()
args.template = "offline"
args.output = str(output_file)

await run_init_command(args)
output_file = Path(f"{args.template}_template.yaml")
Comment thread
arekay-nv marked this conversation as resolved.

assert output_file.exists()
content = output_file.read_text()
assert "offline-benchmark" in content
assert "max_throughput" in content
try:
await run_init_command(args)

assert output_file.exists()
content = output_file.read_text()
assert "offline-benchmark" in content
assert "max_throughput" in content
finally:
output_file.unlink(missing_ok=True)

@pytest.mark.asyncio
async def test_init_warns_on_overwrite(self, tmp_path, caplog):
async def test_init_warns_on_overwrite(self, caplog):
"""Test warning when file already exists."""
output_file = tmp_path / "existing.yaml"
output_file.write_text("existing content")

args = MagicMock()
args.template = "online"
args.output = str(output_file)

await run_init_command(args)
output_file = Path(f"{args.template}_template.yaml")
output_file.write_text("existing content")

try:
await run_init_command(args)

assert "will be overwritten" in caplog.text
# File should be replaced
assert "online-benchmark" in output_file.read_text()
assert "will be overwritten" in caplog.text
# File should be replaced
assert "online-benchmark" in output_file.read_text()
finally:
output_file.unlink(missing_ok=True)

@pytest.mark.asyncio
async def test_init_all_templates(self, tmp_path):
async def test_init_all_templates(self):
"""Test generating all template types."""
templates = ["offline", "online", "eval", "submission"]

for template_type in templates:
output_file = tmp_path / f"{template_type}_test.yaml"

output_file = Path(f"{template_type}_template.yaml")
args = MagicMock()
args.template = template_type
args.output = str(output_file)

await run_init_command(args)
try:
await run_init_command(args)

assert output_file.exists()
assert output_file.stat().st_size > 0
assert output_file.exists()
assert output_file.stat().st_size > 0
finally:
output_file.unlink(missing_ok=True)