diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 8c9b43a20ad6..6f3662e01750 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -22,11 +22,7 @@ from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel from diffusers.utils import logging -from diffusers.utils.testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, -) +from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin @@ -305,6 +301,7 @@ def _set_gradient_checkpointing_new(self, module, value=False): assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + @is_flaky def test_forward_no_control(self): unet = self.get_dummy_unet() controlnet = self.get_dummy_controlnet_from_unet(unet)