Skip to content

fix(vllm): off-by-one in last_hidden_states aux layer id (#87)#89

Merged
yubofredwang merged 3 commits intomainfrom
ywang/fix-vllm-aux-final-layer-id
Apr 28, 2026
Merged

fix(vllm): off-by-one in last_hidden_states aux layer id (#87)#89
yubofredwang merged 3 commits intomainfrom
ywang/fix-vllm-aux-final-layer-id

Conversation

@yubofredwang
Copy link
Copy Markdown
Collaborator

@yubofredwang yubofredwang commented Apr 28, 2026

Summary

Fixes #87 — vLLM aux-layer hidden-state capture was off by one for the last_hidden_states slot.

For Qwen3-8B (36 layers) the captured "last hidden state" was the output of layer 34 instead of the output of layer 35 (= input to model.norm), one full transformer block earlier than what training needs for target logits via lm_head(norm(h)). Inner aux layers were unaffected.

Why

vLLM's _maybe_add_hidden_state is called with layer_idx + 1 after each layer runs:

aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in enumerate(islice(self.layers, self.start_layer, self.end_layer), start=self.start_layer):
    hidden_states, residual = layer(positions, hidden_states, residual)
    self._maybe_add_hidden_state(aux_hidden_states, layer_idx + 1, hidden_states, residual)

So valid capture indices are [0, num_hidden_layers] and index num_hidden_layers is the pre-norm slot. VllmEngine.init() was appending num_hidden_layers - 1 and silently dropping a user-passed final post-layer via the < num_layers filter.

Changes

Commit 1 — fix(vllm): off-by-one in last_hidden_states aux layer id (#87)

  • lid + 1 < num_layerslid < num_layers (≡ lid + 1 <= num_layers) so a user-passed final post-layer is preserved.
  • final_layer_id = num_hidden_layers - 1num_hidden_layers so the appended slot matches vLLM's pre-norm capture index.
  • Stale rationale fragments in the surrounding comments updated; all unrelated comments preserved.
  • Adds TestAuxLayerIdResolution (4 cases): 36-layer regression, explicit-final-layer pass-through, out-of-range filtering, mid-layer +1 shift.

Commit 2 — test(vllm): drop stale TestChunkedPrefillSingleWrite (orphaned by #68)
Independent cleanup that surfaced while running the suite. PR #68 refactored MooncakeHiddenStatesConnector (_ReqMeta.make, slot_mapping field, add_request(block_size=...) removed; _cache_layer_group_id added) and added tests/test_connector_slot_mapping.py with full coverage of the new API, but left the old TestChunkedPrefillSingleWrite class in tests/test_vllm_engine.py behind. Those tests have been failing ever since against an API surface that no longer exists. Removing them — every case has an equivalent in test_connector_slot_mapping.py or test_vllm_engine_integration.py.

Test plan

  • `docker run … torchspec-local:vllm-v0.19.1 … pytest tests/test_vllm_engine.py tests/test_connector_slot_mapping.py` against real vLLM v0.19.1 + the project's patches: 56 passed, 1 skipped (GPU-only).
  • New `TestAuxLayerIdResolution` regression cases pass; explicit assertion that index 35 is not in the resolved list for a 36-layer model.
  • Recommended: a follow-up integration spot-check comparing the captured `last_hidden_states` against `outputs.hidden_states[-1]` from a plain HF forward, to nail down the alignment end-to-end. Not in this PR scope.

Reported-by: @shadowpa0327 — thanks for the precise repro and the printed-tensor screenshot, made the off-by-one trivial to confirm.

Closes #87

Copilot AI review requested due to automatic review settings April 28, 2026 06:41
vLLM's `_maybe_add_hidden_state` is called with `layer_idx + 1` *after*
each layer runs, so valid capture indices are `[0, num_hidden_layers]`
and index `num_hidden_layers` is the pre-`norm` slot consumed by the
training pipeline as `last_hidden_states` for target logit computation
via `lm_head(norm(h))`.

`VllmEngine.init()` was appending `num_hidden_layers - 1` instead, which
in vLLM's convention is the input to the last transformer block (= output
of the second-to-last block), one full layer earlier than intended. For
Qwen3-8B (36 layers) the captured "last hidden state" was the output of
layer 34 instead of the output of layer 35.

The companion filter `lid + 1 < num_layers` also silently dropped a
user-passed final post-layer, so `--aux-hidden-states-layers` with the
final id was being remapped to the wrong slot.

This commit:

- Widens the filter to `lid < num_layers` (≡ `lid + 1 <= num_layers`) so
  the final post-layer is preserved when the user passes it explicitly.
- Sets `final_layer_id = num_hidden_layers` so the appended slot matches
  vLLM's pre-`norm` capture index.
- Updates the now-incorrect rationale fragments in the surrounding
  comments; all unrelated comments preserved.
- Adds `TestAuxLayerIdResolution` (4 cases) covering the 36-layer
  Qwen3-8B regression, explicit-final-layer pass-through, out-of-range
  filtering, and mid-layer +1 shift.

Reported-by: shadowpa0327
Closes: #87
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
PR #68 ("Fix Mooncake connector for hybrid KV cache models and chunked
prefill") refactored `MooncakeHiddenStatesConnector`:

- `_ReqMeta.make()` and the `slot_mapping` field were removed; slot
  mapping is now computed lazily via `_slot_mapping_from_block_ids`.
- `MooncakeConnectorMetadata.add_request()` no longer takes `block_size`.
- `_cache_layer_group_id` is now set in `__init__` and consumed by
  `build_connector_meta`.

That PR added a new test file `tests/test_connector_slot_mapping.py`
(10 tests) covering the new API end-to-end (single/multiple/partial/
non-contiguous blocks, GPU, chunked-prefill partial-skip + full-store,
HMA mismatch, `_extract_from_kv_cache`) but the old
`TestChunkedPrefillSingleWrite` class in `tests/test_vllm_engine.py` was
left behind and has been failing ever since against an API surface that
no longer exists.

This commit removes the orphaned class and its `_import_connector_internals`
helper. No coverage is lost: each removed test maps onto an existing
case in `test_connector_slot_mapping.py` or the integration suite in
`test_vllm_engine_integration.py`.

Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes an off-by-one error in vLLM aux hidden-state layer-id resolution so last_hidden_states captures the post-last-layer / pre-norm slot (needed for correct target-logit computation), and updates tests accordingly while removing obsolete chunked-prefill tests.

Changes:

  • Adjust vLLM aux layer-id shifting/filtering to allow capturing index num_hidden_layers and append num_hidden_layers as the final (pre-norm) capture slot.
  • Add regression tests covering aux-layer id resolution for defaults and user-provided values.
  • Remove stale chunked-prefill tests that no longer match the connector API (superseded by newer coverage elsewhere).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
torchspec/inference/engine/vllm_engine.py Fix aux-layer id filtering and final capture id to correctly target vLLM’s pre-norm slot.
tests/test_vllm_engine.py Add regression tests for aux-layer id resolution; remove obsolete chunked-prefill test suite.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +177 to 179
final_layer_id = num_layers
if final_layer_id not in self.aux_hidden_state_layer_ids:
self.aux_hidden_state_layer_ids.append(final_layer_id)
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

final_layer_id is assumed to correspond to last_hidden_states (see MooncakeHiddenStatesConnector: it treats the last aux layer as the final layer). With the new filter, a user-supplied last post-layer id (num_layers-1) now shifts to num_layers and is preserved in-place; if the user list isn’t sorted (e.g. [num_layers-1, 1]), final_layer_id can end up not being the last entry and the connector will silently split out the wrong last_hidden_states. Consider normalizing the list after shifting so the final slot is always last (e.g., remove any existing final_layer_id then append it, optionally also dedup/sort).

Copilot uses AI. Check for mistakes.
Pure formatting — ruff-format collapses multi-line function signatures
and call sites that fit within the project's line-length=100 limit.
No behavior change.

Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
@yubofredwang yubofredwang merged commit ec1eeab into main Apr 28, 2026
2 checks passed
@yubofredwang yubofredwang deleted the ywang/fix-vllm-aux-final-layer-id branch April 28, 2026 06:51
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c106eeaa08

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 178 to 179
if final_layer_id not in self.aux_hidden_state_layer_ids:
self.aux_hidden_state_layer_ids.append(final_layer_id)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Ensure final aux layer id is always last

When a user provides aux_hidden_states_layers in non-ascending order and includes the final post-layer (e.g. [35, 1] for a 36-layer model), this branch keeps the shifted final id (36) in its original position and skips re-appending it, yielding [36, 2]. MooncakeHiddenStatesConnector.save_kv_layer splits tensors by taking the last aux slice as last_hidden_states, so this ordering causes last_hidden_states to come from a non-final layer in that input case. This regression is introduced by preserving the user-passed final layer while only appending when absent; the final slot should be moved/appended to the end unconditionally.

Useful? React with 👍 / 👎.

yubofredwang added a commit that referenced this pull request Apr 28, 2026
…) (#90)

Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Question] vLLM Hidden States Extraction

2 participants