Skip to content

llada2 model/pipeline review #13598

@hlky

Description

@hlky

llada2 model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Reviewed pipeline imports/lazy-loading, pipeline runtime behavior, scheduler coupling, callbacks, docs/examples, fast tests, and slow-test coverage. Existing fast tests pass: 42 passed. Duplicate search found existing LLaDA2 issue #13357 and PRs #12911/#13226/#13333, but no duplicate for the findings below.

Issue 1: Tokenizer padding masks are discarded

Affected code:

encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list))
return encoded["input_ids"]

# 2D attention mask (no padding) — the model handles backend-specific conversion internally.
attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long)
position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)

Problem:
_prepare_input_ids() drops the tokenizer attention_mask, and __call__() replaces it with an all-ones mask. Batched prompts with padding, and non-block-aligned padded tail positions, are treated as real context tokens.

Impact:
Shorter prompts in a batch attend to pad tokens as prompt content. Final block padding beyond prompt_length + gen_length is also exposed as valid masked tokens, which can change logits for returned tokens.

Reproduction:

from types import SimpleNamespace
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("anchor", torch.empty(0))
        self.masks = []
    @property
    def device(self): return self.anchor.device
    @property
    def dtype(self): return torch.float32
    def forward(self, input_ids, attention_mask=None, position_ids=None):
        self.masks.append(attention_mask.cpu())
        logits = torch.zeros(input_ids.shape[0], input_ids.shape[1], 128)
        logits[..., 1] = 10
        return SimpleNamespace(logits=logits)

class Tok:
    eos_token_id = None
    mask_token_id = 99
    chat_template = None
    def __call__(self, prompt, return_tensors=None, padding=False):
        return {
            "input_ids": torch.tensor([[11, 12, 0], [21, 0, 0]]),
            "attention_mask": torch.tensor([[1, 1, 0], [1, 0, 0]]),
        }

model = Model()
pipe = LLaDA2Pipeline(model=model, scheduler=BlockRefinementScheduler(), tokenizer=Tok())
pipe.set_progress_bar_config(disable=True)
pipe(prompt=["long", "short"], use_chat_template=False, gen_length=1, block_length=4, num_inference_steps=1, output_type="seq")
print(model.masks[0].tolist())  # [[1, 1, 1, 1], [1, 1, 1, 1]]

Relevant precedent:
QwenImage preserves tokenizer masks through prompt encoding:

txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds, encoder_attention_mask

Suggested fix:
Carry tokenizer attention_mask out of prompt encoding, add an attention_mask argument for pre-tokenized input_ids, build a valid mask for prompt plus requested generated positions only, and prevent the scheduler from committing padded tail positions.

Issue 2: block_length argument does not control the scheduler transfer schedule

Affected code:

num_inference_steps = min(num_inference_steps, gen_length // minimal_topk)
self.scheduler.set_timesteps(num_inference_steps, device=device)

def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)
self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to(
device=device if device is not None else "cpu"
)

Problem:
The pipeline accepts block_length, but BlockRefinementScheduler.set_timesteps() computes _transfer_schedule from self.config.block_length. With the default scheduler, pipe(..., block_length=8) still uses a 32-token transfer schedule.

Impact:
Custom block sizes refine at the wrong rate. For block_length=8, num_inference_steps=8, the first step commits 4 tokens instead of 1.

Reproduction:

from types import SimpleNamespace
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("anchor", torch.empty(0))
    @property
    def device(self): return self.anchor.device
    @property
    def dtype(self): return torch.float32
    def forward(self, input_ids, attention_mask=None, position_ids=None):
        logits = torch.zeros(input_ids.shape[0], input_ids.shape[1], 128)
        logits[..., 1] = 10
        return SimpleNamespace(logits=logits)

commits = []
def cb(pipe, step, timestep, kwargs):
    commits.append(int(kwargs["transfer_index"].sum()))
    return {}

pipe = LLaDA2Pipeline(model=Model(), scheduler=BlockRefinementScheduler())
pipe.set_progress_bar_config(disable=True)
pipe(
    input_ids=torch.empty((1, 0), dtype=torch.long),
    gen_length=8,
    block_length=8,
    num_inference_steps=8,
    threshold=2.0,
    mask_token_id=127,
    output_type="seq",
    eos_early_stop=False,
    callback_on_step_end=cb,
    callback_on_step_end_tensor_inputs=["transfer_index"],
)
print(commits[:3])  # [4, 4, 0]

Relevant precedent:
Schedulers should derive runtime schedules from the pipeline call parameters passed into set_timesteps, not stale constructor defaults.

Suggested fix:

def set_timesteps(self, num_inference_steps: int, device=None, block_length: int | None = None) -> None:
    if num_inference_steps <= 0:
        raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
    block_length = self.config.block_length if block_length is None else block_length
    self.num_inference_steps = num_inference_steps
    self.timesteps = torch.arange(num_inference_steps - 1, -1, device=device, dtype=torch.long)
    self._transfer_schedule = self.get_num_transfer_tokens(block_length, num_inference_steps).to(device=device or "cpu")

Then call self.scheduler.set_timesteps(num_inference_steps, device=device, block_length=block_length).

Issue 3: Advertised callback tensor inputs raise KeyError

Affected code:

_callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"]

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs)
block_x = callback_outputs.pop("block_x", block_x)

Problem:
_callback_tensor_inputs allows x0, x0_p, confidence, and active_block, but those names are not locals at callback collection time. Requesting them passes validation and then crashes.

Impact:
The callback API is unreliable for documented introspection and debugging.

Reproduction:

from types import SimpleNamespace
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("anchor", torch.empty(0))
    @property
    def device(self): return self.anchor.device
    @property
    def dtype(self): return torch.float32
    def forward(self, input_ids, attention_mask=None, position_ids=None):
        logits = torch.zeros(input_ids.shape[0], input_ids.shape[1], 16)
        logits[..., 1] = 10
        return SimpleNamespace(logits=logits)

pipe = LLaDA2Pipeline(model=Model(), scheduler=BlockRefinementScheduler())
pipe.set_progress_bar_config(disable=True)
try:
    pipe(
        input_ids=torch.empty((1, 0), dtype=torch.long),
        gen_length=1,
        block_length=1,
        num_inference_steps=1,
        mask_token_id=15,
        output_type="seq",
        callback_on_step_end=lambda *args: {},
        callback_on_step_end_tensor_inputs=["confidence"],
    )
except Exception as e:
    print(type(e).__name__, e)  # KeyError 'confidence'

Relevant precedent:
QwenImage keeps _callback_tensor_inputs aligned with actual locals:

_callback_tensor_inputs = ["latents", "prompt_embeds"]

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)

Suggested fix:

_callback_tensor_inputs = ["block_x", "transfer_index"]

# Or define the extra values explicitly before callback collection:
active_block = block_tokens == mask_token_id
confidence = scheduler_output.sampled_probs

Avoid keeping x0 / x0_p unless their semantics are implemented and tested.

Issue 4: EOS at the first generated position is ignored

Affected code:

eos_pos = (cur_x[b] == eos_token_id).nonzero(as_tuple=True)
if len(eos_pos[0]) == 0:
continue
eos_pos = int(eos_pos[0][0].item())
if prompt_length >= eos_pos:
continue
if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item():
finished[b] = True

Problem:
check_eos_finished() skips EOS when prompt_length >= eos_pos. The first generated token has index prompt_length, so an EOS immediately after the prompt is treated as if it were inside the prompt.

Impact:
eos_early_stop=True fails for the most common early-stop case: the model ending generation on the first generated token.

Reproduction:

import torch
from diffusers import BlockRefinementScheduler

finished = BlockRefinementScheduler.check_eos_finished(
    cur_x=torch.tensor([[10, 2, 99]]),
    sampled_tokens=torch.tensor([[0, 2]]),
    final_transfer=torch.tensor([[False, True]]),
    finished=torch.tensor([False]),
    eos_token_id=2,
    mask_token_id=99,
    prompt_length=1,
)
print(finished.tolist())  # [False], expected [True]

Relevant precedent:
Prompt positions are indices < prompt_length; the first generated position is exactly prompt_length.

Suggested fix:

if eos_pos < prompt_length:
    continue

Issue 5: Finished rows in a batch keep being refined after EOS

Affected code:

finished = torch.zeros((batch_size,), device=device, dtype=torch.bool)
editing_enabled = editing_threshold is not None and editing_threshold > 0.0

should_continue = self.scheduler.check_block_should_continue(
step_idx=step_idx,
masks_remaining=masks_remaining,
editing_enabled=editing_enabled,
editing_transfer_index=editing_transfer_index,
post_steps=post_steps,
max_post_steps=max_post_steps,
finished=finished,
)
progress_bar.close()
x[:, :current_window_end] = block_x
if eos_early_stop and finished.all():
break

if eos_token_id is not None and batch_size == 1:
eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0]
if len(eos_positions) > 0:
sequences = sequences[:, : int(eos_positions[0].item()) + 1]

Problem:
finished is only used to stop when all batch rows finish. Rows already marked finished remain eligible for later block updates, and sequence trimming only runs for batch_size == 1.

Impact:
In mixed-length batches, text after EOS can be generated and decoded for rows that should have stopped.

Reproduction:

from types import SimpleNamespace
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("anchor", torch.empty(0))
    @property
    def device(self): return self.anchor.device
    @property
    def dtype(self): return torch.float32
    def forward(self, input_ids, attention_mask=None, position_ids=None):
        b, s = input_ids.shape
        logits = torch.zeros(b, s, 128)
        if s <= 3:
            logits[0, :, 5] = 10
            logits[0, 2, 2] = 20
            logits[1, :, 6] = 10
        else:
            logits[0, :, 7] = 10
            logits[1, :, 6] = 10
        return SimpleNamespace(logits=logits)

pipe = LLaDA2Pipeline(model=Model(), scheduler=BlockRefinementScheduler())
pipe.set_progress_bar_config(disable=True)
out = pipe(
    input_ids=torch.tensor([[10], [20]]),
    gen_length=5,
    block_length=3,
    num_inference_steps=3,
    threshold=2.0,
    mask_token_id=127,
    eos_token_id=2,
    eos_early_stop=True,
    output_type="seq",
)
print(out.sequences.tolist())  # row 0 has tokens after EOS: [5, 2, 7, 7, 7]

Relevant precedent:
Generation APIs should freeze or mask rows after EOS and decode per-row only up to EOS.

Suggested fix:
Freeze finished rows before applying future transfers, and trim per-row decode inputs:

if finished.any():
    final_transfer = final_transfer & ~finished[:, None]

decode_sequences = sequences
if eos_token_id is not None:
    decode_sequences = [
        seq[: int((seq == eos_token_id).nonzero(as_tuple=True)[0][0]) + 1]
        if (seq == eos_token_id).any()
        else seq
        for seq in sequences
    ]
texts = self.tokenizer.batch_decode(decode_sequences, skip_special_tokens=True)

Issue 6: Inner progress bars ignore disable=True

Affected code:

# 5. Block-wise refinement loop
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
block_progress_bar_config["position"] = 0
block_progress_bar_config["desc"] = "Blocks"
for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config):
current_window_end = (num_block + 1) * block_length
block_x = x[:, :current_window_end]
block_attn_mask = attn_mask[:, :current_window_end]
block_position_ids = position_ids[:, :current_window_end]
# Identify which positions in the block are prompt (non-editable).
block_start_pos = num_block * block_length
prompt_mask_in_block = torch.zeros(block_length, device=device, dtype=torch.bool)
if block_start_pos < prompt_length:
prompt_end_in_block = min(prompt_length - block_start_pos, block_length)
prompt_mask_in_block[:prompt_end_in_block] = True
post_steps = 0
step_idx = 0
should_continue = True
self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps")
progress_bar = self.progress_bar(total=num_inference_steps)

Problem:
The pipeline copies the outer progress config, but then calls self.set_progress_bar_config(position=1, leave=False, desc=...) inside the block loop. This replaces the user’s existing config, including disable=True.

Impact:
Users who disable progress bars still get inner progress output, and the pipeline leaves _progress_bar_config mutated after the call.

Reproduction:

from types import SimpleNamespace
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("anchor", torch.empty(0))
    @property
    def device(self): return self.anchor.device
    @property
    def dtype(self): return torch.float32
    def forward(self, input_ids, attention_mask=None, position_ids=None):
        logits = torch.zeros(input_ids.shape[0], input_ids.shape[1], 16)
        logits[..., 1] = 10
        return SimpleNamespace(logits=logits)

pipe = LLaDA2Pipeline(model=Model(), scheduler=BlockRefinementScheduler())
pipe.set_progress_bar_config(disable=True)
pipe(input_ids=torch.empty((1, 0), dtype=torch.long), gen_length=1, block_length=1, num_inference_steps=1, mask_token_id=15)
print(pipe._progress_bar_config)  # {'position': 1, 'leave': False, 'desc': ...}

Relevant precedent:
DiffusionPipeline.progress_bar() already merges the current config and supplies distributed defaults:

def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
progress_bar_config = dict(self._progress_bar_config)
if "disable" not in progress_bar_config:
progress_bar_config["disable"] = not is_torch_dist_rank_zero()
if iterable is not None:
return tqdm(iterable, **progress_bar_config)
elif total is not None:
return tqdm(total=total, **progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs

Suggested fix:
Preserve and restore the prior config, or build a local config without calling set_progress_bar_config() inside __call__.

Issue 7: Missing slow/integration tests for LLaDA2

Affected code:

class LLaDA2PipelineTest(unittest.TestCase):
def test_pipeline_runs(self):
pipe = _make_pipeline().to("cpu")
input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long)
out = pipe(
input_ids=input_ids,
use_chat_template=False,
gen_length=24,
block_length=8,
num_inference_steps=8,
temperature=0.0,
threshold=2.0, # force top-k commits
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
eos_token_id=None,
output_type="seq",
)
self.assertEqual(out.sequences.shape, (2, 24))
self.assertFalse((out.sequences == 31).any().item())
def test_pipeline_return_tuple(self):
pipe = _make_pipeline().to("cpu")
input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long)
sequences, texts = pipe(
input_ids=input_ids,
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="seq",
return_dict=False,
)
self.assertEqual(sequences.shape, (1, 16))
self.assertIsNone(texts)
def test_output_type_seq(self):
"""output_type='seq' should return sequences but no texts."""
pipe = _make_pipeline().to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="seq",
)
self.assertIsNotNone(out.sequences)
self.assertEqual(out.sequences.shape, (1, 16))
self.assertIsNone(out.texts)
def test_output_type_text_without_tokenizer(self):
"""output_type='text' without a tokenizer should return texts=None."""
pipe = _make_pipeline(tokenizer=None).to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="text",
)
self.assertIsNotNone(out.sequences)
self.assertIsNone(out.texts)
def test_output_type_text_with_tokenizer(self):
"""output_type='text' with a tokenizer should return decoded texts."""
tok = type(
"Tok",
(),
{
"eos_token_id": None,
"mask_token_id": 31,
"batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs],
},
)()
pipe = _make_pipeline(tokenizer=tok).to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
output_type="text",
)
self.assertIsNotNone(out.sequences)
self.assertIsNotNone(out.texts)
self.assertEqual(len(out.texts), 1)
self.assertTrue(out.texts[0].startswith("decoded_"))
def test_output_type_invalid_raises(self):
"""Invalid output_type should raise ValueError."""
pipe = _make_pipeline().to("cpu")
with self.assertRaises(ValueError):
pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
mask_token_id=31,
output_type="invalid",
)
def test_prepare_input_ids_from_tensor(self):
pipe = _make_pipeline()
ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
result = pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=ids,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
self.assertTrue(torch.equal(result, ids))
def test_prepare_input_ids_from_1d_tensor(self):
pipe = _make_pipeline()
ids = torch.tensor([1, 2, 3], dtype=torch.long)
result = pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=ids,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
self.assertEqual(result.shape, (1, 3))
def test_prepare_input_ids_no_tokenizer_raises(self):
pipe = _make_pipeline(tokenizer=None)
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt="hello",
messages=None,
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
def test_prepare_input_ids_both_prompt_and_messages_raises(self):
pipe = _make_pipeline()
# Manually set tokenizer to a simple object so _prepare_input_ids doesn't short-circuit
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt="hello",
messages=[{"role": "user", "content": "hi"}],
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
def test_prepare_input_ids_neither_raises(self):
pipe = _make_pipeline()
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)

Problem:
Only fast dummy-model tests exist under tests/pipelines/llada2. There is no slow test that loads an actual or tiny Hub fixture, exercises tokenizer chat templating/padding, or checks a realistic LLaDA2Pipeline call.

Impact:
The fast suite misses real integration risks, including the tokenizer mask and EOS/callback failures above.

Reproduction:

from pathlib import Path

hits = []
for path in Path("tests/pipelines/llada2").rglob("*.py"):
    text = path.read_text()
    if "@slow" in text or "hf-internal-testing" in text or "inclusionAI/" in text:
        hits.append(str(path))
print(hits)  # []

Relevant precedent:
The repo generally pairs pipeline fast tests with slow or tiny-fixture coverage when Hub loading/tokenizer behavior is part of the public path.

Suggested fix:
Add a slow test using a hf-internal-testing/ tiny LLaDA2-style fixture if available, or create one. At minimum, cover from_pretrained/manual construction with a real tokenizer, chat template path, batched padded prompts, eos_early_stop, and callback tensor selection.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions