From b938d300fda21a824d8951b869ad787196aa396c Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 17 Oct 2025 10:01:22 +0800 Subject: [PATCH] adjust unit tests for wan pipeline Signed-off-by: Liu, Kaixuan --- tests/pipelines/test_pipelines_common.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index db8209835be4..022262a8eefe 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1438,19 +1438,19 @@ def test_save_load_float16(self, expected_max_diff=1e-2): with tempfile.TemporaryDirectory() as tmpdir: pipe.save_pretrained(tmpdir) pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16) - for component in pipe_loaded.components.values(): + for name, component in pipe_loaded.components.items(): if hasattr(component, "set_default_attn_processor"): component.set_default_attn_processor() - pipe_loaded.to(torch_device) + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) + if hasattr(component, "half"): + # Although all components for pipe_loaded should be float16 now, some submodules still use fp32, like in https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/models/t5/modeling_t5.py#L783, so we need to do the conversion again manally to align with the datatype we use in pipe exactly + component = component.to(torch_device).half() pipe_loaded.set_progress_bar_config(disable=None) - for name, component in pipe_loaded.components.items(): - if hasattr(component, "dtype"): - self.assertTrue( - component.dtype == torch.float16, - f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", - ) - inputs = self.get_dummy_inputs(torch_device) output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()