Skip to content

Sana 4K: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! #10520

@nitinmukesh

Description

@nitinmukesh

Describe the bug

Inference not working with quantization

Reproduction

Use the sample code from here
https://github.com/NVlabs/Sana/blob/main/asset/docs/8bit_sana.md#quantization

Replace model with Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers
and dtype torch.bfloat16

Logs

(venv) C:\ai1\diffuser_t2i>python Sana_4K-Quant.py
`low_cpu_mem_usage` was None, now default to True since model is quantized.
Loading checkpoint shards: 100%|████████████████████████████████████| 2/2 [00:28<00:00, 14.45s/it]
Expected types for text_encoder: ['AutoModelForCausalLM'], got Gemma2Model.
Loading pipeline components...: 100%|███████████████████████████████| 5/5 [00:15<00:00,  3.17s/it]
The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
C:\ai1\diffuser_t2i\venv\lib\site-packages\bitsandbytes\autograd\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
  0%|                                                                      | 0/20 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "C:\ai1\diffuser_t2i\Sana_4K-Quant.py", line 30, in <module>
    image = pipeline(prompt).images[0]
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\pipelines\sana\pipeline_sana.py", line 882, in __call__
    noise_pred = self.transformer(
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\models\transformers\sana_transformer.py", line 414, in forward
    hidden_states = self.patch_embed(hidden_states)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\models\embeddings.py", line 569, in forward
    return (latent + pos_embed).to(latent.dtype)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

System Info

python 3.10.11

accelerate 1.2.0.dev0
aiofiles 23.2.1
annotated-types 0.7.0
anyio 4.7.0
bitsandbytes 0.45.0
certifi 2024.12.14
charset-normalizer 3.4.1
click 8.1.8
colorama 0.4.6
diffusers 0.33.0.dev0
einops 0.8.0
exceptiongroup 1.2.2
fastapi 0.115.6
ffmpy 0.5.0
filelock 3.16.1
fsspec 2024.12.0
gguf 0.13.0
gradio 5.9.1
gradio_client 1.5.2
h11 0.14.0
httpcore 1.0.7
httpx 0.28.1
huggingface-hub 0.25.2
idna 3.10
imageio 2.36.1
imageio-ffmpeg 0.5.1
importlib_metadata 8.5.0
Jinja2 3.1.5
markdown-it-py 3.0.0
MarkupSafe 2.1.5
mdurl 0.1.2
mpmath 1.3.0
networkx 3.4.2
ninja 1.11.1.3
numpy 2.2.1
opencv-python 4.10.0.84
optimum-quanto 0.2.6.dev0
orjson 3.10.13
packaging 24.2
pandas 2.2.3
patch-conv 0.0.1b0
pillow 11.1.0
pip 23.0.1
protobuf 5.29.2
psutil 6.1.1
pydantic 2.10.4
pydantic_core 2.27.2
pydub 0.25.1
Pygments 2.18.0
python-dateutil 2.9.0.post0
python-multipart 0.0.20
pytz 2024.2
PyYAML 6.0.2
regex 2024.11.6
requests 2.32.3
rich 13.9.4
ruff 0.8.6
safehttpx 0.1.6
safetensors 0.5.0
semantic-version 2.10.0
sentencepiece 0.2.0
setuptools 65.5.0
shellingham 1.5.4
six 1.17.0
sniffio 1.3.1
starlette 0.41.3
sympy 1.13.1
tokenizers 0.21.0
tomlkit 0.13.2
torch 2.5.1+cu124
torchao 0.7.0
torchvision 0.20.1+cu124
tqdm 4.67.1
transformers 4.47.1
typer 0.15.1
typing_extensions 4.12.2
tzdata 2024.2
urllib3 2.3.0
uvicorn 0.34.0
websockets 14.1
wheel 0.45.1
zipp 3.21.0

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions