llada2 model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Reviewed pipeline imports/lazy-loading, pipeline runtime behavior, scheduler coupling, callbacks, docs/examples, fast tests, and slow-test coverage. Existing fast tests pass: 42 passed. Duplicate search found existing LLaDA2 issue #13357 and PRs #12911/#13226/#13333, but no duplicate for the findings below.
Issue 1: Tokenizer padding masks are discarded
Affected code:
|
encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) |
|
return encoded["input_ids"] |
|
# 2D attention mask (no padding) — the model handles backend-specific conversion internally. |
|
attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long) |
|
|
|
position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) |
Problem:
_prepare_input_ids() drops the tokenizer attention_mask, and __call__() replaces it with an all-ones mask. Batched prompts with padding, and non-block-aligned padded tail positions, are treated as real context tokens.
Impact:
Shorter prompts in a batch attend to pad tokens as prompt content. Final block padding beyond prompt_length + gen_length is also exposed as valid masked tokens, which can change logits for returned tokens.
Reproduction:
from types import SimpleNamespace
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("anchor", torch.empty(0))
self.masks = []
@property
def device(self): return self.anchor.device
@property
def dtype(self): return torch.float32
def forward(self, input_ids, attention_mask=None, position_ids=None):
self.masks.append(attention_mask.cpu())
logits = torch.zeros(input_ids.shape[0], input_ids.shape[1], 128)
logits[..., 1] = 10
return SimpleNamespace(logits=logits)
class Tok:
eos_token_id = None
mask_token_id = 99
chat_template = None
def __call__(self, prompt, return_tensors=None, padding=False):
return {
"input_ids": torch.tensor([[11, 12, 0], [21, 0, 0]]),
"attention_mask": torch.tensor([[1, 1, 0], [1, 0, 0]]),
}
model = Model()
pipe = LLaDA2Pipeline(model=model, scheduler=BlockRefinementScheduler(), tokenizer=Tok())
pipe.set_progress_bar_config(disable=True)
pipe(prompt=["long", "short"], use_chat_template=False, gen_length=1, block_length=4, num_inference_steps=1, output_type="seq")
print(model.masks[0].tolist()) # [[1, 1, 1, 1], [1, 1, 1, 1]]
Relevant precedent:
QwenImage preserves tokenizer masks through prompt encoding:
|
txt_tokens = self.tokenizer( |
|
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" |
|
).to(device) |
|
encoder_hidden_states = self.text_encoder( |
|
input_ids=txt_tokens.input_ids, |
|
attention_mask=txt_tokens.attention_mask, |
|
output_hidden_states=True, |
|
) |
|
hidden_states = encoder_hidden_states.hidden_states[-1] |
|
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) |
|
split_hidden_states = [e[drop_idx:] for e in split_hidden_states] |
|
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] |
|
max_seq_len = max([e.size(0) for e in split_hidden_states]) |
|
prompt_embeds = torch.stack( |
|
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] |
|
) |
|
encoder_attention_mask = torch.stack( |
|
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] |
|
) |
|
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
|
return prompt_embeds, encoder_attention_mask |
Suggested fix:
Carry tokenizer attention_mask out of prompt encoding, add an attention_mask argument for pre-tokenized input_ids, build a valid mask for prompt plus requested generated positions only, and prevent the scheduler from committing padded tail positions.
Issue 2: block_length argument does not control the scheduler transfer schedule
Affected code:
|
num_inference_steps = min(num_inference_steps, gen_length // minimal_topk) |
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: |
|
if num_inference_steps <= 0: |
|
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") |
|
self.num_inference_steps = num_inference_steps |
|
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long) |
|
self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to( |
|
device=device if device is not None else "cpu" |
|
) |
Problem:
The pipeline accepts block_length, but BlockRefinementScheduler.set_timesteps() computes _transfer_schedule from self.config.block_length. With the default scheduler, pipe(..., block_length=8) still uses a 32-token transfer schedule.
Impact:
Custom block sizes refine at the wrong rate. For block_length=8, num_inference_steps=8, the first step commits 4 tokens instead of 1.
Reproduction:
from types import SimpleNamespace
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("anchor", torch.empty(0))
@property
def device(self): return self.anchor.device
@property
def dtype(self): return torch.float32
def forward(self, input_ids, attention_mask=None, position_ids=None):
logits = torch.zeros(input_ids.shape[0], input_ids.shape[1], 128)
logits[..., 1] = 10
return SimpleNamespace(logits=logits)
commits = []
def cb(pipe, step, timestep, kwargs):
commits.append(int(kwargs["transfer_index"].sum()))
return {}
pipe = LLaDA2Pipeline(model=Model(), scheduler=BlockRefinementScheduler())
pipe.set_progress_bar_config(disable=True)
pipe(
input_ids=torch.empty((1, 0), dtype=torch.long),
gen_length=8,
block_length=8,
num_inference_steps=8,
threshold=2.0,
mask_token_id=127,
output_type="seq",
eos_early_stop=False,
callback_on_step_end=cb,
callback_on_step_end_tensor_inputs=["transfer_index"],
)
print(commits[:3]) # [4, 4, 0]
Relevant precedent:
Schedulers should derive runtime schedules from the pipeline call parameters passed into set_timesteps, not stale constructor defaults.
Suggested fix:
def set_timesteps(self, num_inference_steps: int, device=None, block_length: int | None = None) -> None:
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
block_length = self.config.block_length if block_length is None else block_length
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(num_inference_steps - 1, -1, device=device, dtype=torch.long)
self._transfer_schedule = self.get_num_transfer_tokens(block_length, num_inference_steps).to(device=device or "cpu")
Then call self.scheduler.set_timesteps(num_inference_steps, device=device, block_length=block_length).
Issue 3: Advertised callback tensor inputs raise KeyError
Affected code:
|
_callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"] |
|
if callback_on_step_end is not None: |
|
callback_kwargs = {} |
|
for k in callback_on_step_end_tensor_inputs: |
|
callback_kwargs[k] = locals()[k] |
|
callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs) |
|
block_x = callback_outputs.pop("block_x", block_x) |
|
|
Problem:
_callback_tensor_inputs allows x0, x0_p, confidence, and active_block, but those names are not locals at callback collection time. Requesting them passes validation and then crashes.
Impact:
The callback API is unreliable for documented introspection and debugging.
Reproduction:
from types import SimpleNamespace
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("anchor", torch.empty(0))
@property
def device(self): return self.anchor.device
@property
def dtype(self): return torch.float32
def forward(self, input_ids, attention_mask=None, position_ids=None):
logits = torch.zeros(input_ids.shape[0], input_ids.shape[1], 16)
logits[..., 1] = 10
return SimpleNamespace(logits=logits)
pipe = LLaDA2Pipeline(model=Model(), scheduler=BlockRefinementScheduler())
pipe.set_progress_bar_config(disable=True)
try:
pipe(
input_ids=torch.empty((1, 0), dtype=torch.long),
gen_length=1,
block_length=1,
num_inference_steps=1,
mask_token_id=15,
output_type="seq",
callback_on_step_end=lambda *args: {},
callback_on_step_end_tensor_inputs=["confidence"],
)
except Exception as e:
print(type(e).__name__, e) # KeyError 'confidence'
Relevant precedent:
QwenImage keeps _callback_tensor_inputs aligned with actual locals:
|
_callback_tensor_inputs = ["latents", "prompt_embeds"] |
|
if callback_on_step_end is not None: |
|
callback_kwargs = {} |
|
for k in callback_on_step_end_tensor_inputs: |
|
callback_kwargs[k] = locals()[k] |
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
|
latents = callback_outputs.pop("latents", latents) |
Suggested fix:
_callback_tensor_inputs = ["block_x", "transfer_index"]
# Or define the extra values explicitly before callback collection:
active_block = block_tokens == mask_token_id
confidence = scheduler_output.sampled_probs
Avoid keeping x0 / x0_p unless their semantics are implemented and tested.
Issue 4: EOS at the first generated position is ignored
Affected code:
|
eos_pos = (cur_x[b] == eos_token_id).nonzero(as_tuple=True) |
|
if len(eos_pos[0]) == 0: |
|
continue |
|
eos_pos = int(eos_pos[0][0].item()) |
|
if prompt_length >= eos_pos: |
|
continue |
|
if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item(): |
|
finished[b] = True |
Problem:
check_eos_finished() skips EOS when prompt_length >= eos_pos. The first generated token has index prompt_length, so an EOS immediately after the prompt is treated as if it were inside the prompt.
Impact:
eos_early_stop=True fails for the most common early-stop case: the model ending generation on the first generated token.
Reproduction:
import torch
from diffusers import BlockRefinementScheduler
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=torch.tensor([[10, 2, 99]]),
sampled_tokens=torch.tensor([[0, 2]]),
final_transfer=torch.tensor([[False, True]]),
finished=torch.tensor([False]),
eos_token_id=2,
mask_token_id=99,
prompt_length=1,
)
print(finished.tolist()) # [False], expected [True]
Relevant precedent:
Prompt positions are indices < prompt_length; the first generated position is exactly prompt_length.
Suggested fix:
if eos_pos < prompt_length:
continue
Issue 5: Finished rows in a batch keep being refined after EOS
Affected code:
|
finished = torch.zeros((batch_size,), device=device, dtype=torch.bool) |
|
editing_enabled = editing_threshold is not None and editing_threshold > 0.0 |
|
should_continue = self.scheduler.check_block_should_continue( |
|
step_idx=step_idx, |
|
masks_remaining=masks_remaining, |
|
editing_enabled=editing_enabled, |
|
editing_transfer_index=editing_transfer_index, |
|
post_steps=post_steps, |
|
max_post_steps=max_post_steps, |
|
finished=finished, |
|
) |
|
|
|
progress_bar.close() |
|
x[:, :current_window_end] = block_x |
|
if eos_early_stop and finished.all(): |
|
break |
|
if eos_token_id is not None and batch_size == 1: |
|
eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0] |
|
if len(eos_positions) > 0: |
|
sequences = sequences[:, : int(eos_positions[0].item()) + 1] |
Problem:
finished is only used to stop when all batch rows finish. Rows already marked finished remain eligible for later block updates, and sequence trimming only runs for batch_size == 1.
Impact:
In mixed-length batches, text after EOS can be generated and decoded for rows that should have stopped.
Reproduction:
from types import SimpleNamespace
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("anchor", torch.empty(0))
@property
def device(self): return self.anchor.device
@property
def dtype(self): return torch.float32
def forward(self, input_ids, attention_mask=None, position_ids=None):
b, s = input_ids.shape
logits = torch.zeros(b, s, 128)
if s <= 3:
logits[0, :, 5] = 10
logits[0, 2, 2] = 20
logits[1, :, 6] = 10
else:
logits[0, :, 7] = 10
logits[1, :, 6] = 10
return SimpleNamespace(logits=logits)
pipe = LLaDA2Pipeline(model=Model(), scheduler=BlockRefinementScheduler())
pipe.set_progress_bar_config(disable=True)
out = pipe(
input_ids=torch.tensor([[10], [20]]),
gen_length=5,
block_length=3,
num_inference_steps=3,
threshold=2.0,
mask_token_id=127,
eos_token_id=2,
eos_early_stop=True,
output_type="seq",
)
print(out.sequences.tolist()) # row 0 has tokens after EOS: [5, 2, 7, 7, 7]
Relevant precedent:
Generation APIs should freeze or mask rows after EOS and decode per-row only up to EOS.
Suggested fix:
Freeze finished rows before applying future transfers, and trim per-row decode inputs:
if finished.any():
final_transfer = final_transfer & ~finished[:, None]
decode_sequences = sequences
if eos_token_id is not None:
decode_sequences = [
seq[: int((seq == eos_token_id).nonzero(as_tuple=True)[0][0]) + 1]
if (seq == eos_token_id).any()
else seq
for seq in sequences
]
texts = self.tokenizer.batch_decode(decode_sequences, skip_special_tokens=True)
Issue 6: Inner progress bars ignore disable=True
Affected code:
|
# 5. Block-wise refinement loop |
|
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy() |
|
block_progress_bar_config["position"] = 0 |
|
block_progress_bar_config["desc"] = "Blocks" |
|
for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config): |
|
current_window_end = (num_block + 1) * block_length |
|
block_x = x[:, :current_window_end] |
|
block_attn_mask = attn_mask[:, :current_window_end] |
|
block_position_ids = position_ids[:, :current_window_end] |
|
|
|
# Identify which positions in the block are prompt (non-editable). |
|
block_start_pos = num_block * block_length |
|
prompt_mask_in_block = torch.zeros(block_length, device=device, dtype=torch.bool) |
|
if block_start_pos < prompt_length: |
|
prompt_end_in_block = min(prompt_length - block_start_pos, block_length) |
|
prompt_mask_in_block[:prompt_end_in_block] = True |
|
|
|
post_steps = 0 |
|
step_idx = 0 |
|
should_continue = True |
|
self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps") |
|
progress_bar = self.progress_bar(total=num_inference_steps) |
Problem:
The pipeline copies the outer progress config, but then calls self.set_progress_bar_config(position=1, leave=False, desc=...) inside the block loop. This replaces the user’s existing config, including disable=True.
Impact:
Users who disable progress bars still get inner progress output, and the pipeline leaves _progress_bar_config mutated after the call.
Reproduction:
from types import SimpleNamespace
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("anchor", torch.empty(0))
@property
def device(self): return self.anchor.device
@property
def dtype(self): return torch.float32
def forward(self, input_ids, attention_mask=None, position_ids=None):
logits = torch.zeros(input_ids.shape[0], input_ids.shape[1], 16)
logits[..., 1] = 10
return SimpleNamespace(logits=logits)
pipe = LLaDA2Pipeline(model=Model(), scheduler=BlockRefinementScheduler())
pipe.set_progress_bar_config(disable=True)
pipe(input_ids=torch.empty((1, 0), dtype=torch.long), gen_length=1, block_length=1, num_inference_steps=1, mask_token_id=15)
print(pipe._progress_bar_config) # {'position': 1, 'leave': False, 'desc': ...}
Relevant precedent:
DiffusionPipeline.progress_bar() already merges the current config and supplies distributed defaults:
|
def progress_bar(self, iterable=None, total=None): |
|
if not hasattr(self, "_progress_bar_config"): |
|
self._progress_bar_config = {} |
|
elif not isinstance(self._progress_bar_config, dict): |
|
raise ValueError( |
|
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." |
|
) |
|
|
|
progress_bar_config = dict(self._progress_bar_config) |
|
if "disable" not in progress_bar_config: |
|
progress_bar_config["disable"] = not is_torch_dist_rank_zero() |
|
|
|
if iterable is not None: |
|
return tqdm(iterable, **progress_bar_config) |
|
elif total is not None: |
|
return tqdm(total=total, **progress_bar_config) |
|
else: |
|
raise ValueError("Either `total` or `iterable` has to be defined.") |
|
|
|
def set_progress_bar_config(self, **kwargs): |
|
self._progress_bar_config = kwargs |
Suggested fix:
Preserve and restore the prior config, or build a local config without calling set_progress_bar_config() inside __call__.
Issue 7: Missing slow/integration tests for LLaDA2
Affected code:
|
class LLaDA2PipelineTest(unittest.TestCase): |
|
def test_pipeline_runs(self): |
|
pipe = _make_pipeline().to("cpu") |
|
|
|
input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long) |
|
out = pipe( |
|
input_ids=input_ids, |
|
use_chat_template=False, |
|
gen_length=24, |
|
block_length=8, |
|
num_inference_steps=8, |
|
temperature=0.0, |
|
threshold=2.0, # force top-k commits |
|
minimal_topk=1, |
|
eos_early_stop=False, |
|
mask_token_id=31, |
|
eos_token_id=None, |
|
output_type="seq", |
|
) |
|
|
|
self.assertEqual(out.sequences.shape, (2, 24)) |
|
self.assertFalse((out.sequences == 31).any().item()) |
|
|
|
def test_pipeline_return_tuple(self): |
|
pipe = _make_pipeline().to("cpu") |
|
|
|
input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long) |
|
sequences, texts = pipe( |
|
input_ids=input_ids, |
|
use_chat_template=False, |
|
gen_length=16, |
|
block_length=8, |
|
num_inference_steps=4, |
|
temperature=0.0, |
|
threshold=2.0, |
|
minimal_topk=1, |
|
eos_early_stop=False, |
|
mask_token_id=31, |
|
output_type="seq", |
|
return_dict=False, |
|
) |
|
|
|
self.assertEqual(sequences.shape, (1, 16)) |
|
self.assertIsNone(texts) |
|
|
|
def test_output_type_seq(self): |
|
"""output_type='seq' should return sequences but no texts.""" |
|
pipe = _make_pipeline().to("cpu") |
|
|
|
out = pipe( |
|
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), |
|
use_chat_template=False, |
|
gen_length=16, |
|
block_length=8, |
|
num_inference_steps=4, |
|
temperature=0.0, |
|
threshold=2.0, |
|
minimal_topk=1, |
|
eos_early_stop=False, |
|
mask_token_id=31, |
|
output_type="seq", |
|
) |
|
|
|
self.assertIsNotNone(out.sequences) |
|
self.assertEqual(out.sequences.shape, (1, 16)) |
|
self.assertIsNone(out.texts) |
|
|
|
def test_output_type_text_without_tokenizer(self): |
|
"""output_type='text' without a tokenizer should return texts=None.""" |
|
pipe = _make_pipeline(tokenizer=None).to("cpu") |
|
|
|
out = pipe( |
|
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), |
|
use_chat_template=False, |
|
gen_length=16, |
|
block_length=8, |
|
num_inference_steps=4, |
|
temperature=0.0, |
|
threshold=2.0, |
|
minimal_topk=1, |
|
eos_early_stop=False, |
|
mask_token_id=31, |
|
output_type="text", |
|
) |
|
|
|
self.assertIsNotNone(out.sequences) |
|
self.assertIsNone(out.texts) |
|
|
|
def test_output_type_text_with_tokenizer(self): |
|
"""output_type='text' with a tokenizer should return decoded texts.""" |
|
tok = type( |
|
"Tok", |
|
(), |
|
{ |
|
"eos_token_id": None, |
|
"mask_token_id": 31, |
|
"batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs], |
|
}, |
|
)() |
|
pipe = _make_pipeline(tokenizer=tok).to("cpu") |
|
|
|
out = pipe( |
|
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), |
|
use_chat_template=False, |
|
gen_length=16, |
|
block_length=8, |
|
num_inference_steps=4, |
|
temperature=0.0, |
|
threshold=2.0, |
|
minimal_topk=1, |
|
eos_early_stop=False, |
|
output_type="text", |
|
) |
|
|
|
self.assertIsNotNone(out.sequences) |
|
self.assertIsNotNone(out.texts) |
|
self.assertEqual(len(out.texts), 1) |
|
self.assertTrue(out.texts[0].startswith("decoded_")) |
|
|
|
def test_output_type_invalid_raises(self): |
|
"""Invalid output_type should raise ValueError.""" |
|
pipe = _make_pipeline().to("cpu") |
|
|
|
with self.assertRaises(ValueError): |
|
pipe( |
|
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), |
|
use_chat_template=False, |
|
gen_length=16, |
|
block_length=8, |
|
num_inference_steps=4, |
|
mask_token_id=31, |
|
output_type="invalid", |
|
) |
|
|
|
def test_prepare_input_ids_from_tensor(self): |
|
pipe = _make_pipeline() |
|
ids = torch.tensor([[1, 2, 3]], dtype=torch.long) |
|
result = pipe._prepare_input_ids( |
|
prompt=None, |
|
messages=None, |
|
input_ids=ids, |
|
use_chat_template=False, |
|
add_generation_prompt=False, |
|
chat_template_kwargs=None, |
|
) |
|
self.assertTrue(torch.equal(result, ids)) |
|
|
|
def test_prepare_input_ids_from_1d_tensor(self): |
|
pipe = _make_pipeline() |
|
ids = torch.tensor([1, 2, 3], dtype=torch.long) |
|
result = pipe._prepare_input_ids( |
|
prompt=None, |
|
messages=None, |
|
input_ids=ids, |
|
use_chat_template=False, |
|
add_generation_prompt=False, |
|
chat_template_kwargs=None, |
|
) |
|
self.assertEqual(result.shape, (1, 3)) |
|
|
|
def test_prepare_input_ids_no_tokenizer_raises(self): |
|
pipe = _make_pipeline(tokenizer=None) |
|
with self.assertRaises(ValueError): |
|
pipe._prepare_input_ids( |
|
prompt="hello", |
|
messages=None, |
|
input_ids=None, |
|
use_chat_template=False, |
|
add_generation_prompt=False, |
|
chat_template_kwargs=None, |
|
) |
|
|
|
def test_prepare_input_ids_both_prompt_and_messages_raises(self): |
|
pipe = _make_pipeline() |
|
# Manually set tokenizer to a simple object so _prepare_input_ids doesn't short-circuit |
|
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})() |
|
with self.assertRaises(ValueError): |
|
pipe._prepare_input_ids( |
|
prompt="hello", |
|
messages=[{"role": "user", "content": "hi"}], |
|
input_ids=None, |
|
use_chat_template=False, |
|
add_generation_prompt=False, |
|
chat_template_kwargs=None, |
|
) |
|
|
|
def test_prepare_input_ids_neither_raises(self): |
|
pipe = _make_pipeline() |
|
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})() |
|
with self.assertRaises(ValueError): |
|
pipe._prepare_input_ids( |
|
prompt=None, |
|
messages=None, |
|
input_ids=None, |
|
use_chat_template=False, |
|
add_generation_prompt=False, |
|
chat_template_kwargs=None, |
|
) |
|
|
Problem:
Only fast dummy-model tests exist under tests/pipelines/llada2. There is no slow test that loads an actual or tiny Hub fixture, exercises tokenizer chat templating/padding, or checks a realistic LLaDA2Pipeline call.
Impact:
The fast suite misses real integration risks, including the tokenizer mask and EOS/callback failures above.
Reproduction:
from pathlib import Path
hits = []
for path in Path("tests/pipelines/llada2").rglob("*.py"):
text = path.read_text()
if "@slow" in text or "hf-internal-testing" in text or "inclusionAI/" in text:
hits.append(str(path))
print(hits) # []
Relevant precedent:
The repo generally pairs pipeline fast tests with slow or tiny-fixture coverage when Hub loading/tokenizer behavior is part of the public path.
Suggested fix:
Add a slow test using a hf-internal-testing/ tiny LLaDA2-style fixture if available, or create one. At minimum, cover from_pretrained/manual construction with a real tokenizer, chat template path, batched padded prompts, eos_early_stop, and callback tensor selection.
llada2model/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Reviewed pipeline imports/lazy-loading, pipeline runtime behavior, scheduler coupling, callbacks, docs/examples, fast tests, and slow-test coverage. Existing fast tests pass:
42 passed. Duplicate search found existing LLaDA2 issue #13357 and PRs #12911/#13226/#13333, but no duplicate for the findings below.Issue 1: Tokenizer padding masks are discarded
Affected code:
diffusers/src/diffusers/pipelines/llada2/pipeline_llada2.py
Lines 147 to 148 in 0f1abc4
diffusers/src/diffusers/pipelines/llada2/pipeline_llada2.py
Lines 362 to 365 in 0f1abc4
Problem:
_prepare_input_ids()drops the tokenizerattention_mask, and__call__()replaces it with an all-ones mask. Batched prompts with padding, and non-block-aligned padded tail positions, are treated as real context tokens.Impact:
Shorter prompts in a batch attend to pad tokens as prompt content. Final block padding beyond
prompt_length + gen_lengthis also exposed as valid masked tokens, which can change logits for returned tokens.Reproduction:
Relevant precedent:
QwenImage preserves tokenizer masks through prompt encoding:
diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Lines 202 to 224 in 0f1abc4
Suggested fix:
Carry tokenizer
attention_maskout of prompt encoding, add anattention_maskargument for pre-tokenizedinput_ids, build a valid mask for prompt plus requested generated positions only, and prevent the scheduler from committing padded tail positions.Issue 2:
block_lengthargument does not control the scheduler transfer scheduleAffected code:
diffusers/src/diffusers/pipelines/llada2/pipeline_llada2.py
Lines 354 to 356 in 0f1abc4
diffusers/src/diffusers/schedulers/scheduling_block_refinement.py
Lines 78 to 85 in 0f1abc4
Problem:
The pipeline accepts
block_length, butBlockRefinementScheduler.set_timesteps()computes_transfer_schedulefromself.config.block_length. With the default scheduler,pipe(..., block_length=8)still uses a 32-token transfer schedule.Impact:
Custom block sizes refine at the wrong rate. For
block_length=8, num_inference_steps=8, the first step commits 4 tokens instead of 1.Reproduction:
Relevant precedent:
Schedulers should derive runtime schedules from the pipeline call parameters passed into
set_timesteps, not stale constructor defaults.Suggested fix:
Then call
self.scheduler.set_timesteps(num_inference_steps, device=device, block_length=block_length).Issue 3: Advertised callback tensor inputs raise
KeyErrorAffected code:
diffusers/src/diffusers/pipelines/llada2/pipeline_llada2.py
Line 74 in 0f1abc4
diffusers/src/diffusers/pipelines/llada2/pipeline_llada2.py
Lines 447 to 453 in 0f1abc4
Problem:
_callback_tensor_inputsallowsx0,x0_p,confidence, andactive_block, but those names are not locals at callback collection time. Requesting them passes validation and then crashes.Impact:
The callback API is unreliable for documented introspection and debugging.
Reproduction:
Relevant precedent:
QwenImage keeps
_callback_tensor_inputsaligned with actual locals:diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Line 152 in 0f1abc4
diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Lines 733 to 739 in 0f1abc4
Suggested fix:
Avoid keeping
x0/x0_punless their semantics are implemented and tested.Issue 4: EOS at the first generated position is ignored
Affected code:
diffusers/src/diffusers/schedulers/scheduling_block_refinement.py
Lines 342 to 349 in 0f1abc4
Problem:
check_eos_finished()skips EOS whenprompt_length >= eos_pos. The first generated token has indexprompt_length, so an EOS immediately after the prompt is treated as if it were inside the prompt.Impact:
eos_early_stop=Truefails for the most common early-stop case: the model ending generation on the first generated token.Reproduction:
Relevant precedent:
Prompt positions are indices
< prompt_length; the first generated position is exactlyprompt_length.Suggested fix:
Issue 5: Finished rows in a batch keep being refined after EOS
Affected code:
diffusers/src/diffusers/pipelines/llada2/pipeline_llada2.py
Lines 375 to 376 in 0f1abc4
diffusers/src/diffusers/pipelines/llada2/pipeline_llada2.py
Lines 459 to 472 in 0f1abc4
diffusers/src/diffusers/pipelines/llada2/pipeline_llada2.py
Lines 477 to 480 in 0f1abc4
Problem:
finishedis only used to stop when all batch rows finish. Rows already marked finished remain eligible for later block updates, and sequence trimming only runs forbatch_size == 1.Impact:
In mixed-length batches, text after EOS can be generated and decoded for rows that should have stopped.
Reproduction:
Relevant precedent:
Generation APIs should freeze or mask rows after EOS and decode per-row only up to EOS.
Suggested fix:
Freeze finished rows before applying future transfers, and trim per-row decode inputs:
Issue 6: Inner progress bars ignore
disable=TrueAffected code:
diffusers/src/diffusers/pipelines/llada2/pipeline_llada2.py
Lines 379 to 400 in 0f1abc4
Problem:
The pipeline copies the outer progress config, but then calls
self.set_progress_bar_config(position=1, leave=False, desc=...)inside the block loop. This replaces the user’s existing config, includingdisable=True.Impact:
Users who disable progress bars still get inner progress output, and the pipeline leaves
_progress_bar_configmutated after the call.Reproduction:
Relevant precedent:
DiffusionPipeline.progress_bar()already merges the current config and supplies distributed defaults:diffusers/src/diffusers/pipelines/pipeline_utils.py
Lines 1964 to 1984 in 0f1abc4
Suggested fix:
Preserve and restore the prior config, or build a local config without calling
set_progress_bar_config()inside__call__.Issue 7: Missing slow/integration tests for LLaDA2
Affected code:
diffusers/tests/pipelines/llada2/test_llada2.py
Lines 44 to 242 in 0f1abc4
Problem:
Only fast dummy-model tests exist under
tests/pipelines/llada2. There is no slow test that loads an actual or tiny Hub fixture, exercises tokenizer chat templating/padding, or checks a realisticLLaDA2Pipelinecall.Impact:
The fast suite misses real integration risks, including the tokenizer mask and EOS/callback failures above.
Reproduction:
Relevant precedent:
The repo generally pairs pipeline fast tests with slow or tiny-fixture coverage when Hub loading/tokenizer behavior is part of the public path.
Suggested fix:
Add a slow test using a
hf-internal-testing/tiny LLaDA2-style fixture if available, or create one. At minimum, coverfrom_pretrained/manual construction with a real tokenizer, chat template path, batched padded prompts,eos_early_stop, and callback tensor selection.