-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Hello 👋
Describe the bug
If I use SemanticStableDiffusionPipeline
+ enable_model_cpu_offload
+ a seed (either through torch.manual_seed(0)
or via torch.Generator
, I get tensors in different devices
Reproduction
First install latest diffusers
and accelerate
versions
pip install -U diffusers accelerate
Loading the model and enabling offloading works
from diffusers import SemanticStableDiffusionPipeline
import torch
semantic_pipeline = SemanticStableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
)
semantic_pipeline.enable_model_cpu_offload()
Running the pipeline with no editing works
# this works fine
torch.manual_seed(0)
out = semantic_pipeline(
prompt="a photo of the face of a man",
)
but specifying editing_prompt
makes things fail
# this fails
torch.manual_seed(0)
out = semantic_pipeline(
prompt="a photo of the face of a man",
editing_prompt="smiling, smile",
)
Logs
/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py in __call__(self, prompt, height, width, num_inference_steps, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, output_type, return_dict, callback, callback_steps, editing_prompt, editing_prompt_embeddings, reverse_editing_direction, edit_guidance_scale, edit_warmup_steps, edit_cooldown_steps, edit_threshold, edit_momentum_scale, edit_mom_beta, edit_weights, sem_guidance)
681 noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
682
--> 683 noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
684
685 edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
System Info
- 🤗 Diffusers version: 0.28.2
- Platform: Ubuntu 22.04.3 LTS - Linux-6.1.85+-x86_64-with-glibc2.35
- Running on a notebook?: Yes
- Running on Google Colab?: Yes
- Python version: 3.10.12
- PyTorch version (GPU?): 2.3.0+cu121 (True)
- Flax version (CPU?/GPU?/TPU?): 0.8.4 (gpu)
- Jax version: 0.4.26
- JaxLib version: 0.4.26
- Huggingface_hub version: 0.23.2
- Transformers version: 4.41.2
- Accelerate version: 0.31.0
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.4.3
- xFormers version: not installed
- Accelerator: Tesla T4, 15360 MiB VRAM
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working