Skip to content

ErnieImage callback locals() scope bug. #13477

@songh11

Description

@songh11

Describe the bug

ErnieImagePipeline builds callback kwargs via locals() in a dict comprehension, which is scope-fragile and may cause callback lookup issues. I will submit a PR with a safe explicit-loop fix.

Reproduction

import torch
from diffusers import ErnieImagePipeline

def on_step_end(pipe, step, timestep, callback_kwargs):
    latents = callback_kwargs["latents"]
    print(f"[callback] step={step}")
    return {}

def main():
    model_id = "PaddlePaddle/ERNIE-Image/"

    pipe = ErnieImagePipeline.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
    ).to("cuda")

    image = pipe(
        prompt="A tiny red car on a rainy street at dusk, cinematic lighting",
        height=512,
        width=512,
        num_inference_steps=8,
        guidance_scale=1.0,
        use_pe=False,  # keep demo minimal and deterministic
        callback_on_step_end=on_step_end,
        callback_on_step_end_tensor_inputs=["latents"],
    ).images[0]

    image.save("output.png")

if __name__ == "__main__":
    main()

Logs

<details>
<summary>Traceback</summary>


Traceback (most recent call last):
  File "/root/workspace/demo.py", line 35, in <module>
    main()
  File "/root/workspace/demo.py", line 19, in main
    image = pipe(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/root/workspace/diffusers/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py", line 355, in __call__
    callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
  File "/root/workspace/diffusers/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py", line 355, in <dictcomp>
    callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
KeyError: 'latents'

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • 🤗 Diffusers version: 0.38.0.dev0
  • Platform: Linux-5.10.134-15.al8.x86_64-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.8.0+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 1.10.2
  • Transformers version: 5.5.4
  • Accelerate version: 1.13.0
  • PEFT version: 0.19.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.7.0
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 4090, 49140 MiB
    NVIDIA GeForce RTX 4090, 49140 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions