- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
Description
Describe the bug
Inference not working with quantization
Reproduction
Use the sample code from here
https://github.com/NVlabs/Sana/blob/main/asset/docs/8bit_sana.md#quantization
Replace model with Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers
and dtype torch.bfloat16
Logs
(venv) C:\ai1\diffuser_t2i>python Sana_4K-Quant.py
`low_cpu_mem_usage` was None, now default to True since model is quantized.
Loading checkpoint shards: 100%|████████████████████████████████████| 2/2 [00:28<00:00, 14.45s/it]
Expected types for text_encoder: ['AutoModelForCausalLM'], got Gemma2Model.
Loading pipeline components...: 100%|███████████████████████████████| 5/5 [00:15<00:00,  3.17s/it]
The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
C:\ai1\diffuser_t2i\venv\lib\site-packages\bitsandbytes\autograd\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
  0%|                                                                      | 0/20 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "C:\ai1\diffuser_t2i\Sana_4K-Quant.py", line 30, in <module>
    image = pipeline(prompt).images[0]
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\pipelines\sana\pipeline_sana.py", line 882, in __call__
    noise_pred = self.transformer(
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\models\transformers\sana_transformer.py", line 414, in forward
    hidden_states = self.patch_embed(hidden_states)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\models\embeddings.py", line 569, in forward
    return (latent + pos_embed).to(latent.dtype)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!System Info
python 3.10.11
accelerate         1.2.0.dev0
aiofiles           23.2.1
annotated-types    0.7.0
anyio              4.7.0
bitsandbytes       0.45.0
certifi            2024.12.14
charset-normalizer 3.4.1
click              8.1.8
colorama           0.4.6
diffusers          0.33.0.dev0
einops             0.8.0
exceptiongroup     1.2.2
fastapi            0.115.6
ffmpy              0.5.0
filelock           3.16.1
fsspec             2024.12.0
gguf               0.13.0
gradio             5.9.1
gradio_client      1.5.2
h11                0.14.0
httpcore           1.0.7
httpx              0.28.1
huggingface-hub    0.25.2
idna               3.10
imageio            2.36.1
imageio-ffmpeg     0.5.1
importlib_metadata 8.5.0
Jinja2             3.1.5
markdown-it-py     3.0.0
MarkupSafe         2.1.5
mdurl              0.1.2
mpmath             1.3.0
networkx           3.4.2
ninja              1.11.1.3
numpy              2.2.1
opencv-python      4.10.0.84
optimum-quanto     0.2.6.dev0
orjson             3.10.13
packaging          24.2
pandas             2.2.3
patch-conv         0.0.1b0
pillow             11.1.0
pip                23.0.1
protobuf           5.29.2
psutil             6.1.1
pydantic           2.10.4
pydantic_core      2.27.2
pydub              0.25.1
Pygments           2.18.0
python-dateutil    2.9.0.post0
python-multipart   0.0.20
pytz               2024.2
PyYAML             6.0.2
regex              2024.11.6
requests           2.32.3
rich               13.9.4
ruff               0.8.6
safehttpx          0.1.6
safetensors        0.5.0
semantic-version   2.10.0
sentencepiece      0.2.0
setuptools         65.5.0
shellingham        1.5.4
six                1.17.0
sniffio            1.3.1
starlette          0.41.3
sympy              1.13.1
tokenizers         0.21.0
tomlkit            0.13.2
torch              2.5.1+cu124
torchao            0.7.0
torchvision        0.20.1+cu124
tqdm               4.67.1
transformers       4.47.1
typer              0.15.1
typing_extensions  4.12.2
tzdata             2024.2
urllib3            2.3.0
uvicorn            0.34.0
websockets         14.1
wheel              0.45.1
zipp               3.21.0
Who can help?
No response