[Cache] Add Cache.snapshot() / Cache.restore(snapshot) for tentative-forward rollback#45846
Open
kashif wants to merge 4 commits intohuggingface:mainfrom
Open
[Cache] Add Cache.snapshot() / Cache.restore(snapshot) for tentative-forward rollback#45846kashif wants to merge 4 commits intohuggingface:mainfrom
Cache.snapshot() / Cache.restore(snapshot) for tentative-forward rollback#45846kashif wants to merge 4 commits intohuggingface:mainfrom
Conversation
…rollback
Speculative decoding (and any "forward then maybe undo" pattern) needs to
roll a cache back to a previous state. For full-attention layers the existing
`crop()` API works, but for linear-attention layers — whose `crop()` is a
documented no-op because the recurrent state has already absorbed every
observed token — there has been no supported way to undo a forward pass.
Block-diffusion speculative decoders silently corrupt their target cache
on hybrid-attention models (e.g. Qwen3.5) as a result.
This change adds:
* `CacheLayerMixin.snapshot()` / `restore(snapshot)` covering full-attn
layers (`DynamicLayer`, `DynamicSlidingWindowLayer`).
* `LinearAttentionCacheLayerMixin.snapshot()` / `restore(snapshot)` for
linear-attn layers. `restore()` uses `.copy_()` into existing tensors
when shapes match so cudagraph static-address assumptions survive.
* `LinearAttentionAndFullAttentionLayer` overrides that merge both parents.
* `Cache.snapshot()` / `Cache.restore(snapshot)` that delegate per-layer
and reject layout mismatches.
The shape of a snapshot is opaque to callers — the return value is meant
to be passed back to `restore()` on the same cache. The intended pattern is
snap = cache.snapshot()
run_target_on_speculative_block(...)
if partial_accept:
cache.restore(snap)
run_target_on_accepted_prefix(...)
Tests in `tests/utils/test_cache_utils.py::CacheSnapshotRestoreTest` cover
DynamicLayer round-trip, LinearAttentionLayer round-trip (including
static-address preservation), DynamicCache aggregate round-trip, and a
layout-mismatch rejection.
Verified end-to-end with diffusers' DFlashPipeline on z-lab/Qwen3.5-4B-DFlash
+ Qwen/Qwen3.5-4B (a hybrid-attention target whose target cache contains
linear-attention layers): the pipeline drives generation purely through the
new `cache.snapshot()` / `cache.restore()` calls and produces "2 + 2 equals 4."
Contributor
There was a problem hiding this comment.
Pull request overview
Adds a snapshot/restore rollback primitive to Cache and its layer types to support speculative decoding and other “tentative forward then undo” workflows, including for linear-attention layers where crop() is a no-op.
Changes:
- Add
snapshot()/restore(snapshot)to full-attention cache layers (CacheLayerMixin,DynamicSlidingWindowLayer) and to linear-attention cache layers (LinearAttentionCacheLayerMixin), plus hybrid merge behavior inLinearAttentionAndFullAttentionLayer. - Add
Cache.snapshot()/Cache.restore(snapshot)delegating per-layer snapshots with a basic layout-size check. - Add unit tests covering round-trips for
DynamicLayer,LinearAttentionLayer,DynamicCache, and a layer-count mismatch rejection.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
src/transformers/cache_utils.py |
Introduces snapshot/restore APIs on cache layers and Cache to enable rollback for speculative decoding, including linear-attention support with static-address-preserving restores. |
tests/utils/test_cache_utils.py |
Adds unit tests validating snapshot/restore round-trips and a basic mismatch rejection. |
Comment on lines
+116
to
128
| return { | ||
| "is_initialized": self.is_initialized, | ||
| "keys": self.keys.clone() if self.is_initialized and self.keys is not None else None, | ||
| "values": self.values.clone() if self.is_initialized and self.values is not None else None, | ||
| } | ||
|
|
||
| def restore(self, snapshot: dict) -> None: | ||
| """Restore this layer to the state captured by a previous `snapshot()` call.""" | ||
| self.is_initialized = bool(snapshot["is_initialized"]) | ||
| self.keys = snapshot["keys"] | ||
| self.values = snapshot["values"] | ||
|
|
||
|
|
Comment on lines
+821
to
+829
| elif self.conv_states is not None and self.conv_states.shape == snapshot["conv_states"].shape: | ||
| self.conv_states.copy_(snapshot["conv_states"]) | ||
| else: | ||
| self.conv_states = snapshot["conv_states"].clone() | ||
|
|
||
| if snapshot["recurrent_states"] is None: | ||
| self.recurrent_states = None | ||
| elif self.recurrent_states is not None and self.recurrent_states.shape == snapshot["recurrent_states"].shape: | ||
| self.recurrent_states.copy_(snapshot["recurrent_states"]) |
| raise ValueError( | ||
| f"Snapshot has {len(snapshot)} layers but cache has {len(self.layers)}; the snapshot was likely " | ||
| "taken from a different cache." | ||
| ) |
| self.assertEqual(cache.layers[1].get_seq_length(), 3) | ||
| self.assertTrue(torch.equal(cache.layers[0].keys, keys_a)) | ||
| self.assertTrue(torch.equal(cache.layers[1].keys, keys_b)) | ||
|
|
Cache.snapshot() / Cache.restore(snapshot) for tentative-forward rollback
Cache.snapshot() / Cache.restore(snapshot) for tentative-forward rollbackCache.snapshot() / Cache.restore(snapshot) for tentative-forward rollback
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
`StaticLayer` (and `StaticSlidingWindowLayer`) preallocate K/V tensors with `mark_static_address` and track `cumulative_length` separately, so the base class's reassign-and-no-cumulative-length default produced silently-wrong rollbacks: `get_seq_length()` was off and the static addresses were lost. Override snapshot/restore on those layers to capture `cumulative_length` (and `cumulative_length_int` on the sliding variant) and to `.copy_()` into the existing tensors so cudagraph captures survive. Replaces the previously over-broad fast tests with a tighter set: LinearAttentionLayer round-trip + static-address preservation, DynamicCache aggregate round-trip, the hybrid LinearAttentionAndFullAttentionLayer round-trip, and a StaticLayer round-trip that asserts both length bookkeeping and tensor identity preservation.
6 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Speculative decoding (and any "forward then maybe undo" pattern) needs to roll a cache back to a previous state. For full-attention layers, the existing
crop()API works, but for linear-attention layers whosecrop()is a documented no-op because the recurrent state has already absorbed every observed token, there has been no supported way to undo a forward pass. Block-diffusion speculative decoders silently corrupt their target cache on hybrid-attention models (e.g. Qwen3.5) as a result.This change adds:
CacheLayerMixin.snapshot()/restore(snapshot)covering full-attn layers (DynamicLayer,DynamicSlidingWindowLayer).LinearAttentionCacheLayerMixin.snapshot()/restore(snapshot)for linear-attn layers.restore()uses.copy_()into existing tensors when shapes match, so cudagraph static-address assumptions survive.LinearAttentionAndFullAttentionLayeroverrides that merges both parents.Cache.snapshot()/Cache.restore(snapshot)that delegate per-layer and reject layout mismatches.The shape of a snapshot is opaque to callers; the return value is meant to be passed back to
restore()on the same cache. The intended pattern isTests in
tests/utils/test_cache_utils.py::CacheSnapshotRestoreTestcoverDynamicLayerround-trip,LinearAttentionLayerround-trip (including static-address preservation),DynamicCacheaggregate round-trip, and a layout-mismatch rejection.Verified end-to-end with diffusers'
DFlashPipelineonz-lab/Qwen3.5-4B-DFlashQwen/Qwen3.5-4B(a hybrid-attention target whose target cache contains linear-attention layers): the pipeline drives generation purely through the newcache.snapshot()/cache.restore()calls and produces"2 + 2 equals 4."What does this PR do?
Fixes # (issue)
Code Agent Policy
The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.
PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.
This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read
CONTRIBUTING.md.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.