diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index d7bf67288c0a..f6e6373ce035 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -555,7 +555,4 @@ def load_module(name, value): pipe = pipeline_class(**init_kwargs) - if torch_dtype is not None: - pipe.to(dtype=torch_dtype) - return pipe diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index ff076c82b00b..c58251139c49 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -1808,4 +1808,17 @@ def create_diffusers_t5_model_from_checkpoint( else: model.load_state_dict(diffusers_format_checkpoint) + + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = model._keep_in_fp32_modules + else: + keep_in_fp32_modules = [] + + if keep_in_fp32_modules is not None: + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): + # param = param.to(torch.float32) does not work here as only in the local scope. + param.data = param.data.to(torch.float32) + return model diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py index 5157cd4ca63c..b2bb7fe827f9 100644 --- a/tests/single_file/single_file_testing_utils.py +++ b/tests/single_file/single_file_testing_utils.py @@ -201,6 +201,20 @@ def test_single_file_components_with_diffusers_config_local_files_only( self._compare_component_configs(pipe, single_file_pipe) + def test_single_file_setting_pipeline_dtype_to_fp16( + self, + single_file_pipe=None, + ): + single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( + self.ckpt_path, torch_dtype=torch.float16 + ) + + for component_name, component in single_file_pipe.components.items(): + if not isinstance(component, torch.nn.Module): + continue + + assert component.dtype == torch.float16 + class SDXLSingleFileTesterMixin: def _compare_component_configs(self, pipe, single_file_pipe): @@ -378,3 +392,17 @@ def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_d max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) assert max_diff < expected_max_diff + + def test_single_file_setting_pipeline_dtype_to_fp16( + self, + single_file_pipe=None, + ): + single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( + self.ckpt_path, torch_dtype=torch.float16 + ) + + for component_name, component in single_file_pipe.components.items(): + if not isinstance(component, torch.nn.Module): + continue + + assert component.dtype == torch.float16 diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py index 8e9ac7973609..1af3f5126ff3 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py @@ -180,3 +180,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): local_files_only=True, ) super()._compare_component_configs(pipe, pipe_single_file) + + def test_single_file_setting_pipeline_dtype_to_fp16(self): + controlnet = ControlNetModel.from_pretrained( + "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16" + ) + single_file_pipe = self.pipeline_class.from_single_file( + self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16 + ) + super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe) diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py index 8c750437f719..1966ecfc207a 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py @@ -181,3 +181,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): local_files_only=True, ) super()._compare_component_configs(pipe, pipe_single_file) + + def test_single_file_setting_pipeline_dtype_to_fp16(self): + controlnet = ControlNetModel.from_pretrained( + "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16" + ) + single_file_pipe = self.pipeline_class.from_single_file( + self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16 + ) + super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe) diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py index abcf4c11d614..fe066f02cf36 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py @@ -169,3 +169,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): local_files_only=True, ) super()._compare_component_configs(pipe, pipe_single_file) + + def test_single_file_setting_pipeline_dtype_to_fp16(self): + controlnet = ControlNetModel.from_pretrained( + "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16" + ) + single_file_pipe = self.pipeline_class.from_single_file( + self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16 + ) + super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe) diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py index 43881914d3c0..7f478133c66f 100644 --- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py @@ -200,3 +200,11 @@ def test_single_file_components_with_original_config_local_files_only(self): local_files_only=True, ) self._compare_component_configs(pipe, pipe_single_file) + + def test_single_file_setting_pipeline_dtype_to_fp16(self): + adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16) + + single_file_pipe = self.pipeline_class.from_single_file( + self.ckpt_path, adapter=adapter, torch_dtype=torch.float16 + ) + super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe) diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py index 6aebc2b01999..a8509510ad80 100644 --- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py @@ -195,3 +195,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): local_files_only=True, ) super()._compare_component_configs(pipe, pipe_single_file) + + def test_single_file_setting_pipeline_dtype_to_fp16(self): + controlnet = ControlNetModel.from_pretrained( + "diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16" + ) + single_file_pipe = self.pipeline_class.from_single_file( + self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16 + ) + super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)