diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index ea71ae4af26c..729367a0c164 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -88,19 +88,17 @@ def output_shape(self): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "block_out_channels": (32, 64, 64, 64), + "block_out_channels": (32, 64), "down_block_types": ( - "CrossAttnDownBlock3D", - "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), - "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"), "cross_attention_dim": 32, - "attention_head_dim": 4, + "attention_head_dim": 8, "out_channels": 4, "in_channels": 4, - "layers_per_block": 2, + "layers_per_block": 1, "sample_size": 32, } inputs_dict = self.dummy_input