Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/inference_endpoint/commands/benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
BenchmarkConfig,
OfflineBenchmarkConfig,
OnlineBenchmarkConfig,
ProfilerEngine,
TestMode,
TestType,
)
Expand Down Expand Up @@ -98,6 +99,13 @@ def from_config(
config: Annotated[Path, cyclopts.Parameter(name=["--config", "-c"])],
timeout: float | None = None,
mode: TestMode | None = None,
profile: Annotated[
ProfilerEngine | None,
cyclopts.Parameter(
name="--profile",
help="Profile the named inference engine around the performance phase",
),
] = None,
):
"""Run benchmark from YAML config file."""
try:
Expand All @@ -106,6 +114,14 @@ def from_config(
raise InputValidationError(f"Config error: {e}") from e
if timeout is not None:
resolved = resolved.with_updates(timeout=timeout)
if profile is not None:
new_profiling = resolved.settings.profiling.model_copy(
update={"engine": profile}
)
new_settings = resolved.settings.model_copy(
update={"profiling": new_profiling}
)
resolved = resolved.with_updates(settings=new_settings)
test_mode = mode or (
TestMode.BOTH if resolved.type == TestType.SUBMISSION else TestMode.PERF
)
Expand Down
190 changes: 185 additions & 5 deletions src/inference_endpoint/commands/benchmark/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@
import shutil
import signal
import tempfile
import time
import uuid
from collections.abc import Callable
from dataclasses import dataclass, field
from dataclasses import replace as dataclass_replace
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Any, TextIO
from urllib import error as urllib_error
from urllib import request as urllib_request
from urllib.parse import urljoin

import msgspec
Expand Down Expand Up @@ -69,6 +72,7 @@
DatasetType,
LoadPattern,
LoadPatternType,
ProfilerEngine,
StreamingMode,
TestMode,
TestType,
Expand Down Expand Up @@ -140,6 +144,10 @@ class BenchmarkResult:
collector: ResponseCollector
report: Report | None
tmpfs_dir: Path
# Profile trigger payload {engine: str, starts: [...], stops: [...]} when
# settings.profiling.engine is set; None otherwise. Rendered into
# report.txt and a sibling profiling.json by finalize_benchmark.
profiling: dict[str, Any] | None = None


@dataclass
Expand Down Expand Up @@ -548,6 +556,110 @@ def _load_final_snapshot_from_disk(path: Path) -> dict[str, Any] | None:
return None


# (start_path, stop_path) for each supported inference engine's profiling
# protocol. Add a row when introducing a new ProfilerEngine variant.
_PROFILE_PATHS: dict[ProfilerEngine, tuple[str, str]] = {
ProfilerEngine.VLLM: ("/start_profile", "/stop_profile"),
}


def _derive_profile_urls(
endpoints: list[str], engine: ProfilerEngine, action: str
) -> list[str]:
"""One profile URL per endpoint, derived from the engine's HTTP protocol.

For vLLM: strip a trailing ``/v1`` from each endpoint and append
``/{start,stop}_profile``. ``action`` is ``"start"`` or ``"stop"``.
"""
if not endpoints:
raise ValueError(
f"profiling.engine={engine.value} but endpoint_config.endpoints "
f"is empty; cannot derive {action} URLs"
)
start_path, stop_path = _PROFILE_PATHS[engine]
path = start_path if action == "start" else stop_path
urls: list[str] = []
for ep in endpoints:
base = ep.rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
urls.append(f"{base.rstrip('/')}{path}")
return urls


def _post_profile(url: str) -> dict[str, Any]:
"""POST {url} with empty body; never raises. Returns a record dict suitable
for report.txt rendering and profiling.json serialization."""
record: dict[str, Any] = {
"url": url,
"sent_at_ns": time.monotonic_ns(),
"sent_at_iso": datetime.now().isoformat(timespec="milliseconds"),
"status": None,
"error": None,
}
req = urllib_request.Request(url, method="POST", data=b"")
try:
with urllib_request.urlopen(req, timeout=2) as resp:
record["status"] = resp.status
except urllib_error.HTTPError as e:
record["status"] = e.code
record["error"] = f"{e.code} {e.reason}"
except Exception as e: # noqa: BLE001 — profile failures must never abort a run
record["error"] = f"{type(e).__name__}: {e}"
return record


def _render_profile_status(rec: dict[str, Any]) -> str:
status = rec.get("status")
error = rec.get("error")
if status == 200:
return "200 OK"
if status == 404:
return (
"404 (profiling not enabled on server — pass "
"--profiler-config.profiler=... to server)"
)
if error:
return error
if status is not None:
return str(status)
return "ERROR"


def _write_profiling_section(f: TextIO, profiling: dict[str, Any]) -> None:
"""Append the Profiling section to report.txt (called after report.display)."""
starts = profiling.get("starts", [])
stops = profiling.get("stops", [])
f.write("\n------------------- Profiling -------------------\n")
f.write(f"Engine: {profiling.get('engine', 'unknown')}\n")
f.write("Start:\n")
for rec in starts:
f.write(
f" POST {rec['url']} @ {rec['sent_at_iso']} → "
f"{_render_profile_status(rec)}\n"
)
if stops:
f.write("Stop:\n")
for rec in stops:
suffix = (
" (from abort handler)" if rec.get("stop_reason") == "abort" else ""
)
f.write(
f" POST {rec['url']} @ {rec['sent_at_iso']} → "
f"{_render_profile_status(rec)}{suffix}\n"
)
if starts and stops:
first_start = min(r["sent_at_ns"] for r in starts)
last_stop = max(r["sent_at_ns"] for r in stops)
f.write(f"Trigger span: {(last_stop - first_start) / 1e9:.2f} s\n")
f.write(
"\nNote: actual trace window is bounded by server-side "
"--profiler-config.delay_iterations and "
"--profiler-config.max_iterations.\n"
"Trace artifact path is in server stdout.\n"
)


async def _run_benchmark_async(
ctx: BenchmarkContext,
loop: asyncio.AbstractEventLoop,
Expand Down Expand Up @@ -736,6 +848,22 @@ def _on_sample_complete(result: QueryResult) -> None:
_timeout_done = False
max_duration_ms = ctx.rt_settings.max_duration_ms

# Profile trigger state. Pre-derive URLs once so a bad config
# (engine set but no endpoints) fails before the run.
profiling_cfg = config.settings.profiling
profile_start_urls: list[str] = []
profile_stop_urls: list[str] = []
profile_starts: list[dict[str, Any]] = []
profile_stops: list[dict[str, Any]] = []
if profiling_cfg.engine is not None:
profile_start_urls = _derive_profile_urls(
config.endpoint_config.endpoints, profiling_cfg.engine, "start"
)
profile_stop_urls = _derive_profile_urls(
config.endpoint_config.endpoints, profiling_cfg.engine, "stop"
)
session_completed_normally = False

def _on_global_timeout() -> None:
if not _timeout_done:
logger.warning(
Expand All @@ -746,24 +874,58 @@ def _on_global_timeout() -> None:

def _on_phase_start(phase: PhaseConfig) -> None:
nonlocal global_timeout_handle
if (
phase.phase_type == PhaseType.PERFORMANCE
and max_duration_ms is not None
):
if phase.phase_type != PhaseType.PERFORMANCE:
return
if max_duration_ms is not None:
global_timeout_handle = loop.call_later(
max_duration_ms / 1000.0, _on_global_timeout
)
# Fire /start_profile sequentially before any perf request is
# issued, so the server is armed when traffic begins. Blocks
# the loop briefly (sub-100ms per URL); strategy task hasn't
# been created yet so nothing is starved.
for url in profile_start_urls:
rec = _post_profile(url)
if rec["status"] == 200:
logger.info("Profile start: %s -> 200 OK", url)
else:
logger.warning(
"Profile start: %s -> %s",
url,
rec["error"] or rec["status"],
)
profile_starts.append(rec)

loop.add_signal_handler(signal.SIGINT, session.stop)
try:
result = await session.run(phases, on_phase_start=_on_phase_start)
session_completed_normally = True
except Exception as e:
raise ExecutionError(f"Benchmark execution failed: {e}") from e
finally:
_timeout_done = True
if global_timeout_handle is not None:
global_timeout_handle.cancel()
loop.remove_signal_handler(signal.SIGINT)
# Fire /stop_profile for URLs whose /start_profile succeeded.
# Unifies the clean phase-end path and the abort path —
# both reach this block, both fire stops.
if profile_starts:
stop_reason = "phase_end" if session_completed_normally else "abort"
for i, start_rec in enumerate(profile_starts):
if start_rec["status"] != 200 or i >= len(profile_stop_urls):
continue
rec = _post_profile(profile_stop_urls[i])
rec["stop_reason"] = stop_reason
if rec["status"] == 200:
logger.info("Profile stop: %s -> 200 OK", profile_stop_urls[i])
else:
logger.warning(
"Profile stop: %s -> %s",
profile_stop_urls[i],
rec["error"] or rec["status"],
)
profile_stops.append(rec)
Comment on lines +913 to +928
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The profiling stop requests are synchronous blocking network calls (_post_profile uses urllib.request.urlopen with a 2-second timeout). Running these sequentially in the main event loop thread inside the finally block can block the event loop for several seconds if the endpoints are slow or unresponsive.

Furthermore, if the user interrupts the cleanup process (e.g., by pressing Ctrl+C a second time because it appears hung), a KeyboardInterrupt will be raised during these blocking calls, which will bypass the rest of the critical cleanup (such as shutting down the HTTP client, closing the publisher, and terminating launcher services).

To prevent this, we should run the profiling stop requests asynchronously in a thread pool using loop.run_in_executor and wrap the entire block in a try...except BaseException to ensure that any network failures or interrupts during profiling stop do not prevent the rest of the cleanup from executing.

            if profile_starts:
                try:
                    stop_reason = "phase_end" if session_completed_normally else "abort"
                    for i, start_rec in enumerate(profile_starts):
                        if start_rec["status"] != 200 or i >= len(profile_stop_urls):
                            continue
                        try:
                            rec = await loop.run_in_executor(
                                None, _post_profile, profile_stop_urls[i]
                            )
                            rec["stop_reason"] = stop_reason
                            if rec["status"] == 200:
                                logger.info("Profile stop: %s -> 200 OK", profile_stop_urls[i])
                            else:
                                logger.warning(
                                    "Profile stop: %s -> %s",
                                    profile_stop_urls[i],
                                    rec["error"] or rec["status"],
                                )
                            profile_stops.append(rec)
                        except Exception as e:
                            logger.warning(
                                "Failed to stop profile for %s: %s",
                                profile_stop_urls[i],
                                e,
                            )
                except BaseException as e:
                    logger.warning("Profiling stop cleanup was interrupted: %s", e)

logger.info("Cleaning up...")
try:
if http_client:
Expand Down Expand Up @@ -816,11 +978,20 @@ def _on_phase_start(phase: PhaseConfig) -> None:
metrics_subscriber.close()
pbar.close()

profiling_payload: dict[str, Any] | None = None
if profiling_cfg.engine is not None:
profiling_payload = {
"engine": profiling_cfg.engine.value,
"starts": profile_starts,
"stops": profile_stops,
}

return BenchmarkResult(
session=result,
collector=collector,
report=report,
tmpfs_dir=tmpfs_dir,
profiling=profiling_payload,
)


Expand Down Expand Up @@ -889,8 +1060,17 @@ def finalize_benchmark(ctx: BenchmarkContext, bench: BenchmarkResult) -> None:
report_txt = ctx.report_dir / "report.txt"
with report_txt.open("w") as f:
report.display(fn=lambda s: print(s, file=f))
if bench.profiling is not None:
_write_profiling_section(f, bench.profiling)
logger.info(f"Report written to {report_txt}")

# Sibling profiling.json — kept separate so Report stays a pure
# snapshot-derived struct.
if bench.profiling is not None:
(ctx.report_dir / "profiling.json").write_text(
json.dumps(bench.profiling, indent=2)
)

# Write scoring artifacts + copy event log from tmpfs to disk
_write_scoring_artifacts(ctx, result, bench.tmpfs_dir)

Expand Down
39 changes: 39 additions & 0 deletions src/inference_endpoint/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,44 @@ class WarmupConfig(BaseModel):
] = Field(42, description="RNG seed for warmup scheduling and sample ordering")


class ProfilerEngine(str, Enum):
"""Inference engine whose profiling protocol the client should drive.

Selects the HTTP path layout used to derive start/stop URLs from
``endpoint_config.endpoints``. Each value corresponds to one server-side
profiling protocol; add a new variant + ``_PROFILE_PATHS`` row to support
another engine.
"""

VLLM = "vllm"


@cyclopts.Parameter(name="*")
class ProfilingConfig(BaseModel):
"""Client-side trigger for the server's profiler.

When ``engine`` is set, fires POST ``<start_path>`` at performance-phase
begin and POST ``<stop_path>`` at performance-phase end. URLs are derived
from ``endpoint_config.endpoints`` using the engine-specific protocol.
Server must be launched with profiling enabled (e.g. vLLM's
``--profiler-config.profiler=cuda|torch``); the schedule
(``delay_iterations``, ``max_iterations``) is set there, not here.
"""

model_config = ConfigDict(extra="forbid", frozen=True)

engine: Annotated[
ProfilerEngine | None,
cyclopts.Parameter(
alias="--profile",
help="Profile the named inference engine around the performance phase",
),
] = Field(
None,
description="Profile the named inference engine around the performance phase",
)


@cyclopts.Parameter(name="*")
class Settings(BaseModel):
"""Test settings."""
Expand All @@ -493,6 +531,7 @@ class Settings(BaseModel):
load_pattern: LoadPattern = Field(default_factory=LoadPattern)
client: HTTPClientConfig = Field(default_factory=HTTPClientConfig)
warmup: WarmupConfig = Field(default_factory=WarmupConfig)
profiling: ProfilingConfig = Field(default_factory=ProfilingConfig)


class OfflineSettings(Settings):
Expand Down
Loading