Skip to content

numpy_to_pil fails on flax pipeline output #835

@keturn

Description

@keturn

Describe the bug

When I try to numpy_to_pil the output of FlaxStableDiffusionPipeline, it crashes with

'DeviceArray' object has no attribute 'array_interface'

(maybe possibly related to safety_checker=None)

Reproduction

pipe, params = diffusers.FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="flax",
    safety_checker=None,
    device_map='auto',
    dtype=jnp.float16,
)

prompt="A delightful child enjoying a hearty bowl of JAX cereal."
prompt_inputs = pipe.prepare_inputs(prompt)
result = pipe(
    prompt_ids=prompt_inputs, 
    params=params,
    num_inference_steps=12,
    prng_seed=jax.random.PRNGKey(0)
)

images = pipe.numpy_to_pil(result.images)

Logs

diffusers/pipeline_flax_utils.py:447 in <listcomp>                                                          
│   444 │   │   if images.ndim == 3:                                                               
│   445 │   │   │   images = images[None, ...]                                                     
│   446 │   │   images = (images * 255).round().astype("uint8")                                    
│ ❱ 447 │   │   pil_images = [Image.fromarray(image) for image in images]                          
│   448 │   │                                                                                      
│   449 │   │   return pil_images                                          

Image.py:2803 in fromarray
│ ❱ 2803 │   arr = obj.__array_interface__                                                         
AttributeError: 'DeviceArray' object has no attribute '__array_interface__'

System Info

  • diffusers version: 0.5.1
  • Platform: Linux-5.15.0-50-generic-x86_64-with-glibc2.35
  • Python version: 3.10.6
  • PyTorch version (GPU?): 1.12.1+cu116 (True)
  • Huggingface_hub version: 0.10.1
  • Transformers version: 4.23.1
  • Using GPU in script?: unsure? CPU utilization is high, vRAM allocation is high, GPU utilization is low. result.images.device() is CPU.
  • Using distributed or parallel set-up in script?: no

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingstaleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions