Skip to content

Loading models from cache extra slow due to extra conversion #5460

@joe-chiu

Description

@joe-chiu

Describe the bug

I have 2 Python environments, one on Windows and another on Linux (over WSL), both using diffusers. To avoid having mutliple copies of the same model on disk, I try to make these two installations share a single diffusers model cache.

  1. Ubuntu on WSL, Python 3.10.12, cache is a symlink pointing to NTFS disk, ie. the hub folder is pointing to /mnt/c/python/StableDiffusion/hub
  2. Windows 11 Home, Python 3.11.6, cache resides on a folder in NTFS disk, ie. HF_HOME=C:\python\StableDiffusion

I am running the same python code in a Jupyter notebook that just loads a SDXL model using diffusers.
On the Windows installation, I will see the progress bar finishes quickly in 1-2 seconds.
On the Linux, the same progress bar would finish just as quickly, but there would be another 2 minutes of execution until the cell finishes. I manually interrupted the script and I could see some sort of additional data format conversion code was running.

I tried 2 additional set up:

  1. I cleared the shared cache, and let the Ubuntu installation re-download and re-create the cache (cache still resides on a symlink pointing to a NTFS folder) => still see the additional slow down after the model is fully downloaded. Windows installation could still function perfectly with cache created by Ubuntu, no conversion delay.
  2. I stop using symlink to NTFS folders and let the Ubuntu installation uses its own cache on its own file system => no more slow down. But I now have 2 copies of each models.
from diffusers import StableDiffusionXLPipeline, AutoencoderKL, DPMSolverMultistepScheduler
import torch

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    vae=vae,
    torch_dtype=torch.float16, variant="fp16",
    use_safetensors=True
).to("cuda")

This is the traceback for the slow model loading on Linux, I can see some sort of model data conversation seems to be executing when I interrupted the execution a minute in.

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[3], line 10
      2 import torch
      4 vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
      5 pipe = StableDiffusionXLPipeline.from_pretrained(
      6     "stabilityai/stable-diffusion-xl-base-1.0",
      7     vae=vae,
      8     torch_dtype=torch.float16, variant="fp16",
      9     use_safetensors=True
---> 10 ).to("cuda")

File ~/stable_diffusion_venv/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py:733, in DiffusionPipeline.to(self, torch_device, torch_dtype, silence_dtype_warnings)
    729     logger.warning(
    730         f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
    731     )
    732 else:
--> 733     module.to(torch_device, torch_dtype)
    735 if (
    736     module.dtype == torch.float16
    737     and str(torch_device) in ["cpu"]
    738     and not silence_dtype_warnings
    739     and not is_offloaded
    740 ):
    741     logger.warning(
    742         "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
    743         " is not recommended to move them to `cpu` as running them will fail. Please make"
   (...)
    746         " `torch_dtype=torch.float16` argument, or use another device for inference."
    747     )

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1160, in Module.to(self, *args, **kwargs)
   1156         return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1157                     non_blocking, memory_format=convert_to_format)
   1158     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
-> 1160 return self._apply(convert)

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:810, in Module._apply(self, fn, recurse)
    808 if recurse:
    809     for module in self.children():
--> 810         module._apply(fn)
    812 def compute_should_use_set_data(tensor, tensor_applied):
    813     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    814         # If the new tensor has compatible tensor type as the existing tensor,
    815         # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    820         # global flag to let the user control whether they want the future
    821         # behavior of overwriting the existing tensor or not.

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:810, in Module._apply(self, fn, recurse)
    808 if recurse:
    809     for module in self.children():
--> 810         module._apply(fn)
    812 def compute_should_use_set_data(tensor, tensor_applied):
    813     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    814         # If the new tensor has compatible tensor type as the existing tensor,
    815         # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    820         # global flag to let the user control whether they want the future
    821         # behavior of overwriting the existing tensor or not.

    [... skipping similar frames: Module._apply at line 810 (5 times)]

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:810, in Module._apply(self, fn, recurse)
    808 if recurse:
    809     for module in self.children():
--> 810         module._apply(fn)
    812 def compute_should_use_set_data(tensor, tensor_applied):
    813     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    814         # If the new tensor has compatible tensor type as the existing tensor,
    815         # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    820         # global flag to let the user control whether they want the future
    821         # behavior of overwriting the existing tensor or not.

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:833, in Module._apply(self, fn, recurse)
    829 # Tensors stored in modules are graph leaves, and we don't want to
    830 # track autograd history of `param_applied`, so we have to use
    831 # `with torch.no_grad():`
    832 with torch.no_grad():
--> 833     param_applied = fn(param)
    834 should_use_set_data = compute_should_use_set_data(param, param_applied)
    835 if should_use_set_data:

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1158, in Module.to.<locals>.convert(t)
   1155 if convert_to_format is not None and t.dim() in (4, 5):
   1156     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1157                 non_blocking, memory_format=convert_to_format)
-> 1158 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

KeyboardInterrupt: 

Reproduction

  1. Have an Linux running on Windows WSL
  2. Use diffuers on Windows to load and cache a diffusion model (say, Windows HF_HOME is c:\somewhere, cache in c:\somewhere\hub)
  3. Use diffuers on Linux (on WSL) to load the same model, point Linux HF_HOME to the root of the VirtualEnv and the hub folder within HF_HOME is a symlink to a folder on NTFS (eg. /mnt/c/somewhere/hub)
  4. Diffusers should load the model from cache

Expect - model loads quickly from cache

Actual - some unexpected data conversion happens and delays model loading by ~2 minutes

Logs

No response

System Info

pip freeze on Linux

accelerate==0.23.0
anyio==4.0.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.0
async-lru==2.0.4
attrs==23.1.0
Babel==2.13.0
backcall==0.2.0
beautifulsoup4==4.12.2
bleach==6.1.0
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.0
comm==0.1.4
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
diffusers==0.21.4
exceptiongroup==1.1.3
executing==2.0.0
fastjsonschema==2.18.1
filelock==3.12.4
fqdn==1.5.1
fsspec==2023.9.2
huggingface-hub==0.17.3
idna==3.4
importlib-metadata==6.8.0
ipykernel==6.25.2
ipython==8.16.1
ipywidgets==8.1.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.19.1
jsonschema-specifications==2023.7.1
jupyter-events==0.8.0
jupyter-lsp==2.2.0
jupyter_client==8.4.0
jupyter_core==5.4.0
jupyter_server==2.8.0
jupyter_server_terminals==0.4.4
jupyterlab==4.0.7
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.9
jupyterlab_server==2.25.0
MarkupSafe==2.1.3
matplotlib-inline==0.1.6
mistune==3.0.2
mpmath==1.3.0
nbclient==0.8.0
nbconvert==7.9.2
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.2
notebook_shim==0.2.3
numpy==1.26.1
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.52
nvidia-nvtx-cu12==12.1.105
overrides==7.4.0
packaging==23.2
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==10.1.0
platformdirs==3.11.0
prometheus-client==0.17.1
prompt-toolkit==3.0.39
psutil==5.9.6
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
Pygments==2.16.1
python-dateutil==2.8.2
python-json-logger==2.0.7
PyYAML==6.0.1
pyzmq==25.1.1
referencing==0.30.2
regex==2023.10.3
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.10.6
safetensors==0.4.0
Send2Trash==1.8.2
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.3
sympy==1.12
terminado==0.17.1
tinycss2==1.2.1
tokenizers==0.14.1
tomli==2.0.1
torch==2.1.0
torchaudio==2.1.0
torchvision==0.16.0
tornado==6.3.3
tqdm==4.66.1
traitlets==5.11.2
transformers==4.34.1
triton==2.1.0
types-python-dateutil==2.8.19.14
typing_extensions==4.8.0
uri-template==1.3.0
urllib3==2.0.7
wcwidth==0.2.8
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.4
widgetsnbextension==4.0.9
zipp==3.17.0

pip freeze on Windows

absl-py==2.0.0
accelerate==0.23.0
antlr4-python3-runtime==4.9.3
anyio==4.0.0
appdirs==1.4.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.4.0
async-lru==2.0.4
attrs==23.1.0
Babel==2.12.1
backcall==0.2.0
beautifulsoup4==4.12.2
bitsandbytes==0.41.1
bleach==6.0.0
cachetools==5.3.1
certifi==2022.12.7
cffi==1.16.0
charset-normalizer==2.1.1
click==8.1.7
colorama==0.4.6
comm==0.1.4
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
diffusers==0.21.4
docker-pycreds==0.4.0
executing==1.2.0
fastjsonschema==2.18.0
filelock==3.9.0
fqdn==1.5.1
fsspec==2023.9.2
ftfy==6.1.1
gitdb==4.0.10
GitPython==3.1.37
google-auth==2.23.2
google-auth-oauthlib==1.0.0
grpcio==1.59.0
huggingface-hub==0.17.3
idna==3.4
importlib-metadata==6.8.0
ipykernel==6.25.2
ipython==8.15.0
ipython-genutils==0.2.0
ipywidgets==8.1.1
isoduration==20.11.0
jedi==0.19.0
Jinja2==3.1.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.19.1
jsonschema-specifications==2023.7.1
jupyter-events==0.7.0
jupyter-lsp==2.2.0
jupyter_client==8.3.1
jupyter_core==5.3.2
jupyter_server==2.7.3
jupyter_server_terminals==0.4.4
jupyterlab==4.0.6
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.9
jupyterlab_server==2.25.0
Markdown==3.5
MarkupSafe==2.1.2
matplotlib-inline==0.1.6
mistune==3.0.1
mpmath==1.3.0
nbclient==0.8.0
nbconvert==7.8.0
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.0
notebook==7.0.4
notebook_shim==0.2.3
numpy==1.24.1
nvidia-cublas-cu12==12.2.5.6
nvidia-cuda-nvrtc-cu12==12.2.140
nvidia-cuda-runtime-cu12==12.2.140
nvidia-cudnn-cu12==8.9.4.25
oauthlib==3.2.2
omegaconf==2.3.0
opencv-python==4.8.1.78
overrides==7.4.0
packaging==23.1
pandocfilters==1.5.0
parso==0.8.3
pathtools==0.1.2
pickleshare==0.7.5
Pillow==9.3.0
platformdirs==3.10.0
prometheus-client==0.17.1
prompt-toolkit==3.0.39
protobuf==4.24.4
psutil==5.9.5
pure-eval==0.2.2
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycparser==2.21
Pygments==2.16.1
python-dateutil==2.8.2
python-json-logger==2.0.7
pywin32==306
pywinpty==2.0.11
PyYAML==6.0.1
pyzmq==25.1.1
qtconsole==5.4.4
QtPy==2.4.0
referencing==0.30.2
regex==2023.8.8
requests==2.31.0
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.10.3
rsa==4.9
safetensors==0.3.3
scipy==1.11.3
Send2Trash==1.8.2
sentry-sdk==1.31.0
setproctitle==1.3.3
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.2
sympy==1.12
tensorboard==2.14.1
tensorboard-data-server==0.7.1
tensorrt-bindings==9.0.1.post12.dev4
tensorrt-libs==9.0.1.post12.dev4
terminado==0.17.1
tinycss2==1.2.1
tokenizers==0.13.3
torch==2.0.1+cu118
torchaudio==2.0.2+cu118
torchvision==0.15.2+cu118
tornado==6.3.3
tqdm==4.66.1
traitlets==5.10.1
transformers==4.33.3
typing_extensions==4.4.0
uri-template==1.3.0
urllib3==1.26.13
wandb==0.15.12
wcwidth==0.2.7
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.3
Werkzeug==3.0.0
widgetsnbextension==4.0.9
xformers==0.0.22
zipp==3.17.0

Who can help?

@sayakpaul @patrickvonplaten

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions