diff --git a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py index 7eb830cd5097..2476ab92f77a 100644 --- a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py @@ -35,13 +35,14 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AsymmetricAutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AsymmetricAutoencoderKL main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/autoencoders/test_models_autoencoder_cosmos.py b/tests/models/autoencoders/test_models_autoencoder_cosmos.py index ceccc2364e26..5898ae776a1b 100644 --- a/tests/models/autoencoders/test_models_autoencoder_cosmos.py +++ b/tests/models/autoencoders/test_models_autoencoder_cosmos.py @@ -17,13 +17,14 @@ from diffusers import AutoencoderKLCosmos from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLCosmosTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLCosmos main_input_name = "sample" base_precision = 1e-2 @@ -80,7 +81,3 @@ def test_gradient_checkpointing_is_applied(self): @unittest.skip("Not sure why this test fails. Investigate later.") def test_effective_gradient_checkpointing(self): pass - - @unittest.skip("Unsupported test.") - def test_forward_with_norm_groups(self): - pass diff --git a/tests/models/autoencoders/test_models_autoencoder_dc.py b/tests/models/autoencoders/test_models_autoencoder_dc.py index 56f172f1c869..a6912f3ebab7 100644 --- a/tests/models/autoencoders/test_models_autoencoder_dc.py +++ b/tests/models/autoencoders/test_models_autoencoder_dc.py @@ -22,13 +22,14 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderDC main_input_name = "sample" base_precision = 1e-2 @@ -81,7 +82,3 @@ def prepare_init_args_and_inputs_for_common(self): init_dict = self.get_autoencoder_dc_config() inputs_dict = self.dummy_input return init_dict, inputs_dict - - @unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.") - def test_forward_with_norm_groups(self): - pass diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py index 6f91f8bfa91b..9813772a7c55 100644 --- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -20,18 +20,15 @@ from diffusers import AutoencoderKLHunyuanVideo from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask -from ...testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ...testing_utils import enable_full_determinism, floats_tensor, torch_device +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLHunyuanVideo main_input_name = "sample" base_precision = 1e-2 @@ -87,68 +84,6 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) - - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.5, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - def test_gradient_checkpointing_is_applied(self): expected_set = { "HunyuanVideoDecoder3D", diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py index 662a3f1b80b7..5f11c6cb0ab3 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl.py @@ -35,13 +35,14 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKL main_input_name = "sample" base_precision = 1e-2 @@ -83,68 +84,6 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) - - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.5, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - def test_gradient_checkpointing_is_applied(self): expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py index 739daf2a492d..b6d59489d9c6 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py @@ -24,13 +24,14 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLCogVideoX main_input_name = "sample" base_precision = 1e-2 @@ -82,68 +83,6 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) - - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.5, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - def test_gradient_checkpointing_is_applied(self): expected_set = { "CogVideoXDownBlock3D", diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py index 6cb427bff8e1..93f40f44a919 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py @@ -22,13 +22,14 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLTemporalDecoder main_input_name = "sample" base_precision = 1e-2 @@ -67,7 +68,3 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - @unittest.skip("Test unsupported.") - def test_forward_with_norm_groups(self): - pass diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py index 21ab3896c890..527be1b4ecb5 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py @@ -24,13 +24,14 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLLTXVideo main_input_name = "sample" base_precision = 1e-2 @@ -99,7 +100,7 @@ def test_forward_with_norm_groups(self): pass -class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, unittest.TestCase): model_class = AutoencoderKLLTXVideo main_input_name = "sample" base_precision = 1e-2 @@ -167,34 +168,3 @@ def test_outputs_equivalence(self): @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") def test_forward_with_norm_groups(self): pass - - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) diff --git a/tests/models/autoencoders/test_models_autoencoder_magvit.py b/tests/models/autoencoders/test_models_autoencoder_magvit.py index 58cbfc05bd03..f7304df14048 100644 --- a/tests/models/autoencoders/test_models_autoencoder_magvit.py +++ b/tests/models/autoencoders/test_models_autoencoder_magvit.py @@ -18,13 +18,14 @@ from diffusers import AutoencoderKLMagvit from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLMagvit main_input_name = "sample" base_precision = 1e-2 @@ -88,3 +89,9 @@ def test_effective_gradient_checkpointing(self): @unittest.skip("Unsupported test.") def test_forward_with_norm_groups(self): pass + + @unittest.skip( + "Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list." + ) + def test_enable_disable_slicing(self): + pass diff --git a/tests/models/autoencoders/test_models_autoencoder_mochi.py b/tests/models/autoencoders/test_models_autoencoder_mochi.py index b8c5aaaa1eb6..ab8d429a67f6 100755 --- a/tests/models/autoencoders/test_models_autoencoder_mochi.py +++ b/tests/models/autoencoders/test_models_autoencoder_mochi.py @@ -17,18 +17,15 @@ from diffusers import AutoencoderKLMochi -from ...testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ...testing_utils import enable_full_determinism, floats_tensor, torch_device +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLMochiTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLMochi main_input_name = "sample" base_precision = 1e-2 @@ -79,14 +76,6 @@ def test_gradient_checkpointing_is_applied(self): } super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - @unittest.skip("Unsupported test.") - def test_forward_with_norm_groups(self): - """ - tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups - - TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups' - """ - pass - @unittest.skip("Unsupported test.") def test_model_parallelism(self): """ diff --git a/tests/models/autoencoders/test_models_autoencoder_oobleck.py b/tests/models/autoencoders/test_models_autoencoder_oobleck.py index eb7bd50f4a54..d10e8ba33a12 100644 --- a/tests/models/autoencoders/test_models_autoencoder_oobleck.py +++ b/tests/models/autoencoders/test_models_autoencoder_oobleck.py @@ -30,13 +30,14 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderOobleckTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderOobleck main_input_name = "sample" base_precision = 1e-2 @@ -106,10 +107,6 @@ def test_enable_disable_slicing(self): "Without slicing outputs should match with the outputs when slicing is manually disabled.", ) - @unittest.skip("Test unsupported.") - def test_forward_with_norm_groups(self): - pass - @unittest.skip("No attention module used in this model") def test_set_attn_processor_for_determinism(self): return diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py index 4d1dc69cfaad..68232aa12fdf 100644 --- a/tests/models/autoencoders/test_models_autoencoder_tiny.py +++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py @@ -31,13 +31,14 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderTinyTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderTiny main_input_name = "sample" base_precision = 1e-2 @@ -81,37 +82,6 @@ def prepare_init_args_and_inputs_for_common(self): def test_enable_disable_tiling(self): pass - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict)[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict)[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.5, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict)[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - @unittest.skip("Test not supported.") def test_outputs_equivalence(self): pass diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py index cc9c88868157..051098dc7aac 100644 --- a/tests/models/autoencoders/test_models_autoencoder_wan.py +++ b/tests/models/autoencoders/test_models_autoencoder_wan.py @@ -15,18 +15,17 @@ import unittest -import torch - from diffusers import AutoencoderKLWan from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = AutoencoderKLWan main_input_name = "sample" base_precision = 1e-2 @@ -76,68 +75,6 @@ def prepare_init_args_and_inputs_for_tiling(self): inputs_dict = self.dummy_input_tiling return init_dict, inputs_dict - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling(96, 96, 64, 64) - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) - - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.05, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - @unittest.skip("Gradient checkpointing has not been implemented yet") def test_gradient_checkpointing_is_applied(self): pass diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py index 7e44edba3624..ef04d151ecd1 100644 --- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py +++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py @@ -31,12 +31,13 @@ torch_device, ) from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): +class ConsistencyDecoderVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = ConsistencyDecoderVAE main_input_name = "sample" base_precision = 1e-2 @@ -92,70 +93,6 @@ def init_dict(self): def prepare_init_args_and_inputs_for_common(self): return self.init_dict, self.inputs_dict() - def test_enable_disable_tiling(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - _ = inputs_dict.pop("generator") - - torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), - 0.5, - "VAE tiling should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_tiling.detach().cpu().numpy().all(), - output_without_tiling_2.detach().cpu().numpy().all(), - "Without tiling outputs should match with the outputs when tiling is manually disabled.", - ) - - def test_enable_disable_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict).to(torch_device) - - inputs_dict.update({"return_dict": False}) - _ = inputs_dict.pop("generator") - - torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - torch.manual_seed(0) - model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertLess( - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), - 0.5, - "VAE slicing should not affect the inference results", - ) - - torch.manual_seed(0) - model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] - - self.assertEqual( - output_without_slicing.detach().cpu().numpy().all(), - output_without_slicing_2.detach().cpu().numpy().all(), - "Without slicing outputs should match with the outputs when slicing is manually disabled.", - ) - @slow class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index 1c636b081733..b88d24d1f2d8 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -19,19 +19,15 @@ from diffusers import VQModel -from ...testing_utils import ( - backend_manual_seed, - enable_full_determinism, - floats_tensor, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ...testing_utils import backend_manual_seed, enable_full_determinism, floats_tensor, torch_device +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin enable_full_determinism() -class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class VQModelTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): model_class = VQModel main_input_name = "sample" diff --git a/tests/models/autoencoders/testing_utils.py b/tests/models/autoencoders/testing_utils.py new file mode 100644 index 000000000000..cf1f10a4a545 --- /dev/null +++ b/tests/models/autoencoders/testing_utils.py @@ -0,0 +1,142 @@ +import inspect + +import numpy as np +import pytest +import torch + +from diffusers.models.autoencoders.vae import DecoderOutput +from diffusers.utils.torch_utils import torch_device + + +class AutoencoderTesterMixin: + """ + Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks + usually don't do slicing and tiling. + """ + + @staticmethod + def _accepts_generator(model): + model_sig = inspect.signature(model.forward) + accepts_generator = "generator" in model_sig.parameters + return accepts_generator + + @staticmethod + def _accepts_norm_num_groups(model_class): + model_sig = inspect.signature(model_class.__init__) + accepts_norm_groups = "norm_num_groups" in model_sig.parameters + return accepts_norm_groups + + def test_forward_with_norm_groups(self): + if not self._accepts_norm_num_groups(self.model_class): + pytest.skip(f"Test not supported for {self.model_class.__name__}") + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 32) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_enable_disable_tiling(self): + if not hasattr(self.model_class, "enable_tiling"): + pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.") + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + _ = inputs_dict.pop("generator", None) + accepts_generator = self._accepts_generator(model) + + torch.manual_seed(0) + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_tiling = model(**inputs_dict)[0] + # Mochi-1 + if isinstance(output_without_tiling, DecoderOutput): + output_without_tiling = output_without_tiling.sample + + torch.manual_seed(0) + model.enable_tiling() + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_with_tiling = model(**inputs_dict)[0] + if isinstance(output_with_tiling, DecoderOutput): + output_with_tiling = output_with_tiling.sample + + assert ( + output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy() + ).max() < 0.5, "VAE tiling should not affect the inference results" + + torch.manual_seed(0) + model.disable_tiling() + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_tiling_2 = model(**inputs_dict)[0] + if isinstance(output_without_tiling_2, DecoderOutput): + output_without_tiling_2 = output_without_tiling_2.sample + + assert np.allclose( + output_without_tiling.detach().cpu().numpy().all(), + output_without_tiling_2.detach().cpu().numpy().all(), + ), "Without tiling outputs should match with the outputs when tiling is manually disabled." + + def test_enable_disable_slicing(self): + if not hasattr(self.model_class, "enable_slicing"): + pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.") + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + _ = inputs_dict.pop("generator", None) + accepts_generator = self._accepts_generator(model) + + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + + torch.manual_seed(0) + output_without_slicing = model(**inputs_dict)[0] + # Mochi-1 + if isinstance(output_without_slicing, DecoderOutput): + output_without_slicing = output_without_slicing.sample + + torch.manual_seed(0) + model.enable_slicing() + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_with_slicing = model(**inputs_dict)[0] + if isinstance(output_with_slicing, DecoderOutput): + output_with_slicing = output_with_slicing.sample + + assert ( + output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy() + ).max() < 0.5, "VAE slicing should not affect the inference results" + + torch.manual_seed(0) + model.disable_slicing() + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_slicing_2 = model(**inputs_dict)[0] + if isinstance(output_without_slicing_2, DecoderOutput): + output_without_slicing_2 = output_without_slicing_2.sample + + assert np.allclose( + output_without_slicing.detach().cpu().numpy().all(), + output_without_slicing_2.detach().cpu().numpy().all(), + ), "Without slicing outputs should match with the outputs when slicing is manually disabled." diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index a44ef571c5be..6f4c3d544b45 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -450,7 +450,15 @@ def get_dummy_inputs(): class UNetTesterMixin: + @staticmethod + def _accepts_norm_num_groups(model_class): + model_sig = inspect.signature(model_class.__init__) + accepts_norm_groups = "norm_num_groups" in model_sig.parameters + return accepts_norm_groups + def test_forward_with_norm_groups(self): + if not self._accepts_norm_num_groups(self.model_class): + pytest.skip(f"Test not supported for {self.model_class.__name__}") init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 16