From aa83563586fb479af442f9547cfe3f5fb4df394c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 22 Aug 2022 16:05:33 +0000 Subject: [PATCH] [Loading] allow modules to be loaded in fp16 --- src/diffusers/modeling_utils.py | 7 +++++++ src/diffusers/pipeline_utils.py | 9 +++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 3bbc298c6a26..7401526dfde7 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -315,6 +315,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_auto_class = kwargs.pop("_from_auto", False) + torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} @@ -334,6 +335,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P subfolder=subfolder, **kwargs, ) + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError(f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}.") + elif torch_dtype is not None: + model = model.to(torch_dtype) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # Load model diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 6cb98d7c9b60..f27789f0bfa8 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -146,6 +146,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -237,12 +238,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P load_method = getattr(class_obj, load_method_name) + loading_kwargs = {} + if issubclass(class_obj, torch.nn.Module): + loading_kwargs["torch_dtype"] = torch_dtype + # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): - loaded_sub_model = load_method(os.path.join(cached_folder, name)) + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) else: # else load from the root directory - loaded_sub_model = load_method(cached_folder) + loaded_sub_model = load_method(cached_folder, **loading_kwargs) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)