From ac1fe1d51173e67c9cfb236e8fb9b7e933c2adab Mon Sep 17 00:00:00 2001 From: 1lint <105617163+1lint@users.noreply.github.com> Date: Tue, 7 Mar 2023 11:06:50 -0800 Subject: [PATCH] add test_to_dtype to check pipe.to(fp16) --- src/diffusers/pipelines/pipeline_utils.py | 11 ++++++++--- tests/test_pipelines_common.py | 17 +++++++++++++---- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 65b348d2e7d3..3784b4ccecee 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -342,8 +342,13 @@ def is_saveable_module(name, value): save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) - def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): - if torch_device is None: + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + silence_dtype_warnings: bool = False, + ): + if torch_device is None and torch_dtype is None: return self # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. @@ -380,6 +385,7 @@ def module_is_offloaded(module): for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): + module.to(torch_device, torch_dtype) if ( module.dtype == torch.float16 and str(torch_device) in ["cpu"] @@ -393,7 +399,6 @@ def module_is_offloaded(module): " support for`float16` operations on this device in PyTorch. Please, remove the" " `torch_dtype=torch.float16` argument, or use another device for inference." ) - module.to(torch_device) return self @property diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 986770bedea6..96d54b23efb5 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -344,11 +344,8 @@ def test_float16_inference(self): pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - for name, module in components.items(): - if hasattr(module, "half"): - components[name] = module.half() pipe_fp16 = self.pipeline_class(**components) - pipe_fp16.to(torch_device) + pipe_fp16.to(torch_device, torch.float16) pipe_fp16.set_progress_bar_config(disable=None) output = pipe(**self.get_dummy_inputs(torch_device))[0] @@ -447,6 +444,18 @@ def test_to_device(self): output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] self.assertTrue(np.isnan(output_cuda).sum() == 0) + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(torch_dtype=torch.float16) + model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass()