Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from huggingface_hub.utils import validate_hf_hub_args
from safetensors import safe_open

from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
_get_model_file,
is_accelerate_available,
Expand Down Expand Up @@ -182,7 +182,7 @@ def load_ip_adapter(
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch import nn

from .. import __version__
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
Expand Down Expand Up @@ -281,7 +281,7 @@ def lora_state_dict(
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn

from ..models.modeling_utils import load_state_dict
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging


Expand Down Expand Up @@ -100,7 +101,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
Expand Down Expand Up @@ -214,7 +214,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict

Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
else:
return torch.load(checkpoint_file, map_location="cpu")
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
checkpoint_file,
map_location="cpu",
**weights_only_kwarg,
)
except Exception as e:
try:
with open(checkpoint_file) as f:
Expand Down