Skip to content

[Cache] Add Cache.snapshot() / Cache.restore(snapshot) for tentative-forward rollback#45846

Open
kashif wants to merge 4 commits intohuggingface:mainfrom
kashif:fix-linear-attn-cache-rollback
Open

[Cache] Add Cache.snapshot() / Cache.restore(snapshot) for tentative-forward rollback#45846
kashif wants to merge 4 commits intohuggingface:mainfrom
kashif:fix-linear-attn-cache-rollback

Conversation

@kashif
Copy link
Copy Markdown
Contributor

@kashif kashif commented May 8, 2026

Speculative decoding (and any "forward then maybe undo" pattern) needs to roll a cache back to a previous state. For full-attention layers, the existing crop() API works, but for linear-attention layers whose crop() is a documented no-op because the recurrent state has already absorbed every observed token, there has been no supported way to undo a forward pass. Block-diffusion speculative decoders silently corrupt their target cache on hybrid-attention models (e.g. Qwen3.5) as a result.

This change adds:

  • CacheLayerMixin.snapshot() / restore(snapshot) covering full-attn layers (DynamicLayer, DynamicSlidingWindowLayer).
  • LinearAttentionCacheLayerMixin.snapshot() / restore(snapshot) for linear-attn layers. restore() uses .copy_() into existing tensors when shapes match, so cudagraph static-address assumptions survive.
  • LinearAttentionAndFullAttentionLayer overrides that merges both parents.
  • Cache.snapshot() / Cache.restore(snapshot) that delegate per-layer and reject layout mismatches.

The shape of a snapshot is opaque to callers; the return value is meant to be passed back to restore() on the same cache. The intended pattern is

snap = cache.snapshot()
run_target_on_speculative_block(...)
if partial_accept:
    cache.restore(snap)
    run_target_on_accepted_prefix(...)

Tests in tests/utils/test_cache_utils.py::CacheSnapshotRestoreTest cover DynamicLayer round-trip, LinearAttentionLayer round-trip (including static-address preservation), DynamicCache aggregate round-trip, and a layout-mismatch rejection.

Verified end-to-end with diffusers' DFlashPipeline on z-lab/Qwen3.5-4B-DFlash

  • Qwen/Qwen3.5-4B (a hybrid-attention target whose target cache contains linear-attention layers): the pipeline drives generation purely through the new cache.snapshot() / cache.restore() calls and produces "2 + 2 equals 4."

What does this PR do?

Fixes # (issue)

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

…rollback

Speculative decoding (and any "forward then maybe undo" pattern) needs to
roll a cache back to a previous state. For full-attention layers the existing
`crop()` API works, but for linear-attention layers — whose `crop()` is a
documented no-op because the recurrent state has already absorbed every
observed token — there has been no supported way to undo a forward pass.
Block-diffusion speculative decoders silently corrupt their target cache
on hybrid-attention models (e.g. Qwen3.5) as a result.

This change adds:

  * `CacheLayerMixin.snapshot()` / `restore(snapshot)` covering full-attn
    layers (`DynamicLayer`, `DynamicSlidingWindowLayer`).
  * `LinearAttentionCacheLayerMixin.snapshot()` / `restore(snapshot)` for
    linear-attn layers. `restore()` uses `.copy_()` into existing tensors
    when shapes match so cudagraph static-address assumptions survive.
  * `LinearAttentionAndFullAttentionLayer` overrides that merge both parents.
  * `Cache.snapshot()` / `Cache.restore(snapshot)` that delegate per-layer
    and reject layout mismatches.

The shape of a snapshot is opaque to callers — the return value is meant
to be passed back to `restore()` on the same cache. The intended pattern is

    snap = cache.snapshot()
    run_target_on_speculative_block(...)
    if partial_accept:
        cache.restore(snap)
        run_target_on_accepted_prefix(...)

Tests in `tests/utils/test_cache_utils.py::CacheSnapshotRestoreTest` cover
DynamicLayer round-trip, LinearAttentionLayer round-trip (including
static-address preservation), DynamicCache aggregate round-trip, and a
layout-mismatch rejection.

Verified end-to-end with diffusers' DFlashPipeline on z-lab/Qwen3.5-4B-DFlash
+ Qwen/Qwen3.5-4B (a hybrid-attention target whose target cache contains
linear-attention layers): the pipeline drives generation purely through the
new `cache.snapshot()` / `cache.restore()` calls and produces "2 + 2 equals 4."
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a snapshot/restore rollback primitive to Cache and its layer types to support speculative decoding and other “tentative forward then undo” workflows, including for linear-attention layers where crop() is a no-op.

Changes:

  • Add snapshot() / restore(snapshot) to full-attention cache layers (CacheLayerMixin, DynamicSlidingWindowLayer) and to linear-attention cache layers (LinearAttentionCacheLayerMixin), plus hybrid merge behavior in LinearAttentionAndFullAttentionLayer.
  • Add Cache.snapshot() / Cache.restore(snapshot) delegating per-layer snapshots with a basic layout-size check.
  • Add unit tests covering round-trips for DynamicLayer, LinearAttentionLayer, DynamicCache, and a layer-count mismatch rejection.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
src/transformers/cache_utils.py Introduces snapshot/restore APIs on cache layers and Cache to enable rollback for speculative decoding, including linear-attention support with static-address-preserving restores.
tests/utils/test_cache_utils.py Adds unit tests validating snapshot/restore round-trips and a basic mismatch rejection.

Comment on lines +116 to 128
return {
"is_initialized": self.is_initialized,
"keys": self.keys.clone() if self.is_initialized and self.keys is not None else None,
"values": self.values.clone() if self.is_initialized and self.values is not None else None,
}

def restore(self, snapshot: dict) -> None:
"""Restore this layer to the state captured by a previous `snapshot()` call."""
self.is_initialized = bool(snapshot["is_initialized"])
self.keys = snapshot["keys"]
self.values = snapshot["values"]


Comment on lines +821 to +829
elif self.conv_states is not None and self.conv_states.shape == snapshot["conv_states"].shape:
self.conv_states.copy_(snapshot["conv_states"])
else:
self.conv_states = snapshot["conv_states"].clone()

if snapshot["recurrent_states"] is None:
self.recurrent_states = None
elif self.recurrent_states is not None and self.recurrent_states.shape == snapshot["recurrent_states"].shape:
self.recurrent_states.copy_(snapshot["recurrent_states"])
raise ValueError(
f"Snapshot has {len(snapshot)} layers but cache has {len(self.layers)}; the snapshot was likely "
"taken from a different cache."
)
self.assertEqual(cache.layers[1].get_seq_length(), 3)
self.assertTrue(torch.equal(cache.layers[0].keys, keys_a))
self.assertTrue(torch.equal(cache.layers[1].keys, keys_b))

@kashif kashif changed the title [Cache] Add Cache.snapshot() / Cache.restore(snapshot) for tentative-forward rollback [Cache] Add Cache.snapshot() / Cache.restore(snapshot) for tentative-forward rollback May 8, 2026
@kashif kashif changed the title [Cache] Add Cache.snapshot() / Cache.restore(snapshot) for tentative-forward rollback [Cache] Add Cache.snapshot() / Cache.restore(snapshot) for tentative-forward rollback May 8, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

kashif added 2 commits May 8, 2026 12:24
`StaticLayer` (and `StaticSlidingWindowLayer`) preallocate K/V tensors
with `mark_static_address` and track `cumulative_length` separately, so
the base class's reassign-and-no-cumulative-length default produced
silently-wrong rollbacks: `get_seq_length()` was off and the static
addresses were lost. Override snapshot/restore on those layers to
capture `cumulative_length` (and `cumulative_length_int` on the sliding
variant) and to `.copy_()` into the existing tensors so cudagraph
captures survive.

Replaces the previously over-broad fast tests with a tighter set:
LinearAttentionLayer round-trip + static-address preservation,
DynamicCache aggregate round-trip, the hybrid
LinearAttentionAndFullAttentionLayer round-trip, and a StaticLayer
round-trip that asserts both length bookkeeping and tensor identity
preservation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants