Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_auto_class = kwargs.pop("_from_auto", False)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)

user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
Expand All @@ -334,6 +335,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder,
**kwargs,
)

if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}.")
elif torch_dtype is not None:
model = model.to(torch_dtype)

model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)

# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
Expand Down Expand Up @@ -237,12 +238,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

load_method = getattr(class_obj, load_method_name)

loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype

# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name))
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder)
loaded_sub_model = load_method(cached_folder, **loading_kwargs)

init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)

Expand Down