-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Closed
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates
Description
Describe the bug
When I try to numpy_to_pil
the output of FlaxStableDiffusionPipeline, it crashes with
'DeviceArray' object has no attribute 'array_interface'
(maybe possibly related to safety_checker=None)
Reproduction
pipe, params = diffusers.FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="flax",
safety_checker=None,
device_map='auto',
dtype=jnp.float16,
)
prompt="A delightful child enjoying a hearty bowl of JAX cereal."
prompt_inputs = pipe.prepare_inputs(prompt)
result = pipe(
prompt_ids=prompt_inputs,
params=params,
num_inference_steps=12,
prng_seed=jax.random.PRNGKey(0)
)
images = pipe.numpy_to_pil(result.images)
Logs
diffusers/pipeline_flax_utils.py:447 in <listcomp>
│ 444 │ │ if images.ndim == 3:
│ 445 │ │ │ images = images[None, ...]
│ 446 │ │ images = (images * 255).round().astype("uint8")
│ ❱ 447 │ │ pil_images = [Image.fromarray(image) for image in images]
│ 448 │ │
│ 449 │ │ return pil_images
Image.py:2803 in fromarray
│ ❱ 2803 │ arr = obj.__array_interface__
AttributeError: 'DeviceArray' object has no attribute '__array_interface__'
System Info
diffusers
version: 0.5.1- Platform: Linux-5.15.0-50-generic-x86_64-with-glibc2.35
- Python version: 3.10.6
- PyTorch version (GPU?): 1.12.1+cu116 (True)
- Huggingface_hub version: 0.10.1
- Transformers version: 4.23.1
- Using GPU in script?: unsure? CPU utilization is high, vRAM allocation is high, GPU utilization is low.
result.images.device()
is CPU. - Using distributed or parallel set-up in script?: no
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates