Skip to content

feat(logging,trace): cuda-graph-compatible level-5/10 logging + fi_trace template additions/fixes#3172

Merged
yyihuang merged 19 commits into
flashinfer-ai:mainfrom
yyihuang:cuda-graph-api-logging
May 15, 2026
Merged

feat(logging,trace): cuda-graph-compatible level-5/10 logging + fi_trace template additions/fixes#3172
yyihuang merged 19 commits into
flashinfer-ai:mainfrom
yyihuang:cuda-graph-api-logging

Conversation

@yyihuang
Copy link
Copy Markdown
Collaborator

@yyihuang yyihuang commented Apr 25, 2026

Summary

Two related changes to @flashinfer_api:

  1. CUDA-graph compatibility for FLASHINFER_LOGLEVEL=5 and =10. Both levels previously had host-side paths that were unsafe inside torch.cuda.graph(...) capture (tensor.min().item(), tensor.cpu(), tensor repr on nested CUDA tensors, etc.). Level 5 skipped stats under capture; Level 10 could invalidate graph capture or only preserve the latest replay state.
  2. fi_trace template additions/fixes uncovered while validating (1) end-to-end against sglang DSR1 FP8 TP=8 with --attention-backend trtllm_mla. Several flashinfer entry points exercised by that workload had bare @flashinfer_api (no trace template) or had templates whose declared shapes did not match the actual runtime layout.

Part A: CUDA-graph-compatible logging

Level 5: stats under CUDA graph capture

Replaces the [statistics skipped: CUDA graph capture in progress] path for supported dtypes with a small captured CUDA kernel that computes min/max/mean/nan/inf and emits one line via device-side printf. The launch is captured into the graph, so the printf fires on every g.replay(). The host log records a correlation marker so the kernel-emitted line can be matched back to the API call/argument.

Supported dtypes: float32, float16, bfloat16, int32, int64, uint8. Other dtypes (for example fp8/fp4) fall back to the legacy skip message.

Files: csrc/api_log_stats.cu, csrc/flashinfer_api_log_stats_binding.cu, flashinfer/jit/api_log_stats.py, flashinfer/api_logging.py, flashinfer/aot.py.

Example output inside torch.cuda.graph(...):

FlashInfer API Call: my_op
  arg[0]:
    Tensor(
      shape=(64, 64) ...
      [stats deferred to GPU kernel: id=1; look for '[flashinfer stats] id=1 ...' in graph replay output]
    )

On g.replay():

[flashinfer stats] id=1 numel=4096 min=-3.42 max=3.59 mean=0.01 nan=0 inf=0

Level 10: tensor dumps under CUDA graph capture

Current behavior under CUDA graph capture:

  • Capture does not call .cpu(), allocate pinned host memory, or insert D2H copy nodes into the captured graph.
  • Captured level-10 calls record tensor references plus metadata and defer disk writes.
  • torch.cuda.CUDAGraph.capture_begin / capture_end are wrapped to tag deferred dumps with the owning graph id.
  • torch.cuda.CUDAGraph.replay() is wrapped so FlashInfer automatically flushes graph dumps after every replay for that graph. No sglang code injection is needed.
  • Each replay flush synchronizes, materializes current tensor values to CPU, writes root inputs.pt / outputs.pt compatibility files, and also writes immutable graph_flushes/flush_XXXX/ snapshots.
  • Nested tensors inside list / tuple / dict inputs are recursively extracted and dumped with stable keys such as arg_2__0, arg_2__1, while metadata records tensor_key links so replay_from_dump() can reconstruct containers.
  • A process-exit / SIGTERM flush remains as a fallback, but normal PyTorch graph replay preserves every replay automatically.

This means root inputs.pt / outputs.pt contain the latest flushed replay, while graph_flushes/flush_XXXX/ preserves per-replay snapshots.

Caveats:

  • Level 10 graph dumping is intentionally expensive. It adds synchronization and disk writes to every graph replay.
  • Dumping nested KV-cache tensors can create large dumps quickly. Use FLASHINFER_DUMP_INCLUDE, FLASHINFER_DUMP_MAX_COUNT, and FLASHINFER_DUMP_MAX_SIZE_GB for short targeted debug runs.
  • Deferred graph tensor references are retained until clear_graph_dumps() or process exit.
  • kill -9 can still lose pending deferred writes.

Files: flashinfer/api_logging.py, docs/logging.rst, tests/utils/test_logging.py.

Part B: fi_trace template additions / fixes

End-to-end validation under sglang DSR1 + trtllm_mla showed these gaps:

  1. gemm_fp8_nt_groupwise had no trace= template. The op fires heavily in DSR1 + flashinfer_trtllm MoE but produced no trace JSON.
  2. trtllm_batch_decode_mla_trace and xqa_batch_decode_mla_trace declared kv_cache as rank-3 [num_pages, page_size, head_dim_qk], but the kernel accepts and sglang passes the rank-4 [num_pages, 1, page_size, head_dim_qk] form. They were also missing skip_softmax_threshold_scale_factor, and workspace dtype was corrected from int8 to uint8.
  3. mla_rope_quantize_fp8_trace inherited the rank-3 GQA rope quant axes with num_k_heads, but MLA passes rank-2 K tensors with num_k_heads=1 collapsed.
  4. gemm_fp8_nt_groupwise_trace.b_scale axes were corrected to [N_div_block, K_div_block] to match the trtllm path and sglang runtime layout.

Files: flashinfer/gemm/gemm_base.py, flashinfer/trace/templates/{gemm,attention,rope}.py, tests/trace/example.py, regenerated JSONs in tests/trace/fi_trace_out/.

Test plan

Unit / formatting

  • pre-commit run --files docs/logging.rst flashinfer/api_logging.py tests/utils/test_logging.py passed.
  • CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/averyh/flashinfer-pr3172 pytest -q tests/utils/test_logging.py passed: 20 passed, 2 warnings.
  • python -m compileall -q flashinfer/api_logging.py tests/utils/test_logging.py passed.
  • git diff --check passed.
  • pytest tests/trace/ passed previously: 440 passed, 8 skipped.
  • tests/utils/test_logging_replay.py partial local run: 14 passed, 2 failed due local environment/JIT setup before replay validation:
    • test_bmm_fp8_replay: cuDNN reported multiple CUDA runtime libraries, libcudart.so.12 and libcudart.so.13.
    • test_mm_fp4_replay: local JIT build failed because cutlass/arch/barrier.h was missing.

Manual CUDA graph smoke tests

  • Manual eager dispatch over fp32 (with NaN + Inf), bf16, int32: values match expected stats.
  • Manual capture/replay at level 5: g.replay() after mutating the input shows updated stats; multiple replays work.
  • Manual capture/replay at level 10: captured graph plus repeated replays with different inputs; replay hook writes root files and immutable graph_flushes/flush_XXXX/ snapshots.
  • Manual level-10 nested tensor capture test: tuple tensors are dumped as separate tensor files and reconstructed by replay_from_dump().

SGLang single-GPU validation, May 13 2026

Environment:

  • GPU: NVIDIA B200
  • PyTorch: 2.9.1+cu128
  • sglang: 0.5.10.post1
  • FlashInfer import path: /home/averyh/flashinfer-pr3172/flashinfer/__init__.py
  • Model: meta-llama/Llama-3.2-3B-Instruct
  • Command shape: python -m sglang.launch_server --attention-backend flashinfer --sampling-backend flashinfer --cuda-graph-bs 1 --cuda-graph-max-bs 1 --context-length 512 --max-total-tokens 1024 --mem-fraction-static 0.55 --dtype bfloat16

Level 5 result:

  • Server arguments showed attention_backend='flashinfer', sampling_backend='flashinfer', disable_cuda_graph=False, cuda_graph_bs=[1].
  • SGLang completed Capture cuda graph bs [1] and Capture cuda graph end.
  • Completion request returned successfully: " Paris. The capital".
  • SGLang request log showed cuda graph: True.
  • FlashInfer level-5 log had 322726 lines, including 283 [stats deferred to GPU kernel: id=...] markers.
  • Server stdout contained device-side [flashinfer stats] id=... lines.
  • No [statistics skipped: CUDA graph capture in progress] lines were observed for this run.

Level 10 result:

  • Server arguments showed attention_backend='flashinfer', sampling_backend='flashinfer', disable_cuda_graph=False, cuda_graph_bs=[1].
  • SGLang completed CUDA graph capture and served the same completion request successfully.
  • Final recursive-dump validation used FLASHINFER_DUMP_INCLUDE='BatchDecodeWithPagedKVCacheWrapper.run' and FLASHINFER_DUMP_MAX_COUNT=64.
  • FlashInfer log showed 16 graph-deferred input/output sections and 12 automatic flush_graph_dumps (CUDAGraph.replay): wrote ... replay flushes.
  • Dump tree contained 64 dump dirs, 8 dirs with graph_flushes, and 96 immutable replay snapshot dirs.
  • Example snapshot input keys were arg_1, arg_2__0, arg_2__1, proving both the query tensor and nested K/V cache tuple tensors were dumped.
  • Example snapshot output key was result.
  • Consecutive flush_0001 to flush_0004 snapshots had different input/output sums, confirming replay snapshots are not just the last buffer state.
  • Dump size for this targeted run was about 647 MB, which confirms correctness and also the expected cost of recursive K/V tensor dumping.
  • No Capture cuda graph failed, cudaErrorStreamCaptureInvalidated, or BatchDecodeWithPagedKVCache failed errors in the final level-10 run.

Larger integration validation

  • End-to-end sglang DSR1 FP8 TP=8 on 8x B200 at level 3, full bench: 40/40 requests in 99 s, output throughput 416 tok/s, mean TPOT 9.3 ms/token, no CUDA error, every (api, axes) tuple produced a trace JSON.
  • End-to-end sglang at level 5: deferred-stats kernel JIT-builds without racing graph capture, captures into graph, fires on replay. No cudaErrorStreamCaptureInvalidated. Runtime is dominated by unthrottled device-side printf rate.

PR state

  • Head commit after latest update: 008a2836
  • Base refreshed against upstream main commit: 103fcf86
  • Branch pushed to yyihuang/flashinfer:cuda-graph-api-logging

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 25, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a device-side CUDA kernel and TVM FFI binding to compute per-tensor min/max/mean and NaN/Inf counts during CUDA-graph capture/replay; integrates the kernel into JIT/AOT packaging, changes capture-time pinned-buffer staging and deferred dump semantics for log levels 5 and 10, and exposes dump flush/clear APIs.

Changes

Cohort / File(s) Summary
GPU Kernel & FFI Binding
csrc/api_log_stats.cu, csrc/flashinfer_api_log_stats_binding.cu
Adds api_log_print_tensor_stats CUDA kernel that reduces tensor values on-device (handles fp16/bf16, counts NaN/Inf, device-side printf) and exports it as a TVM FFI function.
JIT / AOT Integration
flashinfer/jit/api_log_stats.py, flashinfer/aot.py
Adds gen_api_log_stats_module() JitSpec and wires it into AOT misc module generation when add_misc is enabled.
Python API: capture & dump logic
flashinfer/api_logging.py
Level‑5: attempt to JIT-load/launch GPU stats kernel for supported dtypes during CUDA-graph capture, otherwise fall back to legacy skip marker. Level‑10: switch to pinned-buffer staging for D2H in capture, defer disk writes, update metadata with pending flush state, and add flush_graph_dumps() / clear_graph_dumps() APIs.
Tests
tests/utils/test_logging.py
Updates CUDA-graph compatibility test to accept deferred GPU-correlation marker and adds tests covering level‑10 graph dumps and warmup requirements.
Docs
docs/logging.rst
Documents captured-kernel behavior for level‑5, supported dtypes, deferred dump semantics, pinned-buffer staging, and warmup expectations for level‑10.
CLI / Tooling
tools/dump_with_cuda_graph.py
New wrapper to force loglevel 10, set dump environment, install a CUDAGraph.replay monkey-patch that calls flush_graph_dumps() after replays, and run target program preserving in-process behavior.
Misc packaging
flashinfer/jit/..., flashinfer/aot.py
Build/package inclusion of the new api_log_stats JIT module into generated artifacts.

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host (Python)
    participant JIT as JIT/AOT
    participant TVM as TVM FFI
    participant GPU as GPU Kernel
    participant DevIO as Device printf

    Host->>JIT: build/load api_log_stats module
    JIT-->>Host: module (or build failure)
    alt module built & dtype supported
        Host->>Host: emit correlation marker id=N (deferred-to-GPU)
        Host->>TVM: api_log_print_tensor_stats(tensor, id) on capture stream
        TVM->>GPU: launch kernel (device-stream)
        GPU->>GPU: per-thread convert/reduce (min/max/sum/nan/inf)
        GPU->>DevIO: printf("[flashinfer stats] id=N ...")
    else unsupported dtype or load/launch failure
        Host->>Host: emit "statistics skipped: CUDA graph capture in progress"
    end
    Host->>Host: end capture
    Host->>Host: cuda_graph.replay()
    GPU->>DevIO: printf outputs on replay (if kernel ran)
    Host->>Host: flush_graph_dumps() to persist pinned buffers (level 10)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

ready, run-ci

Suggested reviewers

  • nvmbreughe
  • jimmyzho
  • kahyunnam
  • bkryu
  • cyx-6
  • jiahanc
  • yzh119

Poem

🐇 I hop through kernels, bytes and stats align,
I count NaNs and Infs and print an ID line,
Shared threads gather sums in warmed pinned space,
Replays echo metrics back from GPU's place,
Flush the dumps, warm the pins — telemetry's fine!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 55.56% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Title check ✅ Passed The pull request title accurately reflects the main changes: adding CUDA-graph-compatible logging for levels 5 and 10, plus fi_trace template fixes.
Description check ✅ Passed The PR description is comprehensive, detailed, and well-structured, covering objectives, technical approach, testing, and validation results.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a CUDA-graph-friendly mechanism for logging tensor statistics at Level 5. By utilizing a device-side kernel and printf, statistics can be captured and reported during graph replay without requiring stream synchronization. Feedback focuses on ensuring robustness and accuracy: specifically, adding a check for tensor contiguity to prevent incorrect memory access, using double precision for intermediate calculations to avoid data loss with large integers, and aligning the handling of infinite values in min/max reductions with the existing eager logging implementation.

Comment thread flashinfer/api_logging.py
Comment on lines +1188 to +1189
if tensor.dtype not in _GPU_STATS_SUPPORTED_DTYPES:
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The GPU statistics kernel performs a linear scan of the tensor data using data[i], which assumes the tensor is contiguous in memory. If a non-contiguous tensor (e.g., a slice or a transposed tensor) is passed, the kernel will read incorrect data or potentially access memory out of bounds. A check for tensor.is_contiguous() should be added here to ensure the kernel is only launched for supported layouts, falling back to the "skipped" message otherwise.

Suggested change
if tensor.dtype not in _GPU_STATS_SUPPORTED_DTYPES:
return None
if tensor.dtype not in _GPU_STATS_SUPPORTED_DTYPES or not tensor.is_contiguous():
return None

Comment thread csrc/api_log_stats.cu Outdated
Comment on lines +38 to +40
__device__ inline float to_float_impl(T x) {
return static_cast<float>(x);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using float as the intermediate type for statistics causes precision loss for int32_t and int64_t types when values exceed the 24-bit mantissa limit (approx. 16.7 million). Since the final output is formatted as a double and the sum is already tracked as a double, it is better to use double for the to_float_impl return type and the thread_min/thread_max accumulators to preserve precision for integer types.

template <typename T>
__device__ inline double to_double_impl(T x) {
  return static_cast<double>(x);
}

Comment thread csrc/api_log_stats.cu
Comment on lines +77 to +81
if (is_nan) {
thread_nan += 1;
} else if (is_inf) {
thread_inf += 1;
} else {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic excludes infinite values from the min and max calculations, which differs from the behavior of the eager path where torch.min() and torch.max() include Infs. This inconsistency can be confusing for users. For example, a tensor containing [1.0, inf] would show max=1.0 in the GPU log but max=inf in the eager log. Consider including Infs in the min/max reduction while still counting them separately to match the eager logging output.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
csrc/api_log_stats.cu (1)

69-86: Integer dtypes lose precision because reduction goes through float.

For int32, int64, and uint8, to_float_impl casts each element to float before min/max/sum. Float can only represent integers exactly up to 2^24, so for large int32/int64 tensors (e.g. token-id tensors, paged-KV indices, cu_seqlens with large offsets) the printed min/max and especially mean will be inaccurate by potentially many ULPs.

Since int64 indexing tensors are exactly the kind of inputs users are most likely to be debugging at level 5, this is worth fixing. One approach: keep a double accumulator and dispatch min/max via a per-T traits struct so integer types reduce in their native domain.

♻️ Sketch of a precision-preserving variant
+template <typename T>
+struct StatsAccum {
+  using Acc = float;
+};
+template <> struct StatsAccum<int32_t>  { using Acc = double; };
+template <> struct StatsAccum<int64_t>  { using Acc = double; };
+template <> struct StatsAccum<uint8_t>  { using Acc = double; };
+
+template <typename T>
+__device__ inline typename StatsAccum<T>::Acc to_acc(T x) {
+  return static_cast<typename StatsAccum<T>::Acc>(x);
+}
+__device__ inline float to_acc(nv_half x)     { return __half2float(x); }
+__device__ inline float to_acc(nv_bfloat16 x) { return __bfloat162float(x); }

Then template PrintTensorStatsKernel so thread_min/thread_max/thread_sum use the appropriate accumulator type for T.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/api_log_stats.cu` around lines 69 - 86, The reduction currently casts
every element with to_float_impl inside PrintTensorStatsKernel so integer dtypes
lose precision; change the kernel to use a type-traits dispatch (e.g., a
template AccumTypeFor<T>) that selects double for integer inputs and
float/double for float-like types, keep per-type min/max comparisons in the
native domain (avoid isnan/isinf for non-floats by using IsFloatLike<T>), and
make thread_min/thread_max/thread_sum use the chosen accumulator type (sum as
double for integers) so min/max/mean are computed without truncation for
int32/int64/uint8 while preserving existing float handling.
docs/logging.rst (1)

253-256: Minor: device printf flushing wording.

The phrasing "PyTorch routes device printf to the host stream" is slightly inaccurate. Device-side printf is buffered by the CUDA runtime and flushed on sync points (e.g. cudaDeviceSynchronize/stream sync) — it is not specifically routed by PyTorch. Consider rewording, e.g. "the CUDA runtime flushes the device printf buffer to stdout on stream synchronization."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/logging.rst` around lines 253 - 256, Reword the sentence about device
printf: replace "PyTorch routes device printf to the host stream" with wording
that attributes flushing to the CUDA runtime and synchronization (e.g., mention
that the CUDA runtime flushes the device printf buffer to stdout on stream or
device synchronization). Update the line describing cuda_graph.replay() so it
states the captured kernel prints statistics to stdout because the CUDA runtime
flushes device printf on sync points (not that PyTorch routes it).
flashinfer/api_logging.py (1)

1161-1199: First-call JIT build inside graph capture is safe but slow; consider pre-warming as optional optimization.

_get_api_log_stats_kernel() with @functools.cache may run the build_and_load() chain on first level-5 capture. The build itself is host-only (nvcc subprocess + dlopen of .so) so it will not poison the capture stream. However, the build can take seconds, which may introduce unexpected latency on first use.

Two options if latency is a concern:

  • Pre-warm _get_api_log_stats_kernel() once at module import or on first CUDA tensor observed outside graph capture.
  • Add a note to the level-5 docstring that a short warmup pass is recommended before entering graph capture.

This is an optional optimization; the current design is correct and works under capture.

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/api_log_stats.cu`:
- Around line 130-135: The kernel launch line in launch_print_tensor_stats is
not formatted to match the project's clang-format rules; update the formatting
of the PrintTensorStatsKernel<<<...>>> call inside launch_print_tensor_stats
(the call to PrintTensorStatsKernel<T><<<1, kBlockSize, 0,
stream>>>(static_cast<const T*>(data_ptr), numel, tensor_id)) to match
clang-format, then run the pre-commit formatter and commit the change (e.g., run
pre-commit run clang-format --files csrc/api_log_stats.cu) so CI passes.
- Around line 113-126: The current printf in the tid==0 block prints sentinel
s_min/s_max and a misleading mean when numel>0 but valid==0; modify the tid==0
handling to check valid (computed from numel - s_nan[0] - s_inf[0]) and when
valid==0 print an explicit message such as "[flashinfer stats] id=%lld
numel=%lld all_nan_or_inf nan=%lld inf=%lld" (include tensor_id, numel,
s_nan[0], s_inf[0]) instead of printing s_min/s_max/mean, otherwise keep the
existing min/max/mean printing; ensure mean and use of s_min/s_max only occur
when valid>0.

---

Nitpick comments:
In `@csrc/api_log_stats.cu`:
- Around line 69-86: The reduction currently casts every element with
to_float_impl inside PrintTensorStatsKernel so integer dtypes lose precision;
change the kernel to use a type-traits dispatch (e.g., a template
AccumTypeFor<T>) that selects double for integer inputs and float/double for
float-like types, keep per-type min/max comparisons in the native domain (avoid
isnan/isinf for non-floats by using IsFloatLike<T>), and make
thread_min/thread_max/thread_sum use the chosen accumulator type (sum as double
for integers) so min/max/mean are computed without truncation for
int32/int64/uint8 while preserving existing float handling.

In `@docs/logging.rst`:
- Around line 253-256: Reword the sentence about device printf: replace "PyTorch
routes device printf to the host stream" with wording that attributes flushing
to the CUDA runtime and synchronization (e.g., mention that the CUDA runtime
flushes the device printf buffer to stdout on stream or device synchronization).
Update the line describing cuda_graph.replay() so it states the captured kernel
prints statistics to stdout because the CUDA runtime flushes device printf on
sync points (not that PyTorch routes it).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 233331c7-9d12-4bef-8735-8cf70cef6557

📥 Commits

Reviewing files that changed from the base of the PR and between 24c4aee and c53c25a.

📒 Files selected for processing (7)
  • csrc/api_log_stats.cu
  • csrc/flashinfer_api_log_stats_binding.cu
  • docs/logging.rst
  • flashinfer/aot.py
  • flashinfer/api_logging.py
  • flashinfer/jit/api_log_stats.py
  • tests/utils/test_logging.py

Comment thread csrc/api_log_stats.cu
Comment thread csrc/api_log_stats.cu
@yyihuang
Copy link
Copy Markdown
Collaborator Author

Follow-up commit b66a2f64 extends the same approach to Level 10 (tensor dumping).

What's new

tensor.cpu() in the dump path also synchronizes the captured stream, so before this change Level 10 just crashed under torch.cuda.graph(...). The fix:

  • Cached pinned host buffers: every dump tensor is staged through a buffer cached by (func_name, key, shape, dtype). The cache warms up on the first eager call; under capture we only do a pinned.copy_(t, non_blocking=True) (graph-safe), never cudaHostAlloc (forbidden under capture).
  • Deferred writes: during capture we register the pinned buffers + dump dirs in _PENDING_GRAPH_DUMPS and skip the immediate inputs.pt / outputs.pt write. Metadata is recorded with execution_status: "graph_capture_pending_flush".
  • New API: flashinfer.api_logging.flush_graph_dumps(synchronize=True) syncs the stream and writes the current pinned-buffer contents to disk; clear_graph_dumps() releases the held buffers.

Usage

from flashinfer.api_logging import flush_graph_dumps, clear_graph_dumps

# Eager warmup primes the pinned-buffer cache.
out = wrapper.run(q, kv_cache)
torch.cuda.synchronize()

g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    wrapper.run(q, kv_cache)

# Per replay, flush to capture this replay's tensor values.
g.replay()
flush_graph_dumps()

q.copy_(new_q)
g.replay()
flush_graph_dumps()  # dump now reflects the new inputs

clear_graph_dumps()

Caveats

  • Warmup required. The pinned-buffer cache must be primed by at least one eager call. If capture is the first call, _stage_tensor_to_pinned raises a clear RuntimeError explaining why.
  • Strides not preserved. Captured copy_() materializes data into a contiguous pinned buffer (matches existing safetensors-mode behavior).

Files

  • flashinfer/api_logging.py — pinned-buffer cache, _extract_tensors_and_metadata_pinned, capture-aware branches in _dump_function_inputs/_outputs, flush_graph_dumps(), clear_graph_dumps().
  • tests/utils/test_logging.pytest_level_10_cuda_graph_dumps (capture + 2 replays with mutated inputs, asserts on-disk dump contents) and test_level_10_cuda_graph_requires_warmup.
  • docs/logging.rst — new "Level 10 (Tensor Dumping) under CUDA Graph" section.

Test plan

  • pytest tests/utils/test_logging.py — 18/18 pass (16 prior + 2 new).
  • Manual smoke test: capture once, replay twice with different inputs, verify each flush_graph_dumps() produces dump files reflecting that replay.
  • Pre-existing failures in test_logging_replay.py::{test_bmm_fp8_replay, test_mm_fp4_replay} are environmental (libcudart.so.12 vs .so.13 mismatch on this box) and reproduce on the prior commit too.

@yyihuang yyihuang changed the title feat(logging): make level-5 stats work under CUDA graph capture feat(logging): make level-5 stats and level-10 dumps work under CUDA graph capture Apr 25, 2026
@yyihuang
Copy link
Copy Markdown
Collaborator Author

Added tools/dump_with_cuda_graph.py (commit fdce0664) so users can drive level-10 dumps under CUDA graphs without modifying the host program — useful when the program is something like sglang that you can't easily patch.

What it does

  • Sets the FLASHINFER_LOGLEVEL=10 and FLASHINFER_DUMP_* env vars from CLI flags.
  • Monkey-patches torch.cuda.CUDAGraph.replay so flashinfer.api_logging.flush_graph_dumps() fires automatically after every replay (idempotent — safe to install twice).
  • Runs the target via runpy when it's python … so the patch stays alive in-process; falls back to os.execvp and warns for non-Python targets.

Usage

python tools/dump_with_cuda_graph.py \
    --dump-dir /tmp/fi_dumps \
    --include '*decode*' \
    --max-count 10 \
    -- \
    python -m sglang.launch_server --model meta-llama/Llama-3-8B ...

Anything before -- configures the wrapper; anything after -- is the command to exec.

Caveats called out in the docstring

  • Eager warmup must happen before capture (sglang already does this for graph correctness, so it works in practice). If the capture path sees a tensor it never saw eagerly, the level-10 dump path raises a clear RuntimeError.
  • Every replay overwrites the same dump files (latest replay wins). Always set --include and a small --max-count for high-QPS workloads — without scoping, each decode step rewrites every captured dump.
  • If you launch sglang via a non-Python entry point (custom binary), the patch doesn't propagate to the child. Prefer python -m sglang....

Smoke-tested

$ python tools/dump_with_cuda_graph.py --dump-dir /tmp/fi --include 'my_op' -- python /tmp/target.py
[dump_with_cuda_graph] FLASHINFER_LOGLEVEL=10, dump_dir=/tmp/fi, include='my_op', ...
DUMP_DIR=/tmp/fi/.../my_op_call0002

$ python -c "import torch; d='...call0002'; \
    print(torch.load(d+'/inputs.pt')['arg_0'].mean().item(), \
          torch.load(d+'/outputs.pt')['result'].mean().item())"
10.0 30.0   # values from the SECOND replay (mutated inputs), proving auto-flush works

Idempotency check also passes: a second install_replay_autoflush() call is a no-op.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/utils/test_logging.py (1)

46-68: ⚠️ Potential issue | 🟡 Minor

setup_and_teardown doesn't restore FLASHINFER_DUMP_DIR, leaking state into later tests.

The new level-10 graph tests (test_level_10_cuda_graph_dumps, test_level_10_cuda_graph_requires_warmup) set FLASHINFER_DUMP_DIR (and re-set FLASHINFER_LOGDEST), but the autouse fixture only saves/restores FLASHINFER_LOGLEVEL and FLASHINFER_LOGDEST. Any test running after these will see a FLASHINFER_DUMP_DIR pointing at a now-deleted tmp_path, which can change behavior of _warn_dump() and the dump count tracking on module reimport.

🧪 Extend the fixture to also restore the dump-related env vars
     `@pytest.fixture`(autouse=True)
     def setup_and_teardown(self):
         """Reset environment and reimport logging module for each test."""
-        # Store original environment
-        original_level = os.environ.get("FLASHINFER_LOGLEVEL")
-        original_dest = os.environ.get("FLASHINFER_LOGDEST")
+        # Store original environment
+        keys = (
+            "FLASHINFER_LOGLEVEL",
+            "FLASHINFER_LOGDEST",
+            "FLASHINFER_DUMP_DIR",
+        )
+        original = {k: os.environ.get(k) for k in keys}
 
         yield
 
-        # Restore original environment
-        if original_level is not None:
-            os.environ["FLASHINFER_LOGLEVEL"] = original_level
-        elif "FLASHINFER_LOGLEVEL" in os.environ:
-            del os.environ["FLASHINFER_LOGLEVEL"]
-
-        if original_dest is not None:
-            os.environ["FLASHINFER_LOGDEST"] = original_dest
-        elif "FLASHINFER_LOGDEST" in os.environ:
-            del os.environ["FLASHINFER_LOGDEST"]
+        # Restore original environment
+        for k, v in original.items():
+            if v is not None:
+                os.environ[k] = v
+            elif k in os.environ:
+                del os.environ[k]
 
         # Force reimport to pick up new environment variables
         if "flashinfer.api_logging" in sys.modules:
             del sys.modules["flashinfer.api_logging"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_logging.py` around lines 46 - 68, The autouse fixture
setup_and_teardown fails to save/restore FLASHINFER_DUMP_DIR (and any
dump-related env vars), leaking state into later tests; update
setup_and_teardown to capture original values for FLASHINFER_DUMP_DIR (and any
other dump-related env vars you add), restore them in the teardown branch
(mirroring the pattern used for FLASHINFER_LOGLEVEL and FLASHINFER_LOGDEST), and
ensure the module reimport logic for flashinfer.api_logging still runs so
functions like _warn_dump() and dump-count tracking see the restored
environment.
flashinfer/api_logging.py (1)

327-648: ⚠️ Potential issue | 🟠 Major

Pre-commit ruff-format is failing on this hunk; CI is currently blocked.

The pre-commit job reports formatting changes required across the new dump-staging code in this file. Run the formatter locally before pushing:

pre-commit run --all-files
# or, scoped:
ruff format flashinfer/api_logging.py
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/api_logging.py` around lines 327 - 648, The new dump-staging code
(see functions _stage_tensor_to_pinned, _extract_tensors_and_metadata_pinned and
_dump_function_inputs) is failing the pre-commit ruff-format hook; run the
project formatter and re-stage the file to fix whitespace/formatting issues
(e.g. run `pre-commit run --all-files` or `ruff format
flashinfer/api_logging.py`) and then amend the commit so CI passes.
🧹 Nitpick comments (1)
flashinfer/api_logging.py (1)

308-316: Optional: narrow the blanket except Exception (BLE001).

torch.cuda.is_current_stream_capturing() is documented to either return a bool or raise on missing CUDA context; swallowing all exceptions hides genuine bugs (e.g., a CUDA driver failure manifesting as "stats look fine" until something else explodes). The hasattr check above already covers older PyTorch, so the try/except here is mostly defensive. Consider scoping it to RuntimeError (or removing it) so unexpected failures aren't masked.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/api_logging.py` around lines 308 - 316, The blanket except in
_is_current_stream_capturing hides unexpected errors; replace the broad "except
Exception" around the call to torch.cuda.is_current_stream_capturing() with a
narrower catch (e.g., "except RuntimeError") or remove the try/except entirely
so only the documented missing-CUDA-context error is caught while other failures
surface; update the exception handling around
torch.cuda.is_current_stream_capturing() accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/api_logging.py`:
- Around line 318-365: The pinned-buffer cache (_PINNED_DUMP_BUFFER_CACHE)
aliases multiple invocations of the same (func_name, key, shape, dtype) inside a
single torch.cuda.graph capture, causing different dump entries to share and
overwrite the same pinned tensor; update _stage_tensor_to_pinned to avoid
aliasing by either (A) when _is_current_stream_capturing() is true, append a
per-call unique discriminator (e.g., a _dump_call_counter[func_name] or
ephemeral UUID) to the cache_key so each in-graph call gets its own pinned
buffer, or (B) detect an existing cache entry during capture and raise a
RuntimeError mirroring the warmup check; ensure you reference and update the
same cache key logic in _stage_tensor_to_pinned and initialize/maintain the
per-call counter/state (or explicit error path) so _PENDING_GRAPH_DUMPS entries
will point to distinct buffers rather than aliasing.
- Around line 1403-1441: The level-5 stats path lazily JIT-compiles in
_get_api_log_stats_kernel (called by _launch_gpu_stats_kernel) which can trigger
illegal module loads during cudaStreamCaptureModeGlobal; fix by adding an eager
warmup call to _get_api_log_stats_kernel() during initialization (for example
invoke it from _warn_dump or _log_system_info when FLASHINFER_LOGLEVEL >= 5) so
the kernel is built before any capture, or alternatively update the docstring
where level-10 warmup is documented (around the level-10 note) to clearly state
that callers must warm up _get_api_log_stats_kernel() before starting captures
at level 5; reference _get_api_log_stats_kernel, _launch_gpu_stats_kernel,
_warn_dump, and _log_system_info to locate the changes.

In `@tests/utils/test_logging.py`:
- Around line 685-717: Update test_level_10_cuda_graph_requires_warmup to assert
that a RuntimeError is raised instead of accepting the "no exception" branch:
replace the try/except and the silent-success fallback with a
pytest.raises(RuntimeError) context around the with torch.cuda.graph(graph):
_id(x) block, import pytest at top of the test, and remove the broad bare except
to both make the failure deterministic (matching _stage_tensor_to_pinned and
_is_current_stream_capturing behavior) and satisfy Ruff BLE001 by narrowing the
exception expectation; keep existing uses of flashinfer_api and
_PINNED_DUMP_BUFFER_CACHE as-is.

---

Outside diff comments:
In `@flashinfer/api_logging.py`:
- Around line 327-648: The new dump-staging code (see functions
_stage_tensor_to_pinned, _extract_tensors_and_metadata_pinned and
_dump_function_inputs) is failing the pre-commit ruff-format hook; run the
project formatter and re-stage the file to fix whitespace/formatting issues
(e.g. run `pre-commit run --all-files` or `ruff format
flashinfer/api_logging.py`) and then amend the commit so CI passes.

In `@tests/utils/test_logging.py`:
- Around line 46-68: The autouse fixture setup_and_teardown fails to
save/restore FLASHINFER_DUMP_DIR (and any dump-related env vars), leaking state
into later tests; update setup_and_teardown to capture original values for
FLASHINFER_DUMP_DIR (and any other dump-related env vars you add), restore them
in the teardown branch (mirroring the pattern used for FLASHINFER_LOGLEVEL and
FLASHINFER_LOGDEST), and ensure the module reimport logic for
flashinfer.api_logging still runs so functions like _warn_dump() and dump-count
tracking see the restored environment.

---

Nitpick comments:
In `@flashinfer/api_logging.py`:
- Around line 308-316: The blanket except in _is_current_stream_capturing hides
unexpected errors; replace the broad "except Exception" around the call to
torch.cuda.is_current_stream_capturing() with a narrower catch (e.g., "except
RuntimeError") or remove the try/except entirely so only the documented
missing-CUDA-context error is caught while other failures surface; update the
exception handling around torch.cuda.is_current_stream_capturing() accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c91c3f8a-903f-4af1-9e7a-75ac8587bff9

📥 Commits

Reviewing files that changed from the base of the PR and between c53c25a and b66a2f6.

📒 Files selected for processing (3)
  • docs/logging.rst
  • flashinfer/api_logging.py
  • tests/utils/test_logging.py
✅ Files skipped from review due to trivial changes (1)
  • docs/logging.rst

Comment thread flashinfer/api_logging.py Outdated
Comment thread flashinfer/api_logging.py
Comment thread tests/utils/test_logging.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tools/dump_with_cuda_graph.py (2)

139-158: Idempotent branch returns the patched callable, not the original — contradicts the docstring.

The docstring promises "the original replay callable (so callers can reverse the patch in tests)", but on the already-patched path you return torch.cuda.CUDAGraph.replay, which is the wrapper itself; the true original is no longer accessible. Either stash the original on the wrapper at install time and return it here, or update the docstring to say it returns the currently-installed callable.

♻️ Proposed fix (preserve and return the real original)
     if getattr(torch.cuda.CUDAGraph.replay, "_flashinfer_autoflush", False):
         # Already patched (idempotent).
-        return torch.cuda.CUDAGraph.replay
+        return getattr(torch.cuda.CUDAGraph.replay, "_flashinfer_original", None)

     original = torch.cuda.CUDAGraph.replay

     def replay_with_flush(self, *args, **kwargs):
         ...

     replay_with_flush._flashinfer_autoflush = True  # type: ignore[attr-defined]
+    replay_with_flush._flashinfer_original = original  # type: ignore[attr-defined]
     torch.cuda.CUDAGraph.replay = replay_with_flush  # type: ignore[assignment]
     return original
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/dump_with_cuda_graph.py` around lines 139 - 158, The idempotent branch
in install_replay_autoflush currently returns torch.cuda.CUDAGraph.replay (the
patched wrapper), which contradicts the docstring promise to return the original
replay callable; fix by storing the real original replay when you first patch
(e.g., attach it to the wrapper under a unique attribute name like
_flashinfer_original_replay when wrapping in install_replay_autoflush) and in
the early-return path return that stored original (check _flashinfer_autoflush
to detect patching and return the attached _flashinfer_original_replay), leaving
the wrapper flag _flashinfer_autoflush to indicate idempotency.

200-200: Nit: use unpacking instead of list concatenation (ruff RUF005).

-            sys.argv = [module] + rest[2:]
+            sys.argv = [module, *rest[2:]]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/dump_with_cuda_graph.py` at line 200, Replace the list concatenation
that builds sys.argv (currently using [module] + rest[2:]) with list unpacking
to satisfy ruff RUF005; locate the assignment to sys.argv that references the
variables module and rest and change it to use the unpacking form (module
followed by the unpacked slice of rest) so the result is the same but uses
unpacking instead of concatenation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tools/dump_with_cuda_graph.py`:
- Around line 194-217: The wrapper's Python detection and -m/script handling are
too narrow: update the logic around head/target_argv (symbols: head,
target_argv, rest) to treat any executable whose basename contains "python" (or
equals sys.executable) as a Python interpreter, and scan rest while stripping
leading interpreter flags (flags beginning with "-" and their potential values,
e.g., -W warnings, -X args) to find the first non-flag token; if that token is
"-m" call runpy.run_module(module, run_name="__main__", alter_sys=True) with
sys.argv set to the module + its args (symbols: runpy.run_module, sys.argv),
otherwise treat the first non-flag token as a script path and call
runpy.run_path(script, run_name="__main__") with sys.argv set to script + its
args (symbol: runpy.run_path); only fall back to os.execvp(head, target_argv)
when head is not a Python interpreter (symbol: os.execvp).

---

Nitpick comments:
In `@tools/dump_with_cuda_graph.py`:
- Around line 139-158: The idempotent branch in install_replay_autoflush
currently returns torch.cuda.CUDAGraph.replay (the patched wrapper), which
contradicts the docstring promise to return the original replay callable; fix by
storing the real original replay when you first patch (e.g., attach it to the
wrapper under a unique attribute name like _flashinfer_original_replay when
wrapping in install_replay_autoflush) and in the early-return path return that
stored original (check _flashinfer_autoflush to detect patching and return the
attached _flashinfer_original_replay), leaving the wrapper flag
_flashinfer_autoflush to indicate idempotency.
- Line 200: Replace the list concatenation that builds sys.argv (currently using
[module] + rest[2:]) with list unpacking to satisfy ruff RUF005; locate the
assignment to sys.argv that references the variables module and rest and change
it to use the unpacking form (module followed by the unpacked slice of rest) so
the result is the same but uses unpacking instead of concatenation.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 662546b0-9164-4a5c-b4a9-45c1edccdb48

📥 Commits

Reviewing files that changed from the base of the PR and between b66a2f6 and fdce066.

📒 Files selected for processing (1)
  • tools/dump_with_cuda_graph.py

Comment thread tools/dump_with_cuda_graph.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/api_logging.py`:
- Around line 797-827: flush_graph_dumps currently writes tensor files but never
promotes the per-dump or session JSONL records from execution_status
"graph_capture_pending_flush" to a terminal state, so consumers still see them
as pending; modify flush_graph_dumps to, after each successful write (inside the
loop, when a tensor file is saved and before incrementing flushed), append a
small completion record to both the per-dump metadata.jsonl and the central
session.jsonl indicating execution_status="completed" (or the same terminal
state used by eager mode), include identifying fields like func_name, kind,
dump_dir and a timestamp, and ensure you reference the same keys used by
_dump_function_inputs/_dump_function_outputs so readers can correlate entries.
- Around line 472-481: The change unconditionally stages tensors into pinned
contiguous buffers via
_extract_tensors_and_metadata_pinned/_stage_tensor_to_pinned which destroys
source strides for eager (non-capture) dumps; revert to using the original
CPU-path (_extract_tensors_and_metadata using arg.cpu()) for the
non-capturing/eager path and only call _extract_tensors_and_metadata_pinned when
_is_current_stream_capturing() is true (apply same fix in
_dump_function_outputs), ensure tensor_details["stride"] is recorded from the
source tensor before any staging/copy occurs (not from the pinned buffer), and
update the public decorator and _extract_tensors_and_metadata docstrings to
accurately reflect when stride/contiguity is preserved versus lost.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7e5f3b69-362a-40e5-9c48-c83212209614

📥 Commits

Reviewing files that changed from the base of the PR and between fdce066 and 6551545.

📒 Files selected for processing (2)
  • csrc/api_log_stats.cu
  • flashinfer/api_logging.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/api_log_stats.cu

Comment thread flashinfer/api_logging.py Outdated
Comment thread flashinfer/api_logging.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
flashinfer/api_logging.py (4)

472-481: ⚠️ Potential issue | 🟠 Major

Eager Level-10 dumps now silently lose CUDA tensor strides (duplicate).

_extract_tensors_and_metadata_pinned is invoked unconditionally for both capture and eager paths (also at lines 660–668). In eager mode, _stage_tensor_to_pinned allocates a contiguous pinned buffer and pinned.copy_(t, ...) produces a contiguous tensor regardless of source stride, so inputs.pt/outputs.pt and tensor_details["stride"] (lines 566, 720–726) no longer reflect the original layout. The decorator docstring at line 1922 and _extract_tensors_and_metadata's docstring (lines 276/282) still advertise stride preservation. Restrict pinned staging to the capture path (or pre-warm the cache during the first eager call without losing the source stride), and record tensor_details["stride"] from the source tensor before staging.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/api_logging.py` around lines 472 - 481, The eager-mode path is
losing original CUDA tensor strides because _extract_tensors_and_metadata_pinned
(which calls _stage_tensor_to_pinned and uses pinned.copy_) is being invoked
unconditionally; change the call site around _is_current_stream_capturing() so
that you only stage to pinned buffers when capturing (use
_extract_tensors_and_metadata for eager), and when you must pre-warm the pinned
cache during the first eager call, explicitly record tensor_details["stride"]
from the source tensor before calling _stage_tensor_to_pinned so the original
stride is preserved in the metadata; update both call sites (the one around
_is_current_stream_capturing and the duplicate at lines 660–668) and ensure
_extract_tensors_and_metadata_pinned documents/returns original stride if
staging occurs.

318-365: ⚠️ Potential issue | 🟠 Major

Pinned-buffer aliasing on repeated in-graph calls (duplicate).

The (func_name, key, shape, dtype) cache key still aliases when the same @flashinfer_api is invoked more than once during a single capture (e.g., a method called inside a captured loop, or two decode calls back-to-back). All such call sites end up sharing the same pinned tensor and _PENDING_GRAPH_DUMPS will hold multiple entries pointing at the same buffer; after replay every aliased dump_dir flushes identical content. Either disambiguate the cache key during capture (e.g., fold in _dump_call_counter[func_name]), refuse to alias and raise like the warmup case, or extend the comment to make the in-graph repeat case explicit.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/api_logging.py` around lines 318 - 365, The pinned-buffer cache
currently aliasing repeated in-graph calls must be disambiguated: modify
_stage_tensor_to_pinned to include a capture-specific call identifier in the
cache key when inside a capture (instead of the current (func_name, key, shape,
dtype) only). Add/consume a per-function-in-capture counter (e.g.,
_dump_call_counter[func_name] or similar) incremented on each dump invocation
during capture and fold that counter into cache_key when
_is_current_stream_capturing() is true; ensure the counter is incremented before
building cache_key and persisted/cleared appropriately so multiple in-graph
calls get distinct pinned buffers and update the top comment to document this
behavior.

1405-1422: ⚠️ Potential issue | 🟠 Major

Level-5 stats also need eager warmup for capture (duplicate).

_get_api_log_stats_kernel() is @functools.cache-decorated and triggers gen_api_log_stats_module().build_and_load() on first use. Because _launch_gpu_stats_kernel is only invoked when is_capturing is True, a "first run is the captured run" workflow under FLASHINFER_LOGLEVEL=5 will JIT-build (and cuModuleLoadData) inside the capture region — which is prohibited under cudaStreamCaptureModeGlobal and aborts capture. Either eagerly prime the kernel from _warn_dump/_log_system_info when level ≥ 5, or document the warmup requirement alongside the level-10 note in the docstring (lines 1970–1983).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/api_logging.py` around lines 1405 - 1422, The stats JIT kernel
(_get_api_log_stats_kernel) can be first-built during a CUDA stream capture
because _launch_gpu_stats_kernel runs only when is_capturing is true; to avoid
illegal JIT/cudaModuleLoadData inside capture, eagerly warm up the kernel when
logging level >= 5 by calling _get_api_log_stats_kernel() from the startup/info
path (e.g., inside _warn_dump and/or _log_system_info) so the build/load happens
before any capture; ensure you handle the None return (build failure) as the
existing callers do and keep the existing level-10 docstring note for
FLASHINFER_LOGLEVEL while adding a short comment about the warmup at level-5.

797-827: ⚠️ Potential issue | 🟡 Minor

flush_graph_dumps doesn't promote execution_status to a terminal state (duplicate).

After successful tensor writes here, neither the per-dump metadata.jsonl nor the central session.jsonl gets a follow-up record — they retain execution_status="graph_capture_pending_flush" forever. Consumers filtering by execution_status == "completed" (the eager-mode terminal state used at line 714) will treat flushed dumps as still pending. Append a small completion record with execution_status="completed" (and ideally a graph_capture_flushed=True marker plus timestamp) after each successful write so on-disk state matches reality.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/api_logging.py` around lines 797 - 827, In flush_graph_dumps,
after each successful tensor write (inside the try block that writes tensors for
entries from _PENDING_GRAPH_DUMPS and after incrementing flushed), append a
small completion record to the per-dump metadata stream and the central session
stream indicating execution_status="completed", graph_capture_flushed=True, and
a timestamp (use entry["dump_dir"], entry["kind"], entry.get("func_name") to
populate context); ensure this write happens only on success and does not
swallow exceptions from the tensor save step, and reuse any existing helper(s)
used elsewhere for writing metadata/session JSONL to keep format consistent.
🧹 Nitpick comments (1)
flashinfer/api_logging.py (1)

2042-2055: Verify intent: include/exclude filter now also gates level-3 console logging.

Level 3+ logging (the _log_function_inputs/_log_function_outputs console output) is now gated by _should_dump_function(func_name) along with level-10 disk dumps. This is consistent with the linked commit message ("gate level-3+ logging path by include/exclude filter"), but worth confirming: users running with only FLASHINFER_LOGLEVEL=3 and FLASHINFER_DUMP_INCLUDE/FLASHINFER_DUMP_EXCLUDE set previously got logs for every API; they will now see logs only for filtered APIs. Consider documenting this side-effect in the env-var reference around lines 1944–1948 (the docstring still describes FLASHINFER_DUMP_INCLUDE/EXCLUDE purely as dump filters), and/or rename _should_dump_function to something like _should_log_function to reflect the broader scope.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/api_logging.py` around lines 2042 - 2055, The current change also
gates level-3 console logging by _should_dump_function(func_name), which
unintentionally prevents normal FLASHINFER_LOGLEVEL=3 console logs for
unfiltered APIs; revert that behavior by making console logging (calls to
_log_function_inputs and _log_function_outputs when _API_LOG_LEVEL >= 3) not
depend on _should_dump_function, while keeping _should_dump_function gating only
for disk dumps (level 10 or explicit dump paths). Concretely, update the pre-
and post-execution checks so that if _API_LOG_LEVEL >= 3 you call
_log_function_inputs/_log_function_outputs unconditionally (or based on a
separate _should_log_function predicate if you prefer), and keep
_should_dump_function checks only for the level-10 dump branch; also update the
env-var docstring around the FLASHINFER_DUMP_INCLUDE/EXCLUDE lines to state
these filters apply to dump files (not console level-3 logging) or rename
_should_dump_function to _should_log_function and adjust doc accordingly if you
intended to change semantics.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/api_logging.py`:
- Around line 830-839: clear_graph_dumps currently only clears
_PENDING_GRAPH_DUMPS but the docstring says it "Releases the pinned host
buffers"; either evict the corresponding entries from _PINNED_DUMP_BUFFER_CACHE
for the keys removed from _PENDING_GRAPH_DUMPS (so buffers are freed and future
captures will perform eager warmup), or update the clear_graph_dumps docstring
to accurately state that only the pending-write registry is cleared while pinned
buffers in _PINNED_DUMP_BUFFER_CACHE remain cached for reuse; locate the
function clear_graph_dumps and change its implementation to iterate removed keys
and pop from _PINNED_DUMP_BUFFER_CACHE, or modify its docstring text to the
corrected behavior.

---

Duplicate comments:
In `@flashinfer/api_logging.py`:
- Around line 472-481: The eager-mode path is losing original CUDA tensor
strides because _extract_tensors_and_metadata_pinned (which calls
_stage_tensor_to_pinned and uses pinned.copy_) is being invoked unconditionally;
change the call site around _is_current_stream_capturing() so that you only
stage to pinned buffers when capturing (use _extract_tensors_and_metadata for
eager), and when you must pre-warm the pinned cache during the first eager call,
explicitly record tensor_details["stride"] from the source tensor before calling
_stage_tensor_to_pinned so the original stride is preserved in the metadata;
update both call sites (the one around _is_current_stream_capturing and the
duplicate at lines 660–668) and ensure _extract_tensors_and_metadata_pinned
documents/returns original stride if staging occurs.
- Around line 318-365: The pinned-buffer cache currently aliasing repeated
in-graph calls must be disambiguated: modify _stage_tensor_to_pinned to include
a capture-specific call identifier in the cache key when inside a capture
(instead of the current (func_name, key, shape, dtype) only). Add/consume a
per-function-in-capture counter (e.g., _dump_call_counter[func_name] or similar)
incremented on each dump invocation during capture and fold that counter into
cache_key when _is_current_stream_capturing() is true; ensure the counter is
incremented before building cache_key and persisted/cleared appropriately so
multiple in-graph calls get distinct pinned buffers and update the top comment
to document this behavior.
- Around line 1405-1422: The stats JIT kernel (_get_api_log_stats_kernel) can be
first-built during a CUDA stream capture because _launch_gpu_stats_kernel runs
only when is_capturing is true; to avoid illegal JIT/cudaModuleLoadData inside
capture, eagerly warm up the kernel when logging level >= 5 by calling
_get_api_log_stats_kernel() from the startup/info path (e.g., inside _warn_dump
and/or _log_system_info) so the build/load happens before any capture; ensure
you handle the None return (build failure) as the existing callers do and keep
the existing level-10 docstring note for FLASHINFER_LOGLEVEL while adding a
short comment about the warmup at level-5.
- Around line 797-827: In flush_graph_dumps, after each successful tensor write
(inside the try block that writes tensors for entries from _PENDING_GRAPH_DUMPS
and after incrementing flushed), append a small completion record to the
per-dump metadata stream and the central session stream indicating
execution_status="completed", graph_capture_flushed=True, and a timestamp (use
entry["dump_dir"], entry["kind"], entry.get("func_name") to populate context);
ensure this write happens only on success and does not swallow exceptions from
the tensor save step, and reuse any existing helper(s) used elsewhere for
writing metadata/session JSONL to keep format consistent.

---

Nitpick comments:
In `@flashinfer/api_logging.py`:
- Around line 2042-2055: The current change also gates level-3 console logging
by _should_dump_function(func_name), which unintentionally prevents normal
FLASHINFER_LOGLEVEL=3 console logs for unfiltered APIs; revert that behavior by
making console logging (calls to _log_function_inputs and _log_function_outputs
when _API_LOG_LEVEL >= 3) not depend on _should_dump_function, while keeping
_should_dump_function gating only for disk dumps (level 10 or explicit dump
paths). Concretely, update the pre- and post-execution checks so that if
_API_LOG_LEVEL >= 3 you call _log_function_inputs/_log_function_outputs
unconditionally (or based on a separate _should_log_function predicate if you
prefer), and keep _should_dump_function checks only for the level-10 dump
branch; also update the env-var docstring around the
FLASHINFER_DUMP_INCLUDE/EXCLUDE lines to state these filters apply to dump files
(not console level-3 logging) or rename _should_dump_function to
_should_log_function and adjust doc accordingly if you intended to change
semantics.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7cebb1cd-d36a-48ba-9f56-356d4dcf0d04

📥 Commits

Reviewing files that changed from the base of the PR and between 6551545 and 8710c20.

📒 Files selected for processing (1)
  • flashinfer/api_logging.py

Comment thread flashinfer/api_logging.py
@yyihuang yyihuang marked this pull request as draft April 26, 2026 02:04
@yyihuang yyihuang changed the title feat(logging): make level-5 stats and level-10 dumps work under CUDA graph capture feat(logging,trace): cuda-graph-compatible level-5/10 logging + fi_trace template additions/fixes May 2, 2026
yyihuang pushed a commit to yyihuang/flashinfer that referenced this pull request May 2, 2026
Addresses inline review comments on PR flashinfer-ai#3172. Each item below maps to
one or more comments from gemini-code-assist or coderabbit.

csrc/api_log_stats.cu (level-5 stats kernel):
- Use ``double`` (not ``float``) for the reduction accumulators and
  per-thread min/max. Float's 24-bit mantissa drops precision past
  ~16.7M for int32_t/int64_t inputs; the kernel always emits ``%.6f``
  anyway. Drops the ``CUDART_INF_F`` sentinels for ``CUDART_INF``.
- Include ``+/-Inf`` in the min/max reduction (still counted separately
  in ``inf=N``). Pre-fix the GPU path showed e.g. ``min=1 max=1`` for
  ``[1.0, +inf]`` while eager ``torch.min/max`` showed ``max=+inf``;
  the inconsistency was confusing.
- New explicit "all non-finite" branch (``valid == 0``) so a tensor of
  pure NaN/Inf doesn't print the misleading sentinel
  ``min=inf max=-inf mean=0.000000``; instead we say
  ``(all non-finite) nan=N inf=M``.

flashinfer/api_logging.py:
- ``_launch_gpu_stats_kernel`` now early-returns ``None`` when the
  tensor is non-contiguous. The kernel does a linear scan via
  ``data[i]`` and would otherwise read garbage / out-of-bounds memory
  for transposed views or slices.
- Eager warm-up of the level-5 stats kernel at import time when
  ``FLASHINFER_LOGLEVEL>=5``. Without this, the first stats call inside
  ``torch.cuda.graph(...)`` triggers ``cuModuleLoadData`` via
  ``build_and_load()``, which is forbidden under
  ``cudaStreamCaptureModeGlobal`` and aborts the capture.
- ``_dump_function_inputs``/``_dump_function_outputs``: restrict the
  pinned-buffer staging path to capture mode and keep the legacy
  ``.cpu()`` extraction in eager mode. Pre-fix, eager dumps silently
  lost CUDA tensor strides because the pinned destination is contiguous,
  contradicting the docstring promise of stride/contiguity preservation.
  In eager we now also call a new ``_prime_pinned_buffer(...)`` that
  allocates (but doesn't copy into) the pinned cache so a subsequent
  captured call still finds a pre-allocated buffer.
- New ``_DumpWarmupRequired(RuntimeError)`` subclass; ``_stage_tensor_to_pinned``
  raises it (instead of bare ``RuntimeError``) when capture finds a
  cache miss. Both ``_dump_function_inputs`` and the ``flashinfer_api``
  decorator now special-case this subclass and let it propagate to user
  code, while still swallowing other dump failures via the generic
  ``Exception`` branch. Pre-fix, the broad ``except Exception`` blocks
  silently swallowed the warmup error so the contract was un-enforceable
  from a user-test perspective.
- ``flush_graph_dumps``: after a successful tensor-file write, append
  a completion record to per-dump ``metadata.jsonl`` and the central
  ``session.jsonl`` promoting ``execution_status`` from
  ``graph_capture_pending_flush`` to ``completed`` (or ``inputs_saved``
  for the inputs half). Consumers that filter by terminal state now see
  flushed dumps as completed instead of stuck in pending.
- ``clear_graph_dumps``: docstring rewritten to honestly describe
  current behavior — only the deferred-write registry is cleared; the
  pinned host buffers in ``_PINNED_DUMP_BUFFER_CACHE`` are intentionally
  retained so subsequent replays can reuse them without
  ``cudaHostAlloc`` (illegal under capture).

tests/utils/test_logging.py:
- ``test_level_10_cuda_graph_requires_warmup`` now asserts the
  ``RuntimeError`` explicitly via ``pytest.raises(...,
  match=r"(?i)pinned host memory")``. Pre-fix, the test accepted both
  the "exception" and "no exception" branches, so a regression that
  silently swallowed the warmup error would still leave it green.

All 18 tests in ``tests/utils/test_logging.py`` pass.

Skipped (out of scope or stale):
- clang-format complaint on csrc/api_log_stats.cu was already addressed
  in commit 6551545.
- ``tools/dump_with_cuda_graph.py`` was deleted in commit 67d066b,
  so the interpreter-detection comment is moot.

Deferred (separate follow-ups):
- Inf-vs-eager parity in the *exclusion* of Inf from the *sum* — the
  comment is right that this is debatable; we keep the current behavior
  (Inf affects min/max, excluded from mean) and document explicitly in
  the kernel.
- Pinned-buffer aliasing within a single graph: same ``(func, key,
  shape, dtype)`` captured twice in one graph (e.g. inside a loop
  body) still aliases the same pinned buffer. Worth a follow-up that
  either disambiguates with a per-call counter or detects-and-errors.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@yyihuang yyihuang marked this pull request as ready for review May 2, 2026 02:39
@yyihuang yyihuang requested a review from dhiraj113 as a code owner May 2, 2026 02:39
@yongwww yongwww added the run-ci label May 4, 2026
@yongwww
Copy link
Copy Markdown
Member

yongwww commented May 4, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !623 has been created, and the CI pipeline #50227464 is currently running. I'll report back once the pipeline job completes.

@yyihuang yyihuang force-pushed the cuda-graph-api-logging branch 2 times, most recently from e6924c3 to e9ad9c9 Compare May 13, 2026 19:39
averyhNV and others added 17 commits May 13, 2026 20:02
At FLASHINFER_LOGLEVEL=5 the host path (.min().item() etc.) cannot run
inside torch.cuda.graph(...) because .item() synchronizes the stream, so
statistics were silently skipped. Replace the skip with a single-block
CUDA kernel that computes min/max/mean/nan/inf and emits one printf line
per tensor — the launch is captured into the graph and the printf fires
on every replay. The host log records a correlation id so kernel output
can be matched back to the API call/argument that produced it.

Supported dtypes: float32, float16, bfloat16, int32, int64, uint8.
Other dtypes fall back to the legacy skip message.

Approach mirrors flashinfer-ai/debug-print#2.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
At FLASHINFER_LOGLEVEL=10 the dump path called .cpu() on every input/output,
which synchronizes the captured stream and is illegal under
torch.cuda.graph(...). Stage all dump tensors through cached pinned host
buffers (allocated lazily during eager warmup, since cudaHostAlloc is
forbidden under capture) and issue captured non_blocking copy_() ops so
each replay refreshes the buffers in place.

The actual inputs.pt/outputs.pt writes are deferred to a new
flush_graph_dumps() API that the user calls after each g.replay() — that
function synchronizes the stream, then writes the buffer's current
contents to disk so the dump always reflects the most recent replay.
clear_graph_dumps() releases the held pinned buffers.

Caveats documented in flashinfer/api_logging.py:
- requires at least one eager warmup call before capture (to populate the
  pinned-buffer cache);
- the captured buffers are contiguous so original strides are not
  preserved (matches the existing safetensors path).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lets users run an unmodifiable Python program (e.g. sglang) under
FLASHINFER_LOGLEVEL=10 + torch.cuda.graph(...) without touching the
program's source. The wrapper:

* sets the FLASHINFER_DUMP_* env vars from CLI flags;
* monkey-patches torch.cuda.CUDAGraph.replay so flush_graph_dumps()
  fires automatically after every replay (idempotent);
* runs the target via runpy when it is `python ...` so the patch stays
  alive in-process; for non-Python targets it falls back to execvp and
  warns.

Usage:

  python tools/dump_with_cuda_graph.py \
      --dump-dir /tmp/fi_dumps --include '*decode*' --max-count 10 \
      -- python -m sglang.launch_server ...

Strongly recommend setting --include and --max-count: without scoping,
every replay rewrites every captured dump, which is heavy for high-QPS
workloads.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Setting FLASHINFER_LOGLEVEL=10 inside a real workload (sglang DSR1 TP=8
warmup observed) used to emit ~28M [flashinfer stats] lines for every
flashinfer API call, regardless of FLASHINFER_DUMP_INCLUDE — the
include/exclude filter only gated _dump_function_inputs/outputs, not
_log_function_inputs/outputs. The volume can overrun upstream HTTP
health-check polling and abort the engine before the dump path even
runs.

Apply _should_dump_function in the level-3+ logging branches too, so
include/exclude narrows BOTH log emission AND tensor dumps.

Verified with sglang DSR1 FP8 TP=8 + dump_with_cuda_graph.py
--include='BatchMLAPagedAttentionWrapper.run,top_k_renorm_probs,...':
apilog dropped from 653 MB to 128 MB and the server reached graph
capture + replay successfully (was previously timing out warmup).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
A long real-workload run (e.g. sglang under InferenceX) issues the same
flashinfer API at the same shape thousands of times during decode replay
and emits a fresh dump every time, which inflates raw dump volume well
past anything sanitize_dumps.py needs as input. The existing global
FLASHINFER_DUMP_MAX_COUNT cap kicks in too uniformly: it stops after N
dumps total, so rarely-seen shapes get dropped before they fire.

Add a per-(func_name, input-shape-signature) cap. When set:

  FLASHINFER_DUMP_PER_SHAPE_LIMIT=5

every distinct shape that the workload actually exercises gets up to 5
sample dumps; further calls with the same shape are skipped. No shape
synthesis — the set of shapes captured is exactly what the workload
exercised. Useful as a passive prune feeding the existing
flashinfer-bench / flashinfer-trace workload pipeline.

Implementation:
- new _compute_input_shape_signature(args, kwargs) — string-only,
  hashable, host-side-metadata-only (no GPU sync, safe under cuda
  graph capture). Tensor inputs contribute (shape, dtype); scalar
  kwargs contribute their repr() so e.g. different block_size counts
  as a distinct shape.
- new _dump_shape_counter dict keyed by (func_name, signature),
  per-process. TP ranks see identical shapes and each contribute their
  own sample (fine for the typical 8-rank case).
- gate runs after the existing include/exclude filter and global
  count cap, so all three caps compose.

Default 0 = disabled (back-compat with all existing collectors).

Verified with a synthetic 15-call test (3 shapes × 5 calls): with cap=2
the dump dir contains exactly 6 entries.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two fixes that together let LOGLEVEL=10 dumps land for real-workload
collection through an inference engine that uses both
torch.inference_mode() and torch.cuda.graph(...) capture (sglang DSR1
TP=8 was the proving ground).

1. _stage_tensor_to_pinned now wraps both the pinned-buffer alloc and
   the copy_() in `with torch.inference_mode(False):`. Without this,
   sglang's outer torch.inference_mode() marks the cached pinned buffer
   as an "inference tensor" on first allocation, which causes every
   subsequent in-place copy_() to raise:
     RuntimeError: Inplace update to inference tensor outside
     InferenceMode is not allowed.
   Net effect was every dump silently failing once requests started
   flowing.

2. _install_cuda_graph_replay_autoflush patches torch.cuda.CUDAGraph
   .replay (idempotently) to call flush_graph_dumps() after every
   replay, gated on FLASHINFER_LOGLEVEL >= 10. Without this in-process
   patch, the previous mechanism (tools/dump_with_cuda_graph.py)
   monkey-patched only the parent process. sglang spawns TP worker
   processes via multiprocessing.spawn, which gives each worker a fresh
   Python interpreter that re-imports flashinfer; the wrapper's patch
   never reached the workers, so captured D2H copies happened but the
   inputs.{pt,safetensors} files were never written.

Verified end-to-end: with --attention-backend flashinfer, LOGLEVEL=10,
SAFETENSORS=1, and cuda graphs enabled, sglang reaches "fired up" and
external workload drivers (e.g. InferenceX benchmark_serving.py) drive
real requests whose dumps land on disk through the captured graph
replay path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Three changes that together let LOGLEVEL=10 capture real-workload shapes
under sglang's cuda-graph decode without tanking decode throughput by 1-3
orders of magnitude.

1. Replace per-CUDAGraph.replay autoflush with atexit + SIGTERM hooks.
   The previous per-replay flush (commit 570ea4c) ran a few hundred
   save_file() calls after every captured graph replay, dropping decode
   from ~100 tok/s/rank to ~1 tok/s/rank when DUMP_DIR was on NFS, and
   was the cause of all-requests-failed under sustained sglang
   inference. Replays now refresh the pinned host buffers via captured
   D2H copies; one shutdown-time flush serializes the latest values for
   every captured shape — exactly what's needed for workload collection.

2. Add FLASHINFER_DISABLE_GRAPH_STATS=1 to skip the captured printf-
   from-graph stats path. _launch_gpu_stats_kernel embeds a device-side
   printf into every captured graph; under sustained replay this floods
   host stdout (~tens of thousands of [flashinfer stats] lines/s with
   ~122 traced ops × 8 TP × per-replay), saturating sglang's stdout
   pipe and stalling the inference scheduler. Set the new env var when
   collecting workloads under cuda graphs to suppress the printf.

3. Make flush_graph_dumps mkdir(parents=True, exist_ok=True) before
   each save_file. This lets the orchestrator wipe DUMP_DIR between
   the warmup→inference boundary (drops eager-mode warmup dumps) while
   still letting the captured-graph deferred dumps land at flush time
   — the registry's dump_dir paths just get re-created.

Verified end-to-end on B200 + DSR1 FP8 TP=8 + InferenceX:
- v9 (per-replay flush, NFS dump): all 16 requests timed out; 0 success
- v10 (shutdown-only flush + DISABLE_GRAPH_STATS, plan-only filter):
  8/8 successful, 1590 tok/s total throughput, 367 ms TTFT
- post-warmup wipe verified safe via flush self-heal

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two general improvements to make level-5 stats and level-10 dumps work
on real workloads under cuda graph capture:

1. Drop the per-filtered-call debug log. At FLASHINFER_LOGLEVEL>=10 the
   logger is at DEBUG, so emitting one line per filtered call (every
   gemm / rmsnorm / etc in a sustained inference run) saturates stderr
   at ~325k lines/sec and drops decode throughput by ~30x when an
   INCLUDE/EXCLUDE filter is set.

2. Add FLASHINFER_DUMP_MAX_TENSOR_MB env var (default 0 = old
   behaviour). Tensors over the cap are recorded as safe attrs
   (shape/dtype/device/stride) only — never via _serialize_value, which
   falls into the generic branch and calls str(t) → tensor.__repr__(),
   reading device memory and triggering cudaErrorStreamCaptureInvalidated
   under capture. Also avoids the multi-GB-per-replay D2H tax for giant
   KV-cache args (the canonical workload schemas treat such inputs as
   "type": "random" downstream).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Document FLASHINFER_DUMP_PER_SHAPE_LIMIT, FLASHINFER_DUMP_MAX_TENSOR_MB,
  and FLASHINFER_DISABLE_GRAPH_STATS env vars in the dump-config table.
- Rewrite the CUDA Graph Compatibility section: replace the obsolete
  per-replay flush_graph_dumps() pattern with the atexit/SIGTERM
  auto-flush model that ships with the cuda-graph-aware logging.
- Note when to flip FLASHINFER_DISABLE_GRAPH_STATS=1 (sustained-replay
  scenarios where device printf saturates stdout).
- Note FLASHINFER_DUMP_MAX_TENSOR_MB rationale (avoids multi-GB D2H
  per replay AND the cudaErrorStreamCaptureInvalidated trap from
  str(tensor) under capture).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…FER_DUMP_MAX_TENSOR_MB

Both env vars were added for a workload-collection use case that
isn't part of this PR's intended scope. Reverting to keep the
cuda-graph + level 5/10 logging implementation minimal:

- Remove FLASHINFER_DUMP_PER_SHAPE_LIMIT (introduced in fa8affd):
  drops the per-(func, shape) cap, _compute_input_shape_signature(),
  and the _dump_shape_counter registry.
- Remove FLASHINFER_DUMP_MAX_TENSOR_MB (introduced in d237a0e):
  drops the safe-metadata-only path for oversized tensors and the
  _tensor_nbytes_mb() helper. _extract_tensors_and_metadata_pinned()
  reverts to the simple "stage every tensor" form.
- Update docs/logging.rst accordingly.

Net delta vs the previous tip: -2 features, -124 LoC.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Removes the env-var escape hatch and its branch in
_log_tensor_statistics. Under cuda-graph capture, the path now always
delegates stats to the captured GPU kernel (or falls back to the legacy
"[statistics skipped]" message for unsupported dtypes), with no opt-out.

- Drop FLASHINFER_DISABLE_GRAPH_STATS env var declaration.
- Drop the if _DISABLE_GRAPH_STATS: branch in the capture-path stats code.
- Drop the corresponding rows / paragraphs in docs/logging.rst.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tool's premise — monkey-patch CUDAGraph.replay so flush_graph_dumps()
runs after every replay — is now obsolete. The atexit/SIGTERM hook in
api_logging._install_cuda_graph_dump_autoflush() flushes once at process
shutdown, which is the recommended pattern; per-replay flushing is
explicitly avoided because it drops decode throughput from ~100 tok/s/rank
to ~1 tok/s/rank when DUMP_DIR is on NFS.

Beyond the obsolete monkey-patch, the wrapper does nothing that an env-var
prefix can't do:
  FLASHINFER_LOGLEVEL=10 FLASHINFER_DUMP_DIR=... python -m my_program

The patch also wouldn't propagate into multiprocess.spawn TP workers (the
intended target), so even at its peak it didn't solve its own use case.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Add @flashinfer_api(trace=gemm_fp8_nt_groupwise_trace) on the FP8
  group-wise GEMM (cutlass + trtllm backends). Captures the trtllm
  canonical scale layout (a_scale=[M, K//bk], b_scale=[K//bk, N//bk])
  used by sglang's --moe-runner-backend flashinfer_trtllm DSR1 path.
  Validated against a real cuda-graph stage-1 run on 8x B200: this op
  fired ~28k times per TP worker but produced no trace JSON because
  the function had a bare @flashinfer_api with no template attached.

- Fix trtllm_batch_decode_mla_trace and xqa_batch_decode_mla_trace to
  model the rank-4 [num_pages, 1, page_size, head_dim_qk] kv_cache
  layout via a kv_pad_dim const, switch workspace_buffer dtype int8
  → uint8, and add the missing skip_softmax_threshold_scale_factor
  scalar.

- Fix mla_rope_quantize_fp8_trace to use rank-2 K tensors
  (num_k_heads=1 collapsed) instead of inheriting the rank-3 GQA
  template _ROPE_QUANT_AXES / _ROPE_QUANT_INPUTS.

- Add example invocation in tests/trace/example.py and regenerate
  tests/trace/fi_trace_out/ accordingly. Tests pass: 440 passed,
  8 skipped.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…K_div_block]

The trtllm backend of gemm_fp8_nt_groupwise actually expects b_scale as
[N//bk, K//bk] — the transposed form of the layout described in
flashinfer/gemm/gemm_base.py:5681-5683. tests/gemm/test_groupwise_scaled_gemm_fp8.py:128-129
proves this:

    if backend == "trtllm":
        b_scale = b_scale.t().contiguous()

sglang's layers/quantization/fp8_utils.py produces the same transposed
layout, and a cuda-graph stage-1 run on 8x B200 shows runtime
b_scale.shape = (17, 56) for K=7168, N=2112, block=128 (i.e. (N//bk,
K//bk), with stride (56, 1) confirming row-major contiguous storage —
not a transposed view).

The previous template declared b_scale as [K_div_block, N_div_block],
which made the matcher in flashinfer-bench's sanitize_fi_log.py reject
every real gemm_fp8_nt_groupwise call: the K_div_block axis bound to
K//bk via a_scale at step 3, then b_scale dim-0 (which is actually
N//bk) tried to bind it to N//bk at step 4 → conflict → no template
match → no workload extracted.

Update the template, the example call in tests/trace/example.py, and
the regenerated JSON. Trace tests still pass: 440 passed, 8 skipped.
The flashinfer source-code docstring is left as-is (separate issue).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Addresses inline review comments on PR flashinfer-ai#3172. Each item below maps to
one or more comments from gemini-code-assist or coderabbit.

csrc/api_log_stats.cu (level-5 stats kernel):
- Use ``double`` (not ``float``) for the reduction accumulators and
  per-thread min/max. Float's 24-bit mantissa drops precision past
  ~16.7M for int32_t/int64_t inputs; the kernel always emits ``%.6f``
  anyway. Drops the ``CUDART_INF_F`` sentinels for ``CUDART_INF``.
- Include ``+/-Inf`` in the min/max reduction (still counted separately
  in ``inf=N``). Pre-fix the GPU path showed e.g. ``min=1 max=1`` for
  ``[1.0, +inf]`` while eager ``torch.min/max`` showed ``max=+inf``;
  the inconsistency was confusing.
- New explicit "all non-finite" branch (``valid == 0``) so a tensor of
  pure NaN/Inf doesn't print the misleading sentinel
  ``min=inf max=-inf mean=0.000000``; instead we say
  ``(all non-finite) nan=N inf=M``.

flashinfer/api_logging.py:
- ``_launch_gpu_stats_kernel`` now early-returns ``None`` when the
  tensor is non-contiguous. The kernel does a linear scan via
  ``data[i]`` and would otherwise read garbage / out-of-bounds memory
  for transposed views or slices.
- Eager warm-up of the level-5 stats kernel at import time when
  ``FLASHINFER_LOGLEVEL>=5``. Without this, the first stats call inside
  ``torch.cuda.graph(...)`` triggers ``cuModuleLoadData`` via
  ``build_and_load()``, which is forbidden under
  ``cudaStreamCaptureModeGlobal`` and aborts the capture.
- ``_dump_function_inputs``/``_dump_function_outputs``: restrict the
  pinned-buffer staging path to capture mode and keep the legacy
  ``.cpu()`` extraction in eager mode. Pre-fix, eager dumps silently
  lost CUDA tensor strides because the pinned destination is contiguous,
  contradicting the docstring promise of stride/contiguity preservation.
  In eager we now also call a new ``_prime_pinned_buffer(...)`` that
  allocates (but doesn't copy into) the pinned cache so a subsequent
  captured call still finds a pre-allocated buffer.
- New ``_DumpWarmupRequired(RuntimeError)`` subclass; ``_stage_tensor_to_pinned``
  raises it (instead of bare ``RuntimeError``) when capture finds a
  cache miss. Both ``_dump_function_inputs`` and the ``flashinfer_api``
  decorator now special-case this subclass and let it propagate to user
  code, while still swallowing other dump failures via the generic
  ``Exception`` branch. Pre-fix, the broad ``except Exception`` blocks
  silently swallowed the warmup error so the contract was un-enforceable
  from a user-test perspective.
- ``flush_graph_dumps``: after a successful tensor-file write, append
  a completion record to per-dump ``metadata.jsonl`` and the central
  ``session.jsonl`` promoting ``execution_status`` from
  ``graph_capture_pending_flush`` to ``completed`` (or ``inputs_saved``
  for the inputs half). Consumers that filter by terminal state now see
  flushed dumps as completed instead of stuck in pending.
- ``clear_graph_dumps``: docstring rewritten to honestly describe
  current behavior — only the deferred-write registry is cleared; the
  pinned host buffers in ``_PINNED_DUMP_BUFFER_CACHE`` are intentionally
  retained so subsequent replays can reuse them without
  ``cudaHostAlloc`` (illegal under capture).

tests/utils/test_logging.py:
- ``test_level_10_cuda_graph_requires_warmup`` now asserts the
  ``RuntimeError`` explicitly via ``pytest.raises(...,
  match=r"(?i)pinned host memory")``. Pre-fix, the test accepted both
  the "exception" and "no exception" branches, so a regression that
  silently swallowed the warmup error would still leave it green.

All 18 tests in ``tests/utils/test_logging.py`` pass.

Skipped (out of scope or stale):
- clang-format complaint on csrc/api_log_stats.cu was already addressed
  in commit 6551545.
- ``tools/dump_with_cuda_graph.py`` was deleted in commit 67d066b,
  so the interpreter-detection comment is moot.

Deferred (separate follow-ups):
- Inf-vs-eager parity in the *exclusion* of Inf from the *sum* — the
  comment is right that this is debatable; we keep the current behavior
  (Inf affects min/max, excluded from mean) and document explicitly in
  the kernel.
- Pinned-buffer aliasing within a single graph: same ``(func, key,
  shape, dtype)`` captured twice in one graph (e.g. inside a loop
  body) still aliases the same pinned buffer. Worth a follow-up that
  either disambiguates with a per-call counter or detects-and-errors.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
pre-commit's `fix-end-of-files` hook found 11 fi_trace JSONs in this
PR's diff missing a trailing newline. The fi_trace auto-dump writes
each JSON via `json.dumps(...)` without an explicit `"\n"` at EOF; the
hook adds one. No semantic change to any trace template.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@yyihuang yyihuang force-pushed the cuda-graph-api-logging branch 2 times, most recently from 1a7e30c to 6331423 Compare May 13, 2026 21:06
@yyihuang yyihuang force-pushed the cuda-graph-api-logging branch from 6331423 to 008a283 Compare May 13, 2026 21:13
@yyihuang
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@yyihuang is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@yongwww
Copy link
Copy Markdown
Member

yongwww commented May 15, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !623 has been updated with latest changes, and the CI pipeline #51416583 is currently running. I'll report back once the pipeline job completes.

@yyihuang yyihuang merged commit 7d1d46e into flashinfer-ai:main May 15, 2026
30 of 35 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants