-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Description
Describe the bug
I have 2 Python environments, one on Windows and another on Linux (over WSL), both using diffusers. To avoid having mutliple copies of the same model on disk, I try to make these two installations share a single diffusers model cache.
- Ubuntu on WSL, Python 3.10.12, cache is a symlink pointing to NTFS disk, ie. the hub folder is pointing to /mnt/c/python/StableDiffusion/hub
- Windows 11 Home, Python 3.11.6, cache resides on a folder in NTFS disk, ie. HF_HOME=C:\python\StableDiffusion
I am running the same python code in a Jupyter notebook that just loads a SDXL model using diffusers.
On the Windows installation, I will see the progress bar finishes quickly in 1-2 seconds.
On the Linux, the same progress bar would finish just as quickly, but there would be another 2 minutes of execution until the cell finishes. I manually interrupted the script and I could see some sort of additional data format conversion code was running.
I tried 2 additional set up:
- I cleared the shared cache, and let the Ubuntu installation re-download and re-create the cache (cache still resides on a symlink pointing to a NTFS folder) => still see the additional slow down after the model is fully downloaded. Windows installation could still function perfectly with cache created by Ubuntu, no conversion delay.
- I stop using symlink to NTFS folders and let the Ubuntu installation uses its own cache on its own file system => no more slow down. But I now have 2 copies of each models.
from diffusers import StableDiffusionXLPipeline, AutoencoderKL, DPMSolverMultistepScheduler
import torch
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16, variant="fp16",
use_safetensors=True
).to("cuda")This is the traceback for the slow model loading on Linux, I can see some sort of model data conversation seems to be executing when I interrupted the execution a minute in.
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[3], line 10
2 import torch
4 vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
5 pipe = StableDiffusionXLPipeline.from_pretrained(
6 "stabilityai/stable-diffusion-xl-base-1.0",
7 vae=vae,
8 torch_dtype=torch.float16, variant="fp16",
9 use_safetensors=True
---> 10 ).to("cuda")
File ~/stable_diffusion_venv/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py:733, in DiffusionPipeline.to(self, torch_device, torch_dtype, silence_dtype_warnings)
729 logger.warning(
730 f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
731 )
732 else:
--> 733 module.to(torch_device, torch_dtype)
735 if (
736 module.dtype == torch.float16
737 and str(torch_device) in ["cpu"]
738 and not silence_dtype_warnings
739 and not is_offloaded
740 ):
741 logger.warning(
742 "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
743 " is not recommended to move them to `cpu` as running them will fail. Please make"
(...)
746 " `torch_dtype=torch.float16` argument, or use another device for inference."
747 )
File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1160, in Module.to(self, *args, **kwargs)
1156 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
1157 non_blocking, memory_format=convert_to_format)
1158 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
-> 1160 return self._apply(convert)
File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:810, in Module._apply(self, fn, recurse)
808 if recurse:
809 for module in self.children():
--> 810 module._apply(fn)
812 def compute_should_use_set_data(tensor, tensor_applied):
813 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
814 # If the new tensor has compatible tensor type as the existing tensor,
815 # the current behavior is to change the tensor in-place using `.data =`,
(...)
820 # global flag to let the user control whether they want the future
821 # behavior of overwriting the existing tensor or not.
File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:810, in Module._apply(self, fn, recurse)
808 if recurse:
809 for module in self.children():
--> 810 module._apply(fn)
812 def compute_should_use_set_data(tensor, tensor_applied):
813 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
814 # If the new tensor has compatible tensor type as the existing tensor,
815 # the current behavior is to change the tensor in-place using `.data =`,
(...)
820 # global flag to let the user control whether they want the future
821 # behavior of overwriting the existing tensor or not.
[... skipping similar frames: Module._apply at line 810 (5 times)]
File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:810, in Module._apply(self, fn, recurse)
808 if recurse:
809 for module in self.children():
--> 810 module._apply(fn)
812 def compute_should_use_set_data(tensor, tensor_applied):
813 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
814 # If the new tensor has compatible tensor type as the existing tensor,
815 # the current behavior is to change the tensor in-place using `.data =`,
(...)
820 # global flag to let the user control whether they want the future
821 # behavior of overwriting the existing tensor or not.
File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:833, in Module._apply(self, fn, recurse)
829 # Tensors stored in modules are graph leaves, and we don't want to
830 # track autograd history of `param_applied`, so we have to use
831 # `with torch.no_grad():`
832 with torch.no_grad():
--> 833 param_applied = fn(param)
834 should_use_set_data = compute_should_use_set_data(param, param_applied)
835 if should_use_set_data:
File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1158, in Module.to.<locals>.convert(t)
1155 if convert_to_format is not None and t.dim() in (4, 5):
1156 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
1157 non_blocking, memory_format=convert_to_format)
-> 1158 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
KeyboardInterrupt:
Reproduction
- Have an Linux running on Windows WSL
- Use diffuers on Windows to load and cache a diffusion model (say, Windows HF_HOME is c:\somewhere, cache in c:\somewhere\hub)
- Use diffuers on Linux (on WSL) to load the same model, point Linux HF_HOME to the root of the VirtualEnv and the hub folder within HF_HOME is a symlink to a folder on NTFS (eg. /mnt/c/somewhere/hub)
- Diffusers should load the model from cache
Expect - model loads quickly from cache
Actual - some unexpected data conversion happens and delays model loading by ~2 minutes
Logs
No response
System Info
pip freeze on Linux
accelerate==0.23.0
anyio==4.0.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.0
async-lru==2.0.4
attrs==23.1.0
Babel==2.13.0
backcall==0.2.0
beautifulsoup4==4.12.2
bleach==6.1.0
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.0
comm==0.1.4
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
diffusers==0.21.4
exceptiongroup==1.1.3
executing==2.0.0
fastjsonschema==2.18.1
filelock==3.12.4
fqdn==1.5.1
fsspec==2023.9.2
huggingface-hub==0.17.3
idna==3.4
importlib-metadata==6.8.0
ipykernel==6.25.2
ipython==8.16.1
ipywidgets==8.1.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.19.1
jsonschema-specifications==2023.7.1
jupyter-events==0.8.0
jupyter-lsp==2.2.0
jupyter_client==8.4.0
jupyter_core==5.4.0
jupyter_server==2.8.0
jupyter_server_terminals==0.4.4
jupyterlab==4.0.7
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.9
jupyterlab_server==2.25.0
MarkupSafe==2.1.3
matplotlib-inline==0.1.6
mistune==3.0.2
mpmath==1.3.0
nbclient==0.8.0
nbconvert==7.9.2
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.2
notebook_shim==0.2.3
numpy==1.26.1
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.52
nvidia-nvtx-cu12==12.1.105
overrides==7.4.0
packaging==23.2
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==10.1.0
platformdirs==3.11.0
prometheus-client==0.17.1
prompt-toolkit==3.0.39
psutil==5.9.6
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
Pygments==2.16.1
python-dateutil==2.8.2
python-json-logger==2.0.7
PyYAML==6.0.1
pyzmq==25.1.1
referencing==0.30.2
regex==2023.10.3
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.10.6
safetensors==0.4.0
Send2Trash==1.8.2
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.3
sympy==1.12
terminado==0.17.1
tinycss2==1.2.1
tokenizers==0.14.1
tomli==2.0.1
torch==2.1.0
torchaudio==2.1.0
torchvision==0.16.0
tornado==6.3.3
tqdm==4.66.1
traitlets==5.11.2
transformers==4.34.1
triton==2.1.0
types-python-dateutil==2.8.19.14
typing_extensions==4.8.0
uri-template==1.3.0
urllib3==2.0.7
wcwidth==0.2.8
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.4
widgetsnbextension==4.0.9
zipp==3.17.0
pip freeze on Windows
absl-py==2.0.0
accelerate==0.23.0
antlr4-python3-runtime==4.9.3
anyio==4.0.0
appdirs==1.4.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.4.0
async-lru==2.0.4
attrs==23.1.0
Babel==2.12.1
backcall==0.2.0
beautifulsoup4==4.12.2
bitsandbytes==0.41.1
bleach==6.0.0
cachetools==5.3.1
certifi==2022.12.7
cffi==1.16.0
charset-normalizer==2.1.1
click==8.1.7
colorama==0.4.6
comm==0.1.4
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
diffusers==0.21.4
docker-pycreds==0.4.0
executing==1.2.0
fastjsonschema==2.18.0
filelock==3.9.0
fqdn==1.5.1
fsspec==2023.9.2
ftfy==6.1.1
gitdb==4.0.10
GitPython==3.1.37
google-auth==2.23.2
google-auth-oauthlib==1.0.0
grpcio==1.59.0
huggingface-hub==0.17.3
idna==3.4
importlib-metadata==6.8.0
ipykernel==6.25.2
ipython==8.15.0
ipython-genutils==0.2.0
ipywidgets==8.1.1
isoduration==20.11.0
jedi==0.19.0
Jinja2==3.1.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.19.1
jsonschema-specifications==2023.7.1
jupyter-events==0.7.0
jupyter-lsp==2.2.0
jupyter_client==8.3.1
jupyter_core==5.3.2
jupyter_server==2.7.3
jupyter_server_terminals==0.4.4
jupyterlab==4.0.6
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.9
jupyterlab_server==2.25.0
Markdown==3.5
MarkupSafe==2.1.2
matplotlib-inline==0.1.6
mistune==3.0.1
mpmath==1.3.0
nbclient==0.8.0
nbconvert==7.8.0
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.0
notebook==7.0.4
notebook_shim==0.2.3
numpy==1.24.1
nvidia-cublas-cu12==12.2.5.6
nvidia-cuda-nvrtc-cu12==12.2.140
nvidia-cuda-runtime-cu12==12.2.140
nvidia-cudnn-cu12==8.9.4.25
oauthlib==3.2.2
omegaconf==2.3.0
opencv-python==4.8.1.78
overrides==7.4.0
packaging==23.1
pandocfilters==1.5.0
parso==0.8.3
pathtools==0.1.2
pickleshare==0.7.5
Pillow==9.3.0
platformdirs==3.10.0
prometheus-client==0.17.1
prompt-toolkit==3.0.39
protobuf==4.24.4
psutil==5.9.5
pure-eval==0.2.2
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycparser==2.21
Pygments==2.16.1
python-dateutil==2.8.2
python-json-logger==2.0.7
pywin32==306
pywinpty==2.0.11
PyYAML==6.0.1
pyzmq==25.1.1
qtconsole==5.4.4
QtPy==2.4.0
referencing==0.30.2
regex==2023.8.8
requests==2.31.0
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.10.3
rsa==4.9
safetensors==0.3.3
scipy==1.11.3
Send2Trash==1.8.2
sentry-sdk==1.31.0
setproctitle==1.3.3
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.2
sympy==1.12
tensorboard==2.14.1
tensorboard-data-server==0.7.1
tensorrt-bindings==9.0.1.post12.dev4
tensorrt-libs==9.0.1.post12.dev4
terminado==0.17.1
tinycss2==1.2.1
tokenizers==0.13.3
torch==2.0.1+cu118
torchaudio==2.0.2+cu118
torchvision==0.15.2+cu118
tornado==6.3.3
tqdm==4.66.1
traitlets==5.10.1
transformers==4.33.3
typing_extensions==4.4.0
uri-template==1.3.0
urllib3==1.26.13
wandb==0.15.12
wcwidth==0.2.7
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.3
Werkzeug==3.0.0
widgetsnbextension==4.0.9
xformers==0.0.22
zipp==3.17.0