Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@
title: Z-Image
title: Image
- sections:
- local: api/pipelines/dflash
title: DFlash
- local: api/pipelines/llada2
title: LLaDA2
title: Text
Expand Down Expand Up @@ -711,6 +713,8 @@
title: DDPMScheduler
- local: api/schedulers/deis
title: DEISMultistepScheduler
- local: api/schedulers/dflash_token_diffusion
title: DFlashTokenDiffusionScheduler
- local: api/schedulers/multistep_dpm_solver_inverse
title: DPMSolverMultistepInverse
- local: api/schedulers/multistep_dpm_solver
Expand Down
88 changes: 88 additions & 0 deletions docs/source/en/api/pipelines/dflash.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# DFlash

[DFlash](https://huggingface.co/collections/z-lab/dflash) is a block-diffusion speculative decoding scheme. A small
diffusion *draft* model proposes a block of tokens conditioned on hidden features extracted from intermediate layers
of a frozen *target* causal LM; the target then verifies the proposed block in a single forward pass and accepts the
longest matching prefix. The draft model is shared with the target's tokenizer, so no calibration is needed.

`DFlashPipeline` ties the two models together: prefill on the target, draft a block, verify against the target's
posterior via [`DFlashTokenDiffusionScheduler`], commit the accepted prefix and the next-token resample, and repeat
until `max_new_tokens` or a stop token. Compatible draft/target pairs include `z-lab/Qwen3-8B-DFlash-b16` with
`Qwen/Qwen3-8B`, and `z-lab/Qwen3.5-4B-DFlash` with `Qwen/Qwen3.5-4B` (the latter is a hybrid-attention target — see
the rollback note below).

## Usage

```py
import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer

from diffusers import DFlashPipeline

# Draft ships custom modeling code via `auto_map` — `trust_remote_code=True` is required.
draft = AutoModel.from_pretrained(
"z-lab/Qwen3.5-4B-DFlash", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
)
target = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-4B", dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B")

pipe = DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer)
output = pipe(
prompt="What is 2 + 2? Answer in one sentence.",
max_new_tokens=128,
temperature=0.0,
chat_template_kwargs={"enable_thinking": False},
)
print(output.texts[0])
```

`DFlashPipeline` currently runs `batch_size=1` only. Multi-prompt batching requires per-row partial-accept tracking
and is not yet supported.

## Hybrid-attention targets

For target models with linear-attention layers (e.g. Qwen3.5's gated-delta-net), `DynamicCache.crop()` is a
documented no-op on those layers (see `transformers.cache_utils.LinearAttentionCacheLayerMixin.crop`), so a
partial-accept block would otherwise leak rejected speculative tokens into the recurrent state. The pipeline
detects linear-attention caches via [`DFlashTokenDiffusionScheduler.cache_has_linear_attention`] and uses a
snapshot/restore + accepted-prefix re-forward pattern to advance both layer types cleanly. This adds one extra
target forward per partial-accept block on hybrid targets; full-attention targets use a plain `cache.crop()`.

## Callbacks

Callbacks run after each block-verify step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are
included in `callback_kwargs`. Allowed keys: `block_output_ids` (the drafted block), `draft_logits`,
`accepted_length`, `next_token`, and `output_ids` (the running output buffer). Return `{"output_ids": ...}` from the
callback to replace the buffer.

```py
def on_step_end(pipe, step, timestep, callback_kwargs):
output_ids = callback_kwargs["output_ids"]
return {"output_ids": output_ids}

out = pipe(
prompt="...",
callback_on_step_end=on_step_end,
callback_on_step_end_tensor_inputs=["output_ids"],
)
```

## DFlashPipeline
[[autodoc]] DFlashPipeline
- all
- __call__

## DFlashPipelineOutput
[[autodoc]] pipelines.DFlashPipelineOutput
33 changes: 33 additions & 0 deletions docs/source/en/api/schedulers/dflash_token_diffusion.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# DFlashTokenDiffusionScheduler

[`DFlashTokenDiffusionScheduler`] implements the verification step for DFlash-style block-diffusion speculative
decoding. It samples a posterior block from the target logits, computes the acceptance length as the longest prefix
where the draft proposal matches the posterior, and exposes the resampled `next_token` for the first rejected
position. Used by [`DFlashPipeline`].

The scheduler also owns three helpers used by the pipeline's verify loop on hybrid-attention targets:

- `cache_has_linear_attention(cache)` — detect whether a `DynamicCache` contains any linear-attention layers.
- `snapshot_cache(cache)` / `restore_cache(cache, snapshot)` — clone and restore the full per-layer state so a
partial-accept block can be rolled back and the target re-advanced on just the accepted prefix.

These exist because `DynamicCache.crop()` silently no-ops on linear-attention layers, which would otherwise let
rejected speculative tokens permanently contaminate the recurrent state.

## DFlashTokenDiffusionScheduler
[[autodoc]] DFlashTokenDiffusionScheduler

## DFlashTokenDiffusionSchedulerOutput
[[autodoc]] schedulers.scheduling_dflash_token_diffusion.DFlashTokenDiffusionSchedulerOutput
41 changes: 41 additions & 0 deletions examples/discrete_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,44 @@ python examples/discrete_diffusion/sample_llada2.py \
--use_chat_template \
--add_generation_prompt
```

## DFlash

[DFlash](https://huggingface.co/collections/z-lab/dflash) is a block-diffusion **speculative decoding** scheme: a small diffusion *draft* model, conditioned on hidden features from a frozen *target* causal LM, proposes a block of tokens that the target verifies in a single forward pass. The pipeline accepts the longest matching prefix and resamples the next token at the rejection point.

### Sample

The published draft pairs with a stock target (no `trust_remote_code` for the target):

```bash
python examples/discrete_diffusion/sample_dflash.py \
--draft_model_id z-lab/Qwen3.5-4B-DFlash \
--target_model_id Qwen/Qwen3.5-4B \
--prompt "How many positive whole-number divisors does 196 have?" \
--max_new_tokens 4096
```

The draft ships a custom `DFlashDraftModel` class via `auto_map`, so the sample script loads it with `trust_remote_code=True`; the target loads as a stock Qwen3 / Qwen3.5 model. Per-draft thinking-mode defaults from the upstream model cards:

| Draft | `enable_thinking` |
|---|---|
| `z-lab/Qwen3.5-*-DFlash` | `True` |
| `z-lab/Qwen3-*-DFlash-b16` | `False` (drafts are non-thinking-trained) |

### Train

The training loop conditions the draft on intermediate target hidden states and predicts the next `block_size − 1` tokens of each block:

```bash
accelerate launch examples/discrete_diffusion/train_dflash.py \
--draft_model_id z-lab/Qwen3-4B-DFlash-b16 \
--target_model_id Qwen/Qwen3-4B \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--text_column text \
--output_dir dflash-output \
--max_train_steps 1000 \
--learning_rate 2e-5
```

`--block_size 0` (default) reads the block size from the draft model's config (16 for the b16 drafts, 16 for `Qwen3.5-*-DFlash`).
145 changes: 145 additions & 0 deletions examples/discrete_diffusion/sample_dflash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Sample script for DFlash speculative decoding.

Example:
python sample_dflash.py \
--draft_model_id z-lab/Qwen3-8B-DFlash-b16 \
--target_model_id Qwen/Qwen3-8B \
--prompt "How many positive whole-number divisors does 196 have?" \
--max_new_tokens 256
"""

import argparse

import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer

from diffusers import DFlashPipeline


def main():
parser = argparse.ArgumentParser(description="Run DFlash speculative decoding.")
parser.add_argument(
"--draft_model_id",
type=str,
default="z-lab/Qwen3-8B-DFlash-b16",
help="Draft model ID or local path.",
)
parser.add_argument(
"--target_model_id",
type=str,
default="Qwen/Qwen3-8B",
help="Target model ID or local path.",
)
parser.add_argument(
"--prompt",
type=str,
default="How many positive whole-number divisors does 196 have?",
help="Prompt text to generate from.",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=2048,
help="Maximum number of new tokens to generate.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Sampling temperature.",
)
parser.add_argument(
"--use_chat_template",
action="store_true",
help="Use the tokenizer chat template for the prompt.",
)
parser.add_argument(
"--add_generation_prompt",
action="store_true",
help="Add the generation prompt when using the chat template.",
)
parser.add_argument(
"--enable_thinking",
action="store_true",
help="Enable chat-template thinking mode if supported by the tokenizer.",
)
parser.add_argument(
"--mask_token",
type=str,
default="<|MASK|>",
help="Mask token to add if the tokenizer does not define one.",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run inference on.",
)
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["auto", "float32", "float16", "bfloat16"],
help="Model dtype.",
)

args = parser.parse_args()

dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
torch_dtype = dtype_map.get(args.dtype)

print(f"Loading draft model: {args.draft_model_id}")
print(f"Loading target model: {args.target_model_id}")
dtype_arg = torch_dtype if torch_dtype is not None else "auto"
# Draft model is a custom DFlashDraftModel; trust_remote_code routes to the class in `auto_map`.
draft_model = AutoModel.from_pretrained(
args.draft_model_id,
trust_remote_code=True,
dtype=dtype_arg,
device_map=args.device,
)
target_model = AutoModelForCausalLM.from_pretrained(
args.target_model_id,
dtype=dtype_arg,
device_map=args.device,
)
tokenizer = AutoTokenizer.from_pretrained(args.target_model_id)
if tokenizer.mask_token is None:
tokenizer.add_special_tokens({"mask_token": args.mask_token})
pipe = DFlashPipeline(draft_model=draft_model, target_model=target_model, tokenizer=tokenizer)

chat_kwargs = {"enable_thinking": args.enable_thinking}

print(f"\nPrompt: {args.prompt}")
output = pipe(
prompt=args.prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
use_chat_template=args.use_chat_template,
add_generation_prompt=args.add_generation_prompt,
chat_template_kwargs=chat_kwargs,
)

print("\nGenerated text:")
print(output.texts[0])
print(f"\nGenerated {output.sequences.shape[1]} tokens")


if __name__ == "__main__":
main()
Loading
Loading