From b0b58d0fc69255f8dab0448421b41b21e441f9e8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 25 Apr 2024 07:03:42 +0530 Subject: [PATCH] decorate UNetControlNetXSModelTests::test_forward_no_control with is_flaky --- tests/models/unets/test_models_unet_controlnetxs.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 8c9b43a20ad6..6f3662e01750 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -22,11 +22,7 @@ from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel from diffusers.utils import logging -from diffusers.utils.testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, -) +from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin @@ -305,6 +301,7 @@ def _set_gradient_checkpointing_new(self, module, value=False): assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + @is_flaky def test_forward_no_control(self): unet = self.get_dummy_unet() controlnet = self.get_dummy_controlnet_from_unet(unet)