From 7dd7d9c686d21973d6179d0261f50a6890bde0b1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 8 May 2026 10:03:14 +0000 Subject: [PATCH 1/6] [Discrete Diffusion] Add DFlash pipeline Adds DFlashPipeline + DFlashTokenDiffusionScheduler for block-diffusion speculative decoding with a draft DFlash model and a target causal LM. Verified against the six bug patterns surfaced in the LLaDA2 review (#13598). DFlash sidesteps most of them by being batch_size=1 only and relying on the causal default for attention; the applicable patterns (#3 callback bindings, #4 EOS at first generated position, #6 inner progress-bar config preservation) are pinned by regression tests. Public surface mirrors the LLaDA2 / SDAR / IDLM conventions: lazy import, dummy objects, scheduler + output dataclass, pipeline + output dataclass, fast tests for both, scheduler doc page, pipeline doc page. Sample/train scripts under examples/discrete_diffusion/. --- docs/source/en/_toctree.yml | 4 + docs/source/en/api/pipelines/dflash.md | 24 + .../api/schedulers/dflash_token_diffusion.md | 22 + examples/discrete_diffusion/sample_dflash.py | 145 +++++ examples/discrete_diffusion/train_dflash.py | 319 ++++++++++ src/diffusers/__init__.py | 8 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/dflash/__init__.py | 47 ++ .../pipelines/dflash/pipeline_dflash.py | 552 ++++++++++++++++++ src/diffusers/schedulers/__init__.py | 8 + .../scheduling_dflash_token_diffusion.py | 277 +++++++++ src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 30 + tests/pipelines/dflash/__init__.py | 1 + tests/pipelines/dflash/test_dflash.py | 448 ++++++++++++++ .../test_scheduler_dflash_token_diffusion.py | 310 ++++++++++ 16 files changed, 2227 insertions(+) create mode 100644 docs/source/en/api/pipelines/dflash.md create mode 100644 docs/source/en/api/schedulers/dflash_token_diffusion.md create mode 100644 examples/discrete_diffusion/sample_dflash.py create mode 100644 examples/discrete_diffusion/train_dflash.py create mode 100644 src/diffusers/pipelines/dflash/__init__.py create mode 100644 src/diffusers/pipelines/dflash/pipeline_dflash.py create mode 100644 src/diffusers/schedulers/scheduling_dflash_token_diffusion.py create mode 100644 tests/pipelines/dflash/__init__.py create mode 100644 tests/pipelines/dflash/test_dflash.py create mode 100644 tests/schedulers/test_scheduler_dflash_token_diffusion.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8e8776d4a8c2..dc934c832b8e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -648,6 +648,8 @@ title: Z-Image title: Image - sections: + - local: api/pipelines/dflash + title: DFlash - local: api/pipelines/llada2 title: LLaDA2 title: Text @@ -711,6 +713,8 @@ title: DDPMScheduler - local: api/schedulers/deis title: DEISMultistepScheduler + - local: api/schedulers/dflash_token_diffusion + title: DFlashTokenDiffusionScheduler - local: api/schedulers/multistep_dpm_solver_inverse title: DPMSolverMultistepInverse - local: api/schedulers/multistep_dpm_solver diff --git a/docs/source/en/api/pipelines/dflash.md b/docs/source/en/api/pipelines/dflash.md new file mode 100644 index 000000000000..95847e1fdd82 --- /dev/null +++ b/docs/source/en/api/pipelines/dflash.md @@ -0,0 +1,24 @@ + + +# DFlash + +`DFlashPipeline` performs block-diffusion speculative decoding using a diffusion draft model and a target causal LM. +The draft model is conditioned on target hidden features extracted during prefill and verification steps. + +## DFlashPipeline +[[autodoc]] DFlashPipeline + - all + - __call__ + +## DFlashPipelineOutput +[[autodoc]] pipelines.DFlashPipelineOutput diff --git a/docs/source/en/api/schedulers/dflash_token_diffusion.md b/docs/source/en/api/schedulers/dflash_token_diffusion.md new file mode 100644 index 000000000000..c98b11bc9963 --- /dev/null +++ b/docs/source/en/api/schedulers/dflash_token_diffusion.md @@ -0,0 +1,22 @@ + + +# DFlashTokenDiffusionScheduler + +`DFlashTokenDiffusionScheduler` implements the acceptance and posterior sampling logic used in DFlash-style block +diffusion speculative decoding. + +## DFlashTokenDiffusionScheduler +[[autodoc]] DFlashTokenDiffusionScheduler + +## DFlashTokenDiffusionSchedulerOutput +[[autodoc]] schedulers.scheduling_dflash_token_diffusion.DFlashTokenDiffusionSchedulerOutput diff --git a/examples/discrete_diffusion/sample_dflash.py b/examples/discrete_diffusion/sample_dflash.py new file mode 100644 index 000000000000..a10899a0d052 --- /dev/null +++ b/examples/discrete_diffusion/sample_dflash.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample script for DFlash speculative decoding. + +Example: + python sample_dflash.py \ + --draft_model_id z-lab/Qwen3-8B-DFlash-b16 \ + --target_model_id Qwen/Qwen3-8B \ + --prompt "How many positive whole-number divisors does 196 have?" \ + --max_new_tokens 256 +""" + +import argparse + +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +from diffusers import DFlashPipeline + + +def main(): + parser = argparse.ArgumentParser(description="Run DFlash speculative decoding.") + parser.add_argument( + "--draft_model_id", + type=str, + default="z-lab/Qwen3-8B-DFlash-b16", + help="Draft model ID or local path.", + ) + parser.add_argument( + "--target_model_id", + type=str, + default="Qwen/Qwen3-8B", + help="Target model ID or local path.", + ) + parser.add_argument( + "--prompt", + type=str, + default="How many positive whole-number divisors does 196 have?", + help="Prompt text to generate from.", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=2048, + help="Maximum number of new tokens to generate.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature.", + ) + parser.add_argument( + "--use_chat_template", + action="store_true", + help="Use the tokenizer chat template for the prompt.", + ) + parser.add_argument( + "--add_generation_prompt", + action="store_true", + help="Add the generation prompt when using the chat template.", + ) + parser.add_argument( + "--enable_thinking", + action="store_true", + help="Enable chat-template thinking mode if supported by the tokenizer.", + ) + parser.add_argument( + "--mask_token", + type=str, + default="<|MASK|>", + help="Mask token to add if the tokenizer does not define one.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run inference on.", + ) + parser.add_argument( + "--dtype", + type=str, + default="auto", + choices=["auto", "float32", "float16", "bfloat16"], + help="Model dtype.", + ) + + args = parser.parse_args() + + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(args.dtype) + + print(f"Loading draft model: {args.draft_model_id}") + print(f"Loading target model: {args.target_model_id}") + dtype_arg = torch_dtype if torch_dtype is not None else "auto" + # Draft model is a custom DFlashDraftModel; use AutoModel so trust_remote_code routes to the class in `auto_map`. + draft_model = AutoModel.from_pretrained( + args.draft_model_id, + trust_remote_code=True, + dtype=dtype_arg, + device_map=args.device, + ) + target_model = AutoModelForCausalLM.from_pretrained( + args.target_model_id, + dtype=dtype_arg, + device_map=args.device, + ) + tokenizer = AutoTokenizer.from_pretrained(args.target_model_id) + if tokenizer.mask_token is None: + tokenizer.add_special_tokens({"mask_token": args.mask_token}) + pipe = DFlashPipeline(draft_model=draft_model, target_model=target_model, tokenizer=tokenizer) + + chat_kwargs = {"enable_thinking": args.enable_thinking} + + print(f"\nPrompt: {args.prompt}") + output = pipe( + prompt=args.prompt, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + use_chat_template=args.use_chat_template, + add_generation_prompt=args.add_generation_prompt, + chat_template_kwargs=chat_kwargs, + ) + + print("\nGenerated text:") + print(output.texts[0]) + print(f"\nGenerated {output.sequences.shape[1]} tokens") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_dflash.py b/examples/discrete_diffusion/train_dflash.py new file mode 100644 index 000000000000..673a2173a058 --- /dev/null +++ b/examples/discrete_diffusion/train_dflash.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + draft_model_id: str + target_model_id: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + block_size: int + mask_token: str + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Fine-tune a DFlash draft model with target-conditioned blocks.") + + parser.add_argument("--draft_model_id", type=str, default="z-lab/Qwen3-4B-DFlash-b16") + parser.add_argument("--target_model_id", type=str, default="Qwen/Qwen3-4B") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + + parser.add_argument("--output_dir", type=str, default="dflash-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=2) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--weight_decay", type=float, default=0.0) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=512) + parser.add_argument( + "--block_size", type=int, default=0, help="Override draft block size (0 uses the model config)." + ) + parser.add_argument("--mask_token", type=str, default="<|MASK|>") + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer(texts, truncation=True, padding=False, max_length=max_length) + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int): + if num_draft_layers == 1: + return [int(num_target_layers // 2)] + start = 1 + end = int(num_target_layers) - 3 + span = end - start + return [int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(int(num_draft_layers))] + + +def extract_context_feature(hidden_states, layer_ids): + offset = 1 + selected_states = [hidden_states[layer_id + offset] for layer_id in layer_ids] + return torch.cat(selected_states, dim=-1) + + +def get_target_input_embeddings(model: torch.nn.Module) -> torch.nn.Module: + embeddings = model.get_input_embeddings() + if embeddings is None: + base = getattr(model, "model", None) + embeddings = getattr(base, "embed_tokens", None) + if embeddings is None: + raise ValueError("Target model must expose input embeddings.") + return embeddings + + +def get_target_output_embeddings(model: torch.nn.Module) -> torch.nn.Module: + embeddings = model.get_output_embeddings() + if embeddings is None: + embeddings = getattr(model, "lm_head", None) + if embeddings is None: + raise ValueError("Target model must expose output embeddings.") + return embeddings + + +def main(): + cfg = parse_args() + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.target_model_id, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": cfg.mask_token}) + + draft_model = AutoModel.from_pretrained(cfg.draft_model_id, trust_remote_code=True) + target_model = AutoModelForCausalLM.from_pretrained(cfg.target_model_id) + target_model.eval() + target_model.requires_grad_(False) + + mask_token_id = tokenizer.mask_token_id + if mask_token_id is None: + raise ValueError("Tokenizer must define a mask token for DFlash training.") + + input_embeddings = get_target_input_embeddings(target_model) + output_embeddings = get_target_output_embeddings(target_model) + + block_size = int(cfg.block_size) + if block_size <= 0: + block_size = getattr(draft_model, "block_size", None) or getattr( + getattr(draft_model, "config", None), "block_size", None + ) + if block_size is None: + raise ValueError("Draft model must define `block_size` or pass --block_size.") + block_size = int(block_size) + if block_size < 2: + raise ValueError("`block_size` must be at least 2 for DFlash training.") + + layer_ids = getattr(draft_model, "target_layer_ids", None) + if layer_ids is None: + cfg_draft = getattr(draft_model, "config", None) + num_target_layers = getattr(cfg_draft, "num_target_layers", None) + num_hidden_layers = getattr(cfg_draft, "num_hidden_layers", None) + if num_target_layers is None or num_hidden_layers is None: + raise ValueError("Draft model must expose `target_layer_ids` or `num_target_layers` in config.") + layer_ids = build_target_layer_ids(int(num_target_layers), int(num_hidden_layers)) + + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(draft_model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + draft_model, optimizer, train_dataloader, lr_scheduler, target_model = accelerator.prepare( + draft_model, optimizer, train_dataloader, lr_scheduler, target_model + ) + input_embeddings = get_target_input_embeddings(target_model) + output_embeddings = get_target_output_embeddings(target_model) + + global_step = 0 + draft_model.train() + + for epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(draft_model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + valid_lengths = attention_mask.sum(dim=1) + min_valid = int(valid_lengths.min().item()) + if min_valid <= block_size: + continue + + max_start = min_valid - block_size + start = torch.randint(1, max_start + 1, (1,), device=input_ids.device).item() + + block_output_ids = torch.full( + (input_ids.shape[0], block_size), + int(mask_token_id), + device=input_ids.device, + dtype=torch.long, + ) + block_output_ids[:, 0] = input_ids[:, start] + block_targets = input_ids[:, start + 1 : start + block_size] + block_mask = attention_mask[:, start + 1 : start + block_size] + + position_ids = torch.arange(start, start + block_size, device=input_ids.device).unsqueeze(0) + position_ids = position_ids.expand(input_ids.shape[0], -1) + + with torch.no_grad(): + target_out = target_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + target_hidden = extract_context_feature(target_out.hidden_states, layer_ids) + target_hidden = target_hidden[:, :start, :] + + noise_embedding = input_embeddings(block_output_ids) + draft_hidden = draft_model( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids, + use_cache=False, + is_causal=False, + ) + if not torch.is_tensor(draft_hidden): + draft_hidden = getattr(draft_hidden, "last_hidden_state", draft_hidden[0]) + + logits = output_embeddings(draft_hidden[:, -block_size + 1 :, :]) + vocab_size = logits.shape[-1] + loss = F.cross_entropy(logits.view(-1, vocab_size), block_targets.reshape(-1), reduction="none") + loss = loss.view(block_targets.shape[0], -1) + loss = (loss * block_mask.to(loss.dtype)).sum() / block_mask.sum().clamp_min(1) + + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info("step=%d loss=%.4f lr=%.6g", global_step, loss.item(), lr_scheduler.get_last_lr()[0]) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(draft_model) + unwrapped.save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(draft_model) + unwrapped.save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1b1f6b3032b3..f8e683bcd76a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -371,6 +371,8 @@ "DDPMScheduler", "DDPMWuerstchenScheduler", "DEISMultistepScheduler", + "DFlashTokenDiffusionScheduler", + "DFlashTokenDiffusionSchedulerOutput", "DPMSolverMultistepInverseScheduler", "DPMSolverMultistepScheduler", "DPMSolverSinglestepScheduler", @@ -539,6 +541,8 @@ "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", "CycleDiffusionPipeline", + "DFlashPipeline", + "DFlashPipelineOutput", "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", @@ -1189,6 +1193,8 @@ DDPMScheduler, DDPMWuerstchenScheduler, DEISMultistepScheduler, + DFlashTokenDiffusionScheduler, + DFlashTokenDiffusionSchedulerOutput, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, @@ -1336,6 +1342,8 @@ CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, CycleDiffusionPipeline, + DFlashPipeline, + DFlashPipelineOutput, EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index f0fc7585bf31..cd47185bb6a6 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -319,6 +319,7 @@ ] ) _import_structure["latte"] = ["LattePipeline"] + _import_structure["dflash"] = ["DFlashPipeline", "DFlashPipelineOutput"] _import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] _import_structure["ltx"] = [ "LTXPipeline", @@ -693,6 +694,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) + from .dflash import DFlashPipeline, DFlashPipelineOutput from .easyanimate import ( EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, diff --git a/src/diffusers/pipelines/dflash/__init__.py b/src/diffusers/pipelines/dflash/__init__.py new file mode 100644 index 000000000000..c5d0f5fae4cd --- /dev/null +++ b/src/diffusers/pipelines/dflash/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_dflash"] = ["DFlashPipeline", "DFlashPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_dflash import DFlashPipeline, DFlashPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/dflash/pipeline_dflash.py b/src/diffusers/pipelines/dflash/pipeline_dflash.py new file mode 100644 index 000000000000..e8b0276db109 --- /dev/null +++ b/src/diffusers/pipelines/dflash/pipeline_dflash.py @@ -0,0 +1,552 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +import torch +from tqdm.auto import tqdm +from transformers import DynamicCache + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...schedulers import DFlashTokenDiffusionScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import DFlashPipeline + >>> from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + + >>> draft = AutoModel.from_pretrained( + ... "z-lab/Qwen3-8B-DFlash-b16", trust_remote_code=True, torch_dtype=torch.bfloat16 + ... ) + >>> target = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B", torch_dtype=torch.bfloat16) + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + >>> pipe = DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer) + >>> out = pipe(prompt="How many positive whole-number divisors does 196 have?") + >>> print(out.texts[0]) + ``` +""" + + +@dataclass +class DFlashPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: list[str] | None = None + + +def _build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> list[int]: + if num_draft_layers == 1: + return [int(num_target_layers // 2)] + start = 1 + end = int(num_target_layers) - 3 + span = end - start + return [int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(int(num_draft_layers))] + + +def _extract_context_feature(hidden_states: list[torch.Tensor], layer_ids: list[int]) -> torch.Tensor: + offset = 1 + selected_states = [hidden_states[layer_id + offset] for layer_id in layer_ids] + return torch.cat(selected_states, dim=-1) + + +class DFlashPipeline(DiffusionPipeline): + r""" + Block diffusion pipeline for speculative decoding with a DFlash draft model and a target causal LM. + """ + + draft_model: Any + target_model: Any + tokenizer: Any + scheduler: DFlashTokenDiffusionScheduler + _callback_tensor_inputs = ["block_output_ids", "draft_logits", "accepted_length", "next_token", "output_ids"] + + def __init__( + self, + draft_model: torch.nn.Module, + target_model: torch.nn.Module, + tokenizer: Any | None = None, + scheduler: DFlashTokenDiffusionScheduler | None = None, + ): + super().__init__() + if scheduler is None: + scheduler = DFlashTokenDiffusionScheduler() + self.register_modules( + draft_model=draft_model, target_model=target_model, tokenizer=tokenizer, scheduler=scheduler + ) + + # --- Prompt encoding --- + + def _prepare_input_ids( + self, + *, + prompt: str | list[str] | None, + messages: list[dict[str, str]] | None, + input_ids: torch.LongTensor | None, + use_chat_template: bool, + add_generation_prompt: bool, + chat_template_kwargs: dict[str, Any] | None, + ) -> torch.LongTensor: + """Convert prompt/messages/input_ids to a `[batch, seq]` LongTensor.""" + if input_ids is not None: + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + if input_ids.ndim != 2: + raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + return input_ids + + if self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + if messages is not None and prompt is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if messages is None and prompt is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + + chat_template_kwargs = chat_template_kwargs or {} + + if messages is not None: + encoded = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return encoded["input_ids"] + + if use_chat_template and getattr(self.tokenizer, "chat_template", None): + if isinstance(prompt, list): + raise ValueError("`prompt` must be a string when `use_chat_template=True`.") + encoded = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return encoded["input_ids"] + + encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) + return encoded["input_ids"] + + def check_inputs( + self, + prompt: str | list[str] | None, + messages: list[dict[str, str]] | None, + input_ids: torch.LongTensor | None, + max_new_tokens: int, + output_type: str, + callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None, + callback_on_step_end_tensor_inputs: list[str] | None, + ): + # Input source validation + if prompt is None and messages is None and input_ids is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + if prompt is not None and messages is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if input_ids is not None: + if input_ids.ndim not in (1, 2): + raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + if prompt is not None and input_ids is None and self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + if messages is not None and input_ids is None and self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + # Generation parameter validation + if max_new_tokens <= 0: + raise ValueError(f"`max_new_tokens` must be > 0, got {max_new_tokens}.") + if output_type not in {"seq", "text"}: + raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") + + # Callback validation + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + def prepare_latents( + self, + max_length: int, + block_size: int, + mask_token_id: int, + device: torch.device, + ) -> torch.LongTensor: + return torch.full( + (1, max_length + int(block_size)), + int(mask_token_id), + dtype=torch.long, + device=device, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + messages: list[dict[str, str]] | None = None, + input_ids: torch.LongTensor | None = None, + max_new_tokens: int = 2048, + temperature: float = 0.0, + stop_token_ids: list[int] | None = None, + mask_token_id: int | None = None, + use_chat_template: bool = True, + add_generation_prompt: bool = True, + chat_template_kwargs: dict[str, object] | None = None, + output_type: str = "text", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + ) -> DFlashPipelineOutput | tuple[torch.LongTensor, list[str] | None]: + """ + Generate text using block-diffusion speculative decoding. + + Args: + prompt (`str` or `list[str]`, *optional*): + Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is + available, the prompt is wrapped in a chat message before tokenization. + messages (`list[dict[str, str]]`, *optional*): + Chat messages to encode. Takes precedence over `prompt` when provided. + input_ids (`torch.LongTensor`, *optional*): + Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`. + max_new_tokens (`int`): + Maximum number of new tokens to generate. + temperature (`float`): + Sampling temperature. + stop_token_ids (`list[int]`, *optional*): + Token IDs that signal generation should stop. + mask_token_id (`int`, *optional*): + Mask token ID for the draft model. + use_chat_template (`bool`, defaults to `True`): + Whether to wrap the prompt in a chat template. + add_generation_prompt (`bool`, defaults to `True`): + Whether to add the generation prompt when using chat templates. + chat_template_kwargs (`dict[str, object]`, *optional*): + Additional keyword arguments for the chat template. + output_type (`str`, defaults to `"text"`): + Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw + token ID sequences only. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`DFlashPipelineOutput`] instead of a tuple. + callback_on_step_end (`Callable` or `PipelineCallback`, *optional*): + Callback executed after each speculative decoding step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + Tensor keys to pass to the callback. + + Examples: + """ + # 1. Check inputs early + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["block_output_ids"] + + self.check_inputs( + prompt=prompt, + messages=messages, + input_ids=input_ids, + max_new_tokens=max_new_tokens, + output_type=output_type, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Prepare input IDs from prompt/messages/input_ids + input_ids = self._prepare_input_ids( + prompt=prompt, + messages=messages, + input_ids=input_ids, + use_chat_template=use_chat_template, + add_generation_prompt=add_generation_prompt, + chat_template_kwargs=chat_template_kwargs, + ) + + if mask_token_id is None: + mask_token_id = getattr(getattr(self, "tokenizer", None), "mask_token_id", None) + if mask_token_id is None: + # DFlash models store mask_token_id in config.dflash_config + dflash_config = getattr(getattr(self.draft_model, "config", None), "dflash_config", None) + if dflash_config is not None: + mask_token_id = dflash_config.get("mask_token_id", None) + if mask_token_id is None: + raise ValueError("`mask_token_id` must be provided (or available on the tokenizer/model config).") + if input_ids.shape[0] != 1: + raise ValueError("DFlashPipeline currently supports batch_size=1 input_ids.") + + target_params = list(self.target_model.parameters()) if hasattr(self.target_model, "parameters") else [] + device = target_params[0].device if len(target_params) > 0 else torch.device("cpu") + input_ids = input_ids.to(device=device) + draft_params = list(self.draft_model.parameters()) if hasattr(self.draft_model, "parameters") else [] + draft_device = draft_params[0].device if len(draft_params) > 0 else device + if draft_device != device: + logger.warning( + "Draft model is on %s while target model is on %s. For best performance, place both on the same device.", + draft_device, + device, + ) + + if stop_token_ids is None: + eos_token_id = getattr(getattr(self, "tokenizer", None), "eos_token_id", None) + stop_token_ids = [int(eos_token_id)] if eos_token_id is not None else None + if stop_token_ids is not None: + stop_token_ids = [int(token_id) for token_id in stop_token_ids] + + # 3. Setup scheduler and resolve model attributes + self.scheduler.set_timesteps(1, device=device) + + block_size = self._get_block_size() + + # Resolve target layer IDs from draft model config + layer_ids = getattr(self.draft_model, "target_layer_ids", None) + if layer_ids is not None: + target_layer_ids = list(layer_ids) + else: + cfg = getattr(self.draft_model, "config", None) + num_target_layers = getattr(cfg, "num_target_layers", None) + num_hidden_layers = getattr(cfg, "num_hidden_layers", None) + if num_target_layers is None or num_hidden_layers is None: + raise ValueError( + "`draft_model` must define `target_layer_ids` or expose `num_target_layers` in config." + ) + target_layer_ids = _build_target_layer_ids(int(num_target_layers), int(num_hidden_layers)) + + input_embeddings = self.target_model.get_input_embeddings() + output_embeddings = self.target_model.get_output_embeddings() + + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + int(max_new_tokens) + + output_ids = self.prepare_latents(max_length, block_size, int(mask_token_id), device) + position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0) + + target_config = getattr(self.target_model, "config", None) + draft_config = getattr(self.draft_model, "config", None) + + # Fast path: some draft models (e.g. z-lab/Qwen3-8B-DFlash-b16) ship a self-contained + # `spec_generate` method. Delegate when available — it's the upstream-canonical loop and + # avoids re-implementing rollback. Newer drafts (Qwen3.5-4B-DFlash) drop this method, so + # fall back to the explicit pipeline loop below. + spec_generate = getattr(self.draft_model, "spec_generate", None) + if callable(spec_generate): + generated = spec_generate( + input_ids=input_ids, + max_new_tokens=int(max_new_tokens), + temperature=float(temperature), + target=self.target_model, + stop_token_ids=stop_token_ids, + ) + sequences = generated[:, input_ids.shape[1] :] + texts = None + if output_type == "text" and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + if not return_dict: + return sequences, texts + return DFlashPipelineOutput(sequences=sequences, texts=texts) + + # Pass `config=` only when it looks like a real PretrainedConfig — hybrid-attention models + # (Qwen3.5) need it so `DynamicCache` instantiates the right per-layer cache types + # (linear vs full), but bare dummy configs in tests don't implement `get_text_config`. + def _new_cache(cfg): + if cfg is not None and hasattr(cfg, "get_text_config"): + try: + return DynamicCache(config=cfg) + except Exception: + pass + return DynamicCache() + + past_key_values_target = _new_cache(target_config) + past_key_values_draft = _new_cache(draft_config) + + # 4. Prefill step + output = self._target_forward( + input_ids=input_ids, + position_ids=position_ids[:, :num_input_tokens], + past_key_values=past_key_values_target, + output_hidden_states=True, + logits_to_keep=1, + ) + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens : num_input_tokens + 1] = self.scheduler.sample( + output.logits[:, -1:], temperature=temperature + ) + target_hidden = _extract_context_feature(output.hidden_states, target_layer_ids) + + start = num_input_tokens + global_step = 0 + num_blocks = (max_length - num_input_tokens + block_size - 1) // block_size + + # 5. Block-wise speculative decoding loop + block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy() + block_progress_bar_config["position"] = 0 + block_progress_bar_config["desc"] = "Blocks" + block_iter = tqdm(range(num_blocks), **block_progress_bar_config) + + for _block_idx in block_iter: + if start >= max_length: + break + + block_output_ids = output_ids[:, start : start + int(block_size)].clone() + block_position_ids = position_ids[:, start : start + int(block_size)] + noise_embedding = input_embeddings(block_output_ids) + draft_hidden = self.draft_model( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids[:, past_key_values_draft.get_seq_length() : start + int(block_size)], + past_key_values=past_key_values_draft, + use_cache=True, + is_causal=False, + ) + if not torch.is_tensor(draft_hidden): + draft_hidden = getattr(draft_hidden, "last_hidden_state", draft_hidden[0]) + draft_logits = output_embeddings(draft_hidden[:, -int(block_size) + 1 :, :]) + past_key_values_draft.crop(start) + block_output_ids[:, 1:] = self.scheduler.sample(draft_logits, temperature=temperature) + + # For hybrid-attention targets (Qwen3.5 etc.), linear-attention cache layers silently + # no-op on `.crop()`, so rejected speculative tokens would permanently contaminate the + # recurrent state. Snapshot before the verify forward so we can roll back on partial-accept. + target_needs_rollback = self.scheduler.cache_has_linear_attention(past_key_values_target) + target_snapshot = self.scheduler.snapshot_cache(past_key_values_target) if target_needs_rollback else None + + output = self._target_forward( + input_ids=block_output_ids, + position_ids=block_position_ids, + past_key_values=past_key_values_target, + output_hidden_states=True, + logits_to_keep=None, + ) + step_output = self.scheduler.step( + model_output=output.logits, + timestep=global_step, + sample=block_output_ids, + temperature=temperature, + return_dict=True, + ) + accepted_length = step_output.accepted_length + next_token = step_output.next_token + acceptance_length = int(step_output.accepted_length[0].item()) + output_ids[:, start : start + acceptance_length + 1] = block_output_ids[:, : acceptance_length + 1] + output_ids[:, start + acceptance_length + 1] = step_output.next_token + start += acceptance_length + 1 + partial_accept = acceptance_length + 1 < int(block_size) + if target_needs_rollback and partial_accept: + # Restore linear-attn recurrent state (and full-attn KVs) to pre-verify, then re-run + # target on just the accepted prefix to advance all layer types cleanly to `start`. + self.scheduler.restore_cache(past_key_values_target, target_snapshot) + accepted_ids = block_output_ids[:, : acceptance_length + 1] + accepted_pos = block_position_ids[:, : acceptance_length + 1] + self._target_forward( + input_ids=accepted_ids, + position_ids=accepted_pos, + past_key_values=past_key_values_target, + output_hidden_states=False, + logits_to_keep=1, + ) + elif not target_needs_rollback: + # Full-attn-only cache: cheap crop is fine. + past_key_values_target.crop(start) + target_hidden = _extract_context_feature(output.hidden_states, target_layer_ids)[ + :, : acceptance_length + 1, : + ] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, 0, callback_kwargs) + output_ids = callback_outputs.pop("output_ids", output_ids) + global_step += 1 + + if self.scheduler.check_should_stop(output_ids, stop_token_ids, num_input_tokens): + break + + # 6. Post-process output + output_ids = output_ids[:, :max_length] + output_ids = output_ids[:, output_ids[0] != int(mask_token_id)] + if stop_token_ids is not None: + stop_tensor = torch.tensor(stop_token_ids, device=device, dtype=torch.long) + stop_positions = torch.isin(output_ids[0, num_input_tokens:], stop_tensor).nonzero(as_tuple=True)[0] + if stop_positions.numel() > 0: + output_ids = output_ids[:, : num_input_tokens + int(stop_positions[0].item()) + 1] + + prompt_len = input_ids.shape[1] + sequences = output_ids[:, prompt_len:] + + texts = None + if output_type == "text" and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + if not return_dict: + return sequences, texts + return DFlashPipelineOutput(sequences=sequences, texts=texts) + + def _get_block_size(self) -> int: + cfg = getattr(self.draft_model, "config", None) + block_size = getattr(cfg, "block_size", None) + if block_size is None: + raise ValueError("`draft_model.config` must define `block_size`.") + return int(block_size) + + def _target_forward( + self, + *, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + past_key_values: DynamicCache, + output_hidden_states: bool, + logits_to_keep: int | None, + ): + kwargs = { + "input_ids": input_ids, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": True, + "output_hidden_states": output_hidden_states, + } + if logits_to_keep is not None: + try: + return self.target_model(**kwargs, logits_to_keep=logits_to_keep) + except TypeError: + pass + return self.target_model(**kwargs) + + +__all__ = ["DFlashPipeline", "DFlashPipelineOutput"] diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b1f75bed7dc5..6321d189fc17 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -51,6 +51,10 @@ _import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"] _import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"] _import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"] + _import_structure["scheduling_dflash_token_diffusion"] = [ + "DFlashTokenDiffusionScheduler", + "DFlashTokenDiffusionSchedulerOutput", + ] _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"] @@ -157,6 +161,10 @@ from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler from .scheduling_deis_multistep import DEISMultistepScheduler + from .scheduling_dflash_token_diffusion import ( + DFlashTokenDiffusionScheduler, + DFlashTokenDiffusionSchedulerOutput, + ) from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler diff --git a/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py b/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py new file mode 100644 index 000000000000..e90b7cfde09f --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py @@ -0,0 +1,277 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class DFlashTokenDiffusionSchedulerOutput(BaseOutput): + """ + Output class for DFlash-style speculative token scheduling. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_size)`): + The proposed block tokens from the draft model. + accepted_length (`torch.LongTensor` of shape `(batch_size,)`): + Number of consecutive accepted tokens from the block. + next_token (`torch.LongTensor` of shape `(batch_size,)`): + Next token sampled from the target posterior at the first rejection. + posterior (`torch.LongTensor` of shape `(batch_size, block_size)`): + Sampled tokens from the target posterior used for acceptance checks. + """ + + prev_sample: torch.LongTensor + accepted_length: torch.LongTensor + next_token: torch.LongTensor + posterior: torch.LongTensor + + +class DFlashTokenDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler for DFlash-style block diffusion speculative decoding. + + This scheduler samples target posteriors and computes acceptance lengths for draft blocks. + """ + + order = 1 + + @register_to_config + def __init__(self): + self.num_inference_steps = 1 + self.timesteps = torch.tensor([0], dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = int(num_inference_steps) + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long) + + def sample(self, logits: torch.Tensor, temperature: float = 0.0) -> torch.LongTensor: + if temperature < 1e-5: + return torch.argmax(logits, dim=-1) + bsz, seq_len, vocab_size = logits.shape + flat = logits.view(-1, vocab_size) / float(temperature) + probs = torch.softmax(flat, dim=-1) + return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.LongTensor, + *, + temperature: float = 0.0, + return_dict: bool = True, + ) -> ( + DFlashTokenDiffusionSchedulerOutput + | tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor] + ): + """ + Perform a single speculative decoding verification step. + + Args: + model_output (`torch.Tensor` of shape `(batch_size, block_size, vocab_size)`): + Raw logits from the target model for the current block. + timestep (`int` or `torch.Tensor`): + Current step index (unused for single-step DFlash, kept for interface compatibility). + sample (`torch.LongTensor` of shape `(batch_size, block_size)`): + Draft token IDs proposed by the draft model. + temperature (`float`): + Sampling temperature for the target posterior. + return_dict (`bool`): + Whether to return a `DFlashTokenDiffusionSchedulerOutput` or a tuple. + """ + posterior = self.sample(model_output, temperature=temperature) + if sample.shape[1] > 1: + matches = sample[:, 1:] == posterior[:, :-1] + accepted_length = matches.int().cumprod(dim=1).sum(dim=1) + else: + accepted_length = torch.zeros((sample.shape[0],), device=sample.device, dtype=torch.long) + + next_token = posterior.gather(1, accepted_length.unsqueeze(1)).squeeze(1) + + if not return_dict: + return sample, accepted_length, next_token, posterior + return DFlashTokenDiffusionSchedulerOutput( + prev_sample=sample, + accepted_length=accepted_length, + next_token=next_token, + posterior=posterior, + ) + + @staticmethod + def cache_has_linear_attention(cache) -> bool: + """ + Detect whether a `DynamicCache` contains any linear-attention layers (e.g. Qwen3.5's gated-delta-net layers). + The spec-decoding loop needs this to know whether a partial-accept block requires snapshot/restore rather than + a plain `.crop()` — transformers' `DynamicCache.crop()` silently no-ops on linear-attention layers, so rejected + speculative tokens would otherwise permanently contaminate the recurrent state. + + Duck-typed on `recurrent_states`/`conv_states` attributes to avoid importing transformers. + """ + for layer in getattr(cache, "layers", []): + if hasattr(layer, "recurrent_states") and hasattr(layer, "conv_states"): + return True + return False + + @staticmethod + def snapshot_cache(cache) -> list[dict]: + """ + Clone the full per-layer cache state so a speculative target forward can be rolled back. + + Handles both full-attention `DynamicLayer` (keys/values) and linear-attention layers + (conv_states/recurrent_states plus their init flags). Mirrors upstream DFlash's MLX `_GDNStateCapture` + rollback, but via full-layer restore rather than kernel-level replay. Pair with `restore_cache()`; no-op if the + caller only ever fully-accepts. + """ + snapshots: list[dict] = [] + for layer in getattr(cache, "layers", []): + snap: dict = {"cls": type(layer)} + if hasattr(layer, "keys") and layer.keys is not None: + snap["keys"] = layer.keys.clone() + snap["values"] = layer.values.clone() + if hasattr(layer, "recurrent_states"): + snap["has_previous_state"] = bool(getattr(layer, "has_previous_state", False)) + snap["is_recurrent_states_initialized"] = bool( + getattr(layer, "is_recurrent_states_initialized", False) + ) + snap["is_conv_states_initialized"] = bool(getattr(layer, "is_conv_states_initialized", False)) + snap["recurrent_states"] = ( + layer.recurrent_states.clone() if getattr(layer, "recurrent_states", None) is not None else None + ) + snap["conv_states"] = ( + layer.conv_states.clone() if getattr(layer, "conv_states", None) is not None else None + ) + snapshots.append(snap) + return snapshots + + @staticmethod + def restore_cache(cache, snapshots: list[dict]) -> None: + """ + Restore a cache to the state captured by `snapshot_cache()`. After this call, the caller should re-advance the + cache (e.g. by re-running the target model on just the accepted prefix) so both full- and linear-attention + layers end up at the committed token count. + """ + for layer, snap in zip(cache.layers, snapshots): + if "keys" in snap: + # DynamicLayer: reassign (shapes will have grown during the verify forward, so + # in-place copy is not safe here). + layer.keys = snap["keys"] + layer.values = snap["values"] + if "recurrent_states" in snap: + # LinearAttentionLayer: in-place copy preserves any static-address assumption + # (e.g. for cudagraph capture) on the live tensors. + layer.has_previous_state = snap["has_previous_state"] + layer.is_recurrent_states_initialized = snap["is_recurrent_states_initialized"] + layer.is_conv_states_initialized = snap["is_conv_states_initialized"] + if snap["recurrent_states"] is not None and getattr(layer, "recurrent_states", None) is not None: + layer.recurrent_states.copy_(snap["recurrent_states"]) + elif snap["recurrent_states"] is not None: + layer.recurrent_states = snap["recurrent_states"].clone() + if snap["conv_states"] is not None and getattr(layer, "conv_states", None) is not None: + layer.conv_states.copy_(snap["conv_states"]) + elif snap["conv_states"] is not None: + layer.conv_states = snap["conv_states"].clone() + + @staticmethod + def check_should_stop( + output_ids: torch.LongTensor, + stop_token_ids: list[int] | None, + num_input_tokens: int, + ) -> bool: + """ + Check whether any stop token has been generated in the output sequence. + + Args: + output_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Current output token IDs including prompt and generated tokens. + stop_token_ids (`list[int]` or `None`): + Token IDs that signal generation should stop. + num_input_tokens (`int`): + Number of prompt tokens at the start of the sequence. + + Returns: + `bool`: `True` if generation should stop, `False` otherwise. + """ + if stop_token_ids is None: + return False + stop_tensor = torch.tensor(stop_token_ids, device=output_ids.device, dtype=torch.long) + return torch.isin(output_ids[:, num_input_tokens:], stop_tensor).any().item() + + def add_noise( + self, + original_samples: torch.LongTensor, + attention_mask: torch.LongTensor, + *, + prompt_length: int, + block_size: int, + mask_token_id: int, + generator: torch.Generator | None = None, + ) -> tuple[torch.LongTensor, torch.BoolTensor]: + """ + Apply the forward (noising) process for DFlash-style block diffusion training. + + For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with + `mask_token_id`. + + Args: + original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Clean token IDs. + attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Padding mask (1 for valid, 0 for padding). + prompt_length (`int`): + Number of leading prompt tokens to keep unmasked. + block_size (`int`): + Block size for masking. + mask_token_id (`int`): + Token ID to use for masked positions. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + + Returns: + `tuple[torch.LongTensor, torch.BoolTensor]`: + `(noisy, masked)` -- the noisy sequence and the boolean mask indicating which positions were masked. + """ + batch_size, seq_len = original_samples.shape + device = original_samples.device + + noisy = original_samples.clone() + masked = torch.zeros_like(original_samples, dtype=torch.bool) + + valid = attention_mask.to(dtype=torch.bool) + for block_start in range(prompt_length, seq_len, block_size): + block_end = min(seq_len, block_start + block_size) + seg_len = block_end - block_start + if seg_len <= 0: + continue + + p_mask = torch.rand((batch_size, 1), device=device, generator=generator) + seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask + seg = seg & valid[:, block_start:block_end] + + masked[:, block_start:block_end] = seg + + noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy) + return noisy, masked + + +__all__ = ["DFlashTokenDiffusionScheduler", "DFlashTokenDiffusionSchedulerOutput"] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9bfb73c1999e..792942dd1928 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2882,6 +2882,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DFlashTokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DFlashTokenDiffusionSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index cfa1318783f3..9ab1581c0045 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1337,6 +1337,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class DFlashPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class DFlashPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class EasyAnimateControlPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/dflash/__init__.py b/tests/pipelines/dflash/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/tests/pipelines/dflash/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/pipelines/dflash/test_dflash.py b/tests/pipelines/dflash/test_dflash.py new file mode 100644 index 000000000000..7f1f11011c18 --- /dev/null +++ b/tests/pipelines/dflash/test_dflash.py @@ -0,0 +1,448 @@ +import unittest + +import torch + +from diffusers import DFlashPipeline, DFlashTokenDiffusionScheduler + + +class _DummyModelOutput: + def __init__(self, logits, hidden_states=None): + self.logits = logits + self.hidden_states = hidden_states + + +class _DummyConfig: + def __init__(self, block_size, num_target_layers, num_hidden_layers): + self.block_size = block_size + self.num_target_layers = num_target_layers + self.num_hidden_layers = num_hidden_layers + + +class _DummyTargetModel(torch.nn.Module): + """Minimal target (causal LM) model that returns logits and hidden_states.""" + + def __init__(self, vocab_size: int, hidden_dim: int, num_layers: int): + super().__init__() + self.vocab_size = vocab_size + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.embed = torch.nn.Embedding(vocab_size, hidden_dim) + self.lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=False) + + def get_input_embeddings(self): + return self.embed + + def get_output_embeddings(self): + return self.lm_head + + def forward( + self, + input_ids, + position_ids=None, + past_key_values=None, + use_cache=False, + output_hidden_states=False, + logits_to_keep=None, + **kwargs, + ): + bsz, seq_len = input_ids.shape + h = self.embed(input_ids) + # Create hidden_states list: one entry per layer + 1 for the embedding layer + hidden_states = [h] * (self.num_layers + 1) if output_hidden_states else None + logits = self.lm_head(h) + # Make token 0 the most likely so acceptance is deterministic + logits[:, :, 0] = 10.0 + return _DummyModelOutput(logits=logits, hidden_states=hidden_states) + + def parameters(self): + return super().parameters() + + +class _DummyDraftModel(torch.nn.Module): + """Minimal draft model that returns hidden states of the expected shape.""" + + def __init__(self, hidden_dim: int, num_target_layers: int, block_size: int): + super().__init__() + self.block_size = block_size + self.config = _DummyConfig( + block_size=block_size, + num_target_layers=num_target_layers, + num_hidden_layers=1, + ) + # The draft model receives concatenated hidden states from num_target_layers target layers, + # each of dim hidden_dim, and produces a hidden state of dim hidden_dim. + self.proj = torch.nn.Linear(hidden_dim * num_target_layers, hidden_dim, bias=False) + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def device(self): + return self._device_anchor.device + + def forward( + self, + target_hidden, + noise_embedding, + position_ids=None, + past_key_values=None, + use_cache=False, + is_causal=False, + **kwargs, + ): + # Return a tensor with shape (batch, seq_len, hidden_dim) + bsz = noise_embedding.shape[0] + seq_len = position_ids.shape[1] if position_ids is not None else noise_embedding.shape[1] + h = torch.zeros(bsz, seq_len, self.proj.out_features, device=noise_embedding.device) + return h + + +def _make_pipeline(tokenizer=None, vocab_size=32, hidden_dim=16, num_target_layers=4, block_size=4): + target = _DummyTargetModel(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_target_layers) + draft = _DummyDraftModel(hidden_dim=hidden_dim, num_target_layers=1, block_size=block_size) + # Set target_layer_ids directly so we skip the config-based computation. + draft.target_layer_ids = [1] + scheduler = DFlashTokenDiffusionScheduler() + return DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer, scheduler=scheduler) + + +class DFlashPipelineTest(unittest.TestCase): + # ------------------------------------------------------------------ + # Pipeline runs + # ------------------------------------------------------------------ + def test_pipeline_runs_with_input_ids(self): + pipe = _make_pipeline() + input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long) + + out = pipe( + input_ids=input_ids, + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + stop_token_ids=None, + output_type="seq", + ) + + self.assertIsNotNone(out.sequences) + self.assertEqual(out.sequences.ndim, 2) + self.assertEqual(out.sequences.shape[0], 1) + # Generated tokens should not be longer than max_new_tokens + self.assertLessEqual(out.sequences.shape[1], 8) + + # ------------------------------------------------------------------ + # output_type="seq" + # ------------------------------------------------------------------ + def test_output_type_seq(self): + pipe = _make_pipeline() + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + + out = pipe( + input_ids=input_ids, + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + output_type="seq", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNone(out.texts) + + # ------------------------------------------------------------------ + # output_type="text" with mock tokenizer + # ------------------------------------------------------------------ + def test_output_type_text_with_tokenizer(self): + tok = type( + "Tok", + (), + { + "eos_token_id": None, + "mask_token_id": 31, + "batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs], + }, + )() + pipe = _make_pipeline(tokenizer=tok) + + out = pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + temperature=0.0, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNotNone(out.texts) + self.assertEqual(len(out.texts), 1) + self.assertTrue(out.texts[0].startswith("decoded_")) + + def test_output_type_text_without_tokenizer(self): + """output_type='text' without a tokenizer should return texts=None.""" + pipe = _make_pipeline(tokenizer=None) + + out = pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNone(out.texts) + + # ------------------------------------------------------------------ + # output_type invalid + # ------------------------------------------------------------------ + def test_output_type_invalid_raises(self): + pipe = _make_pipeline() + + with self.assertRaises(ValueError): + pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + mask_token_id=31, + output_type="invalid", + ) + + # ------------------------------------------------------------------ + # return_dict=False + # ------------------------------------------------------------------ + def test_pipeline_return_tuple(self): + pipe = _make_pipeline() + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + + result = pipe( + input_ids=input_ids, + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + output_type="seq", + return_dict=False, + ) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + sequences, texts = result + self.assertIsNotNone(sequences) + self.assertIsNone(texts) + + # ------------------------------------------------------------------ + # check_inputs validation + # ------------------------------------------------------------------ + def test_check_inputs_no_inputs_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_both_prompt_and_messages_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt="hello", + messages=[{"role": "user", "content": "hi"}], + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_input_ids_ndim_raises(self): + pipe = _make_pipeline() + bad_ids = torch.zeros(2, 3, 4, dtype=torch.long) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=bad_ids, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_input_ids_dtype_raises(self): + pipe = _make_pipeline() + bad_ids = torch.zeros(1, 4, dtype=torch.float32) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=bad_ids, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_max_new_tokens_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + max_new_tokens=0, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_output_type_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + max_new_tokens=16, + output_type="bad", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_prompt_without_tokenizer_raises(self): + pipe = _make_pipeline(tokenizer=None) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt="hello", + messages=None, + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_messages_without_tokenizer_raises(self): + pipe = _make_pipeline(tokenizer=None) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=[{"role": "user", "content": "hi"}], + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_valid_input_ids_passes(self): + pipe = _make_pipeline() + # Should not raise. + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + # ------------------------------------------------------------------ + # _prepare_input_ids + # ------------------------------------------------------------------ + def test_prepare_input_ids_from_tensor(self): + pipe = _make_pipeline() + ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + result = pipe._prepare_input_ids( + prompt=None, + messages=None, + input_ids=ids, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + self.assertTrue(torch.equal(result, ids)) + + def test_prepare_input_ids_from_1d_tensor(self): + pipe = _make_pipeline() + ids = torch.tensor([1, 2, 3], dtype=torch.long) + result = pipe._prepare_input_ids( + prompt=None, + messages=None, + input_ids=ids, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + self.assertEqual(result.shape, (1, 3)) + + # ------------------------------------------------------------------ + # prepare_latents + # ------------------------------------------------------------------ + def test_prepare_latents(self): + pipe = _make_pipeline() + mask_token_id = 99 + latents = pipe.prepare_latents( + max_length=10, block_size=4, mask_token_id=mask_token_id, device=torch.device("cpu") + ) + self.assertEqual(latents.shape, (1, 14)) # 10 + 4 + self.assertTrue((latents == mask_token_id).all().item()) + self.assertEqual(latents.dtype, torch.long) + + +class DFlashRegressionTest(unittest.TestCase): + """Pin the bug patterns surfaced in https://github.com/huggingface/diffusers/issues/13598 + (LLaDA2 review) for any that are relevant to DFlash. + + DFlash is batch_size=1 only and does not pass an `attention_mask` to the target model, so + issues #1 (padding mask), #2 (block_length scheduler routing), and #5 (batched EOS row freeze) + don't apply. The applicable patterns are #3 (callback keys must resolve), #4 (EOS at the first + generated position), and #6 (progress-bar config must not be mutated by `__call__`). + """ + + def test_callback_tensor_inputs_advertised_keys_resolve(self): + """Issue #3: every advertised callback key must be a bound local at callback time.""" + observed: list[str] = [] + + def cb(pipe, step, timestep, kwargs): + observed.extend(sorted(kwargs.keys())) + return {} + + pipe = _make_pipeline() + keys = list(pipe._callback_tensor_inputs) + pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + stop_token_ids=None, + output_type="seq", + callback_on_step_end=cb, + callback_on_step_end_tensor_inputs=keys, + ) + self.assertEqual(set(observed), set(keys)) + + def test_stop_token_at_first_generated_position_triggers_stop(self): + """Issue #4 analogue: a stop token at index `num_input_tokens` (the first generated + position) must terminate generation. Verified at the scheduler level — `check_should_stop` + searches positions starting at `num_input_tokens`, inclusive.""" + # Sequence layout: prompt = [1, 2] (length 2), first generated token (index 2) is the stop. + output_ids = torch.tensor([[1, 2, 99, 0, 0]], dtype=torch.long) + self.assertTrue(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [99], 2)) + + def test_progress_bar_disable_is_preserved_after_call(self): + """Issue #6: calling the pipeline must not mutate `_progress_bar_config`.""" + pipe = _make_pipeline() + pipe.set_progress_bar_config(disable=True) + before = dict(pipe._progress_bar_config) + pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + stop_token_ids=None, + output_type="seq", + ) + self.assertEqual(pipe._progress_bar_config, before) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_dflash_token_diffusion.py b/tests/schedulers/test_scheduler_dflash_token_diffusion.py new file mode 100644 index 000000000000..560571907d77 --- /dev/null +++ b/tests/schedulers/test_scheduler_dflash_token_diffusion.py @@ -0,0 +1,310 @@ +import tempfile +import unittest + +import torch + +from diffusers import DFlashTokenDiffusionScheduler + + +class DFlashTokenDiffusionSchedulerTest(unittest.TestCase): + def get_scheduler(self): + return DFlashTokenDiffusionScheduler() + + # ------------------------------------------------------------------ + # set_timesteps + # ------------------------------------------------------------------ + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(4) + self.assertEqual(scheduler.num_inference_steps, 4) + self.assertEqual(len(scheduler.timesteps), 4) + self.assertEqual(scheduler.timesteps[0].item(), 3) + self.assertEqual(scheduler.timesteps[-1].item(), 0) + + def test_set_timesteps_single(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(1) + self.assertEqual(scheduler.num_inference_steps, 1) + self.assertEqual(len(scheduler.timesteps), 1) + self.assertEqual(scheduler.timesteps[0].item(), 0) + + def test_set_timesteps_invalid(self): + scheduler = self.get_scheduler() + with self.assertRaises(ValueError): + scheduler.set_timesteps(0) + with self.assertRaises(ValueError): + scheduler.set_timesteps(-1) + + # ------------------------------------------------------------------ + # Config round-trip + # ------------------------------------------------------------------ + def test_save_load_config_round_trip(self): + scheduler = self.get_scheduler() + with tempfile.TemporaryDirectory() as tmpdir: + scheduler.save_config(tmpdir) + loaded = DFlashTokenDiffusionScheduler.from_pretrained(tmpdir) + # The scheduler has no user-configurable params, but it should survive the round-trip. + self.assertIsInstance(loaded, DFlashTokenDiffusionScheduler) + self.assertEqual(loaded.order, 1) + + def test_from_config(self): + scheduler = self.get_scheduler() + new_scheduler = DFlashTokenDiffusionScheduler.from_config(scheduler.config) + self.assertIsInstance(new_scheduler, DFlashTokenDiffusionScheduler) + self.assertEqual(new_scheduler.order, 1) + + # ------------------------------------------------------------------ + # sample() – greedy + # ------------------------------------------------------------------ + def test_sample_greedy(self): + scheduler = self.get_scheduler() + logits = torch.tensor([[[1.0, 5.0, 2.0], [3.0, 1.0, 4.0]]]) # (1, 2, 3) + tokens = scheduler.sample(logits, temperature=0.0) + self.assertEqual(tokens.shape, (1, 2)) + self.assertEqual(tokens[0, 0].item(), 1) # argmax of [1,5,2] + self.assertEqual(tokens[0, 1].item(), 2) # argmax of [3,1,4] + + def test_sample_greedy_batched(self): + scheduler = self.get_scheduler() + logits = torch.tensor( + [ + [[10.0, 0.0], [0.0, 10.0]], + [[0.0, 10.0], [10.0, 0.0]], + ] + ) # (2, 2, 2) + tokens = scheduler.sample(logits, temperature=0.0) + self.assertEqual(tokens.shape, (2, 2)) + self.assertEqual(tokens[0, 0].item(), 0) + self.assertEqual(tokens[0, 1].item(), 1) + self.assertEqual(tokens[1, 0].item(), 1) + self.assertEqual(tokens[1, 1].item(), 0) + + # ------------------------------------------------------------------ + # sample() – multinomial + # ------------------------------------------------------------------ + def test_sample_multinomial(self): + scheduler = self.get_scheduler() + # One token has overwhelming probability; multinomial should pick it. + logits = torch.tensor([[[0.0, 100.0, -100.0]]]) # (1, 1, 3) + tokens = scheduler.sample(logits, temperature=1.0) + self.assertEqual(tokens.shape, (1, 1)) + self.assertEqual(tokens[0, 0].item(), 1) + + # ------------------------------------------------------------------ + # step() – return dict + # ------------------------------------------------------------------ + def test_step_all_accepted(self): + """All draft tokens match the posterior => accepted_length == block_size - 1.""" + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 1, 4, 8 + + # Draft tokens: [0, 3, 3, 3] + draft_tokens = torch.tensor([[0, 3, 3, 3]], dtype=torch.long) + # Target logits: make argmax = [3, 3, 3, X] so posterior[:, :-1] matches draft[:, 1:] + logits = torch.zeros(batch_size, block_size, vocab_size) + logits[:, 0, 3] = 10.0 + logits[:, 1, 3] = 10.0 + logits[:, 2, 3] = 10.0 + logits[:, 3, 5] = 10.0 # last posterior token (next_token candidate) + + out = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=True) + + self.assertEqual(out.prev_sample.shape, (1, 4)) + self.assertEqual(out.accepted_length.shape, (1,)) + self.assertEqual(out.accepted_length[0].item(), 3) # all 3 comparisons match + self.assertEqual(out.next_token.shape, (1,)) + self.assertEqual(out.next_token[0].item(), 5) + self.assertEqual(out.posterior.shape, (1, 4)) + + def test_step_none_accepted(self): + """First draft token already mismatches => accepted_length == 0.""" + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 1, 4, 8 + + draft_tokens = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) + logits = torch.zeros(batch_size, block_size, vocab_size) + logits[:, 0, 5] = 10.0 # posterior[0] = 5, but draft[1] = 1 => mismatch + logits[:, 1, 2] = 10.0 + logits[:, 2, 3] = 10.0 + logits[:, 3, 4] = 10.0 + + out = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=True) + + self.assertEqual(out.accepted_length[0].item(), 0) + self.assertEqual(out.next_token[0].item(), 5) # posterior at index 0 + + def test_step_partial_accepted(self): + """First two match, third does not => accepted_length == 2.""" + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 1, 5, 8 + + # draft: [0, 3, 4, 7, 2] + draft_tokens = torch.tensor([[0, 3, 4, 7, 2]], dtype=torch.long) + logits = torch.zeros(batch_size, block_size, vocab_size) + logits[:, 0, 3] = 10.0 # match draft[1]=3 + logits[:, 1, 4] = 10.0 # match draft[2]=4 + logits[:, 2, 0] = 10.0 # mismatch draft[3]=7 + logits[:, 3, 2] = 10.0 + logits[:, 4, 6] = 10.0 + + out = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=True) + + self.assertEqual(out.accepted_length[0].item(), 2) + self.assertEqual(out.next_token[0].item(), 0) # posterior at index 2 + + def test_step_single_token_block(self): + """Block with a single token => accepted_length == 0.""" + scheduler = self.get_scheduler() + draft_tokens = torch.tensor([[5]], dtype=torch.long) + logits = torch.zeros(1, 1, 8) + logits[:, 0, 3] = 10.0 + + out = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=True) + self.assertEqual(out.accepted_length[0].item(), 0) + self.assertEqual(out.next_token[0].item(), 3) + + # ------------------------------------------------------------------ + # step() – return tuple + # ------------------------------------------------------------------ + def test_step_return_tuple(self): + scheduler = self.get_scheduler() + draft_tokens = torch.tensor([[0, 1, 2]], dtype=torch.long) + logits = torch.randn(1, 3, 8) + + result = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=False) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 4) + prev_sample, accepted_length, next_token, posterior = result + self.assertEqual(prev_sample.shape, (1, 3)) + self.assertEqual(accepted_length.shape, (1,)) + self.assertEqual(next_token.shape, (1,)) + self.assertEqual(posterior.shape, (1, 3)) + + # ------------------------------------------------------------------ + # step() – batched + # ------------------------------------------------------------------ + def test_step_batched(self): + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 3, 4, 16 + draft_tokens = torch.randint(0, vocab_size, (batch_size, block_size)) + logits = torch.randn(batch_size, block_size, vocab_size) + + out = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=True) + + self.assertEqual(out.prev_sample.shape, (batch_size, block_size)) + self.assertEqual(out.accepted_length.shape, (batch_size,)) + self.assertEqual(out.next_token.shape, (batch_size,)) + self.assertEqual(out.posterior.shape, (batch_size, block_size)) + + # ------------------------------------------------------------------ + # check_should_stop() + # ------------------------------------------------------------------ + def test_check_should_stop_no_stop_tokens(self): + output_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) + self.assertFalse(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, None, 2)) + + def test_check_should_stop_found(self): + # Stop token 99 is in the generated portion (after num_input_tokens=2). + output_ids = torch.tensor([[1, 2, 3, 99, 5]], dtype=torch.long) + self.assertTrue(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [99], 2)) + + def test_check_should_stop_only_in_prompt(self): + # Stop token 1 is only in the prompt portion => should NOT stop. + output_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) + self.assertFalse(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [1], 2)) + + def test_check_should_stop_multiple_stop_tokens(self): + output_ids = torch.tensor([[10, 20, 30, 40, 50]], dtype=torch.long) + self.assertTrue(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [40, 99], 2)) + self.assertFalse(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [99, 100], 2)) + + # ------------------------------------------------------------------ + # add_noise() + # ------------------------------------------------------------------ + def test_add_noise_prompt_preserved(self): + scheduler = self.get_scheduler() + original = torch.tensor([[10, 11, 12, 13, 14, 15, 16, 17]], dtype=torch.long) + attention_mask = torch.ones_like(original) + mask_token_id = 99 + prompt_length = 3 + + gen = torch.Generator().manual_seed(42) + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=prompt_length, + block_size=4, + mask_token_id=mask_token_id, + generator=gen, + ) + + # Prompt positions should never be masked. + self.assertFalse(masked[0, :prompt_length].any().item()) + # Prompt tokens should be unchanged. + self.assertTrue(torch.equal(noisy[0, :prompt_length], original[0, :prompt_length])) + + def test_add_noise_masked_positions(self): + scheduler = self.get_scheduler() + original = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long) + attention_mask = torch.ones_like(original) + mask_token_id = 99 + + gen = torch.Generator().manual_seed(0) + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=2, + block_size=3, + mask_token_id=mask_token_id, + generator=gen, + ) + + # Where masked is True, noisy should equal mask_token_id. + self.assertTrue((noisy[masked] == mask_token_id).all().item()) + # Where masked is False, noisy should equal original. + self.assertTrue(torch.equal(noisy[~masked], original[~masked])) + + def test_add_noise_respects_attention_mask(self): + scheduler = self.get_scheduler() + original = torch.tensor([[1, 2, 3, 4, 0, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0]], dtype=torch.long) + mask_token_id = 99 + + gen = torch.Generator().manual_seed(42) + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=1, + block_size=3, + mask_token_id=mask_token_id, + generator=gen, + ) + + # Padding positions (attention_mask=0) should never be masked. + self.assertFalse(masked[0, 4].item()) + self.assertFalse(masked[0, 5].item()) + + def test_add_noise_output_shapes(self): + scheduler = self.get_scheduler() + batch_size, seq_len = 2, 10 + original = torch.randint(0, 50, (batch_size, seq_len)) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long) + mask_token_id = 99 + + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=2, + block_size=4, + mask_token_id=mask_token_id, + ) + + self.assertEqual(noisy.shape, (batch_size, seq_len)) + self.assertEqual(masked.shape, (batch_size, seq_len)) + self.assertEqual(noisy.dtype, torch.long) + self.assertEqual(masked.dtype, torch.bool) + + +if __name__ == "__main__": + unittest.main() From 715ca5046d82f1128215cf965903fbc1f0d260fc Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 8 May 2026 10:09:55 +0000 Subject: [PATCH 2/6] [Docs] flesh out DFlash pipeline + scheduler pages --- docs/source/en/api/pipelines/dflash.md | 75 ++++++++++++++++++- .../api/schedulers/dflash_token_diffusion.md | 15 +++- 2 files changed, 86 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/pipelines/dflash.md b/docs/source/en/api/pipelines/dflash.md index 95847e1fdd82..b1caea59f82a 100644 --- a/docs/source/en/api/pipelines/dflash.md +++ b/docs/source/en/api/pipelines/dflash.md @@ -12,8 +12,79 @@ specific language governing permissions and limitations under the License. # DFlash -`DFlashPipeline` performs block-diffusion speculative decoding using a diffusion draft model and a target causal LM. -The draft model is conditioned on target hidden features extracted during prefill and verification steps. +[DFlash](https://huggingface.co/collections/z-lab/dflash) is a block-diffusion speculative decoding scheme. A small +diffusion *draft* model proposes a block of tokens conditioned on hidden features extracted from intermediate layers +of a frozen *target* causal LM; the target then verifies the proposed block in a single forward pass and accepts the +longest matching prefix. The draft model is shared with the target's tokenizer, so no calibration is needed. + +`DFlashPipeline` ties the two models together: prefill on the target, draft a block, verify against the target's +posterior via [`DFlashTokenDiffusionScheduler`], commit the accepted prefix and the next-token resample, and repeat +until `max_new_tokens` or a stop token. Compatible draft/target pairs include `z-lab/Qwen3-8B-DFlash-b16` with +`Qwen/Qwen3-8B`, and `z-lab/Qwen3.5-4B-DFlash` with `Qwen/Qwen3.5-4B` (the latter is a hybrid-attention target — see +the rollback note below). + +## Usage + +```py +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +from diffusers import DFlashPipeline + +draft = AutoModel.from_pretrained( + "z-lab/Qwen3.5-4B-DFlash", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" +) +target = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3.5-4B", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" +) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B", trust_remote_code=True) + +pipe = DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer) +output = pipe( + prompt="What is 2 + 2? Answer in one sentence.", + max_new_tokens=128, + temperature=0.0, + chat_template_kwargs={"enable_thinking": False}, +) +print(output.texts[0]) +``` + +`DFlashPipeline` currently runs `batch_size=1` only. Multi-prompt batching requires per-row partial-accept tracking +and is not yet supported. + +## Hybrid-attention targets + +For target models with linear-attention layers (e.g. Qwen3.5's gated-delta-net), `DynamicCache.crop()` silently +no-ops on those layers, so a partial-accept block would otherwise leak rejected speculative tokens into the +recurrent state. The pipeline detects linear-attention caches via +[`DFlashTokenDiffusionScheduler.cache_has_linear_attention`] and uses a snapshot/restore + accepted-prefix +re-forward pattern to advance both layer types cleanly. This adds one extra target forward per partial-accept +block but is required for correctness. + +## Fast path + +When the draft model exposes a `spec_generate(...)` method (e.g. `z-lab/Qwen3-8B-DFlash-b16`), the pipeline +delegates to it — that loop is the upstream-canonical implementation and avoids re-running the rollback bookkeeping. +Newer drafts (`z-lab/Qwen3.5-4B-DFlash`) drop `spec_generate`; the pipeline falls back to its explicit verify loop. + +## Callbacks + +Callbacks run after each block-verify step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are +included in `callback_kwargs`. Allowed keys: `block_output_ids` (the drafted block), `draft_logits`, +`accepted_length`, `next_token`, and `output_ids` (the running output buffer). Return `{"output_ids": ...}` from the +callback to replace the buffer. + +```py +def on_step_end(pipe, step, timestep, callback_kwargs): + output_ids = callback_kwargs["output_ids"] + return {"output_ids": output_ids} + +out = pipe( + prompt="...", + callback_on_step_end=on_step_end, + callback_on_step_end_tensor_inputs=["output_ids"], +) +``` ## DFlashPipeline [[autodoc]] DFlashPipeline diff --git a/docs/source/en/api/schedulers/dflash_token_diffusion.md b/docs/source/en/api/schedulers/dflash_token_diffusion.md index c98b11bc9963..faa5c2405c87 100644 --- a/docs/source/en/api/schedulers/dflash_token_diffusion.md +++ b/docs/source/en/api/schedulers/dflash_token_diffusion.md @@ -12,8 +12,19 @@ specific language governing permissions and limitations under the License. # DFlashTokenDiffusionScheduler -`DFlashTokenDiffusionScheduler` implements the acceptance and posterior sampling logic used in DFlash-style block -diffusion speculative decoding. +[`DFlashTokenDiffusionScheduler`] implements the verification step for DFlash-style block-diffusion speculative +decoding. It samples a posterior block from the target logits, computes the acceptance length as the longest prefix +where the draft proposal matches the posterior, and exposes the resampled `next_token` for the first rejected +position. Used by [`DFlashPipeline`]. + +The scheduler also owns three helpers used by the pipeline's verify loop on hybrid-attention targets: + +- `cache_has_linear_attention(cache)` — detect whether a `DynamicCache` contains any linear-attention layers. +- `snapshot_cache(cache)` / `restore_cache(cache, snapshot)` — clone and restore the full per-layer state so a + partial-accept block can be rolled back and the target re-advanced on just the accepted prefix. + +These exist because `DynamicCache.crop()` silently no-ops on linear-attention layers, which would otherwise let +rejected speculative tokens permanently contaminate the recurrent state. ## DFlashTokenDiffusionScheduler [[autodoc]] DFlashTokenDiffusionScheduler From 4c0e3dd594c4bbe4345d6a6b9cff518a835bf66a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 8 May 2026 10:41:18 +0000 Subject: [PATCH 3/6] [DFlash] fix train_dflash position_ids + clarify trust_remote_code - Training: `position_ids` must span `[0, start + block_size)` so the draft's attention RoPE cos/sin covers both `k_ctx` (target_hidden, length `start`) and `k_noise` (noise_embedding, length `block_size`). Previously we passed only `arange(start, start + block_size)` which triggered a K-side broadcast mismatch on the very first batch. - Docs/examples: target loads as plain Qwen3 / Qwen3.5 (no remote code), but the draft's custom DFlashDraftModel class lives in the Hub repo's `auto_map`, so `trust_remote_code=True` is required for draft loads only. Updated the example docstring, pipeline doc page, sample script, train script, and the GPU verify script. Smoke-tested via srun on z-lab/Qwen3.5-4B-DFlash + Qwen/Qwen3.5-4B (H100): 3 steps complete, final checkpoint saved. --- docs/source/en/api/pipelines/dflash.md | 7 +++---- examples/discrete_diffusion/sample_dflash.py | 2 +- examples/discrete_diffusion/train_dflash.py | 5 ++++- src/diffusers/pipelines/dflash/pipeline_dflash.py | 4 ++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/docs/source/en/api/pipelines/dflash.md b/docs/source/en/api/pipelines/dflash.md index b1caea59f82a..ab9153f5170e 100644 --- a/docs/source/en/api/pipelines/dflash.md +++ b/docs/source/en/api/pipelines/dflash.md @@ -31,13 +31,12 @@ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from diffusers import DFlashPipeline +# Draft ships custom modeling code via `auto_map` — `trust_remote_code=True` is required. draft = AutoModel.from_pretrained( "z-lab/Qwen3.5-4B-DFlash", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" ) -target = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen3.5-4B", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" -) -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B", trust_remote_code=True) +target = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-4B", dtype=torch.bfloat16, device_map="auto") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B") pipe = DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer) output = pipe( diff --git a/examples/discrete_diffusion/sample_dflash.py b/examples/discrete_diffusion/sample_dflash.py index a10899a0d052..b1a3088a971c 100644 --- a/examples/discrete_diffusion/sample_dflash.py +++ b/examples/discrete_diffusion/sample_dflash.py @@ -107,7 +107,7 @@ def main(): print(f"Loading draft model: {args.draft_model_id}") print(f"Loading target model: {args.target_model_id}") dtype_arg = torch_dtype if torch_dtype is not None else "auto" - # Draft model is a custom DFlashDraftModel; use AutoModel so trust_remote_code routes to the class in `auto_map`. + # Draft model is a custom DFlashDraftModel; trust_remote_code routes to the class in `auto_map`. draft_model = AutoModel.from_pretrained( args.draft_model_id, trust_remote_code=True, diff --git a/examples/discrete_diffusion/train_dflash.py b/examples/discrete_diffusion/train_dflash.py index 673a2173a058..6538a8db5f55 100644 --- a/examples/discrete_diffusion/train_dflash.py +++ b/examples/discrete_diffusion/train_dflash.py @@ -248,7 +248,10 @@ def main(): block_targets = input_ids[:, start + 1 : start + block_size] block_mask = attention_mask[:, start + 1 : start + block_size] - position_ids = torch.arange(start, start + block_size, device=input_ids.device).unsqueeze(0) + # The draft's attention concatenates `k_ctx` (target_hidden, length `start`) with + # `k_noise` (noise_embedding, length `block_size`); RoPE needs cos/sin covering the + # full range `[0, start + block_size)` so the K-side broadcast works. + position_ids = torch.arange(start + block_size, device=input_ids.device).unsqueeze(0) position_ids = position_ids.expand(input_ids.shape[0], -1) with torch.no_grad(): diff --git a/src/diffusers/pipelines/dflash/pipeline_dflash.py b/src/diffusers/pipelines/dflash/pipeline_dflash.py index e8b0276db109..94b81dd6b863 100644 --- a/src/diffusers/pipelines/dflash/pipeline_dflash.py +++ b/src/diffusers/pipelines/dflash/pipeline_dflash.py @@ -38,9 +38,9 @@ >>> from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer >>> draft = AutoModel.from_pretrained( - ... "z-lab/Qwen3-8B-DFlash-b16", trust_remote_code=True, torch_dtype=torch.bfloat16 + ... "z-lab/Qwen3-8B-DFlash-b16", trust_remote_code=True, dtype=torch.bfloat16 ... ) - >>> target = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B", torch_dtype=torch.bfloat16) + >>> target = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B", dtype=torch.bfloat16) >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") >>> pipe = DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer) >>> out = pipe(prompt="How many positive whole-number divisors does 196 have?") From cd0ce7b6254a643d2a4b305550327cc3e9f40fe9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 8 May 2026 11:19:08 +0000 Subject: [PATCH 4/6] [DFlash] remove spec_generate fast path; explicit loop handles all targets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pipeline previously short-circuited to `draft.spec_generate(...)` when the draft model exposed it (e.g. z-lab/Qwen3-8B-DFlash-b16). That path is the upstream `dflash_generate` loop, which calls `past_key_values_target.crop()` unconditionally — fine for full-attention targets, but on hybrid targets it silently corrupts the linear-attention recurrent state. Confirmed in transformers 5.8.0.dev0 at cache_utils.py:759-761: def crop(self, max_length: int): # We don't crop the linear attention cache, so simply do nothing here pass `LinearAttentionCacheLayerMixin.crop` is documented as a no-op, so any verify loop that relies on `cache.crop()` for rollback is wrong on hybrid attention targets. Our explicit loop already handles this via `DFlashTokenDiffusionScheduler.snapshot_cache` / `restore_cache` plus an accepted-prefix re-forward, and reduces to a plain `.crop()` on full-attn targets. Verified end-to-end on GPU after the removal: - z-lab/Qwen3.5-4B-DFlash + Qwen/Qwen3.5-4B (hybrid attn): "2 + 2 equals 4." - z-lab/Qwen3-8B-DFlash-b16 + Qwen/Qwen3-8B (full attn): "2 + 2 equals 4." Fast tests: 43 passed. --- docs/source/en/api/pipelines/dflash.md | 18 ++++++---------- .../pipelines/dflash/pipeline_dflash.py | 21 ------------------- 2 files changed, 6 insertions(+), 33 deletions(-) diff --git a/docs/source/en/api/pipelines/dflash.md b/docs/source/en/api/pipelines/dflash.md index ab9153f5170e..215c07530c9e 100644 --- a/docs/source/en/api/pipelines/dflash.md +++ b/docs/source/en/api/pipelines/dflash.md @@ -53,18 +53,12 @@ and is not yet supported. ## Hybrid-attention targets -For target models with linear-attention layers (e.g. Qwen3.5's gated-delta-net), `DynamicCache.crop()` silently -no-ops on those layers, so a partial-accept block would otherwise leak rejected speculative tokens into the -recurrent state. The pipeline detects linear-attention caches via -[`DFlashTokenDiffusionScheduler.cache_has_linear_attention`] and uses a snapshot/restore + accepted-prefix -re-forward pattern to advance both layer types cleanly. This adds one extra target forward per partial-accept -block but is required for correctness. - -## Fast path - -When the draft model exposes a `spec_generate(...)` method (e.g. `z-lab/Qwen3-8B-DFlash-b16`), the pipeline -delegates to it — that loop is the upstream-canonical implementation and avoids re-running the rollback bookkeeping. -Newer drafts (`z-lab/Qwen3.5-4B-DFlash`) drop `spec_generate`; the pipeline falls back to its explicit verify loop. +For target models with linear-attention layers (e.g. Qwen3.5's gated-delta-net), `DynamicCache.crop()` is a +documented no-op on those layers (see `transformers.cache_utils.LinearAttentionCacheLayerMixin.crop`), so a +partial-accept block would otherwise leak rejected speculative tokens into the recurrent state. The pipeline +detects linear-attention caches via [`DFlashTokenDiffusionScheduler.cache_has_linear_attention`] and uses a +snapshot/restore + accepted-prefix re-forward pattern to advance both layer types cleanly. This adds one extra +target forward per partial-accept block on hybrid targets; full-attention targets use a plain `cache.crop()`. ## Callbacks diff --git a/src/diffusers/pipelines/dflash/pipeline_dflash.py b/src/diffusers/pipelines/dflash/pipeline_dflash.py index 94b81dd6b863..214ee0c44052 100644 --- a/src/diffusers/pipelines/dflash/pipeline_dflash.py +++ b/src/diffusers/pipelines/dflash/pipeline_dflash.py @@ -360,27 +360,6 @@ def __call__( target_config = getattr(self.target_model, "config", None) draft_config = getattr(self.draft_model, "config", None) - # Fast path: some draft models (e.g. z-lab/Qwen3-8B-DFlash-b16) ship a self-contained - # `spec_generate` method. Delegate when available — it's the upstream-canonical loop and - # avoids re-implementing rollback. Newer drafts (Qwen3.5-4B-DFlash) drop this method, so - # fall back to the explicit pipeline loop below. - spec_generate = getattr(self.draft_model, "spec_generate", None) - if callable(spec_generate): - generated = spec_generate( - input_ids=input_ids, - max_new_tokens=int(max_new_tokens), - temperature=float(temperature), - target=self.target_model, - stop_token_ids=stop_token_ids, - ) - sequences = generated[:, input_ids.shape[1] :] - texts = None - if output_type == "text" and getattr(self, "tokenizer", None) is not None: - texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) - if not return_dict: - return sequences, texts - return DFlashPipelineOutput(sequences=sequences, texts=texts) - # Pass `config=` only when it looks like a real PretrainedConfig — hybrid-attention models # (Qwen3.5) need it so `DynamicCache` instantiates the right per-layer cache types # (linear vs full), but bare dummy configs in tests don't implement `get_text_config`. From a70e329fee17ba25ea83a25596d5aa033f64b685 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 9 May 2026 10:33:38 +0000 Subject: [PATCH 5/6] [DFlash] add num_timesteps property for parity with LLaDA2 --- src/diffusers/pipelines/dflash/pipeline_dflash.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/dflash/pipeline_dflash.py b/src/diffusers/pipelines/dflash/pipeline_dflash.py index 214ee0c44052..99b5e1883360 100644 --- a/src/diffusers/pipelines/dflash/pipeline_dflash.py +++ b/src/diffusers/pipelines/dflash/pipeline_dflash.py @@ -95,6 +95,10 @@ def __init__( draft_model=draft_model, target_model=target_model, tokenizer=tokenizer, scheduler=scheduler ) + @property + def num_timesteps(self): + return self._num_timesteps + # --- Prompt encoding --- def _prepare_input_ids( @@ -391,6 +395,7 @@ def _new_cache(cfg): start = num_input_tokens global_step = 0 num_blocks = (max_length - num_input_tokens + block_size - 1) // block_size + self._num_timesteps = int(num_blocks) # 5. Block-wise speculative decoding loop block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy() From 471afa98fc7d588d7cd1e914e6e369aa012ea514 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 9 May 2026 10:36:35 +0000 Subject: [PATCH 6/6] [DFlash] document examples in discrete_diffusion README --- examples/discrete_diffusion/README.md | 41 +++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/examples/discrete_diffusion/README.md b/examples/discrete_diffusion/README.md index a3a8253b1927..0b8fe38316c5 100644 --- a/examples/discrete_diffusion/README.md +++ b/examples/discrete_diffusion/README.md @@ -48,3 +48,44 @@ python examples/discrete_diffusion/sample_llada2.py \ --use_chat_template \ --add_generation_prompt ``` + +## DFlash + +[DFlash](https://huggingface.co/collections/z-lab/dflash) is a block-diffusion **speculative decoding** scheme: a small diffusion *draft* model, conditioned on hidden features from a frozen *target* causal LM, proposes a block of tokens that the target verifies in a single forward pass. The pipeline accepts the longest matching prefix and resamples the next token at the rejection point. + +### Sample + +The published draft pairs with a stock target (no `trust_remote_code` for the target): + +```bash +python examples/discrete_diffusion/sample_dflash.py \ + --draft_model_id z-lab/Qwen3.5-4B-DFlash \ + --target_model_id Qwen/Qwen3.5-4B \ + --prompt "How many positive whole-number divisors does 196 have?" \ + --max_new_tokens 4096 +``` + +The draft ships a custom `DFlashDraftModel` class via `auto_map`, so the sample script loads it with `trust_remote_code=True`; the target loads as a stock Qwen3 / Qwen3.5 model. Per-draft thinking-mode defaults from the upstream model cards: + +| Draft | `enable_thinking` | +|---|---| +| `z-lab/Qwen3.5-*-DFlash` | `True` | +| `z-lab/Qwen3-*-DFlash-b16` | `False` (drafts are non-thinking-trained) | + +### Train + +The training loop conditions the draft on intermediate target hidden states and predicts the next `block_size − 1` tokens of each block: + +```bash +accelerate launch examples/discrete_diffusion/train_dflash.py \ + --draft_model_id z-lab/Qwen3-4B-DFlash-b16 \ + --target_model_id Qwen/Qwen3-4B \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --text_column text \ + --output_dir dflash-output \ + --max_train_steps 1000 \ + --learning_rate 2e-5 +``` + +`--block_size 0` (default) reads the block size from the draft model's config (16 for the b16 drafts, 16 for `Qwen3.5-*-DFlash`).