diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py index a53758ce2808..d02ff7429046 100644 --- a/tests/pipelines/lumina/test_lumina_nextdit.py +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -34,19 +34,19 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM def get_dummy_components(self): torch.manual_seed(0) transformer = LuminaNextDiT2DModel( - sample_size=16, + sample_size=4, patch_size=2, in_channels=4, - hidden_size=24, + hidden_size=4, num_layers=2, - num_attention_heads=3, + num_attention_heads=1, num_kv_heads=1, multiple_of=16, ffn_dim_multiplier=None, norm_eps=1e-5, learn_sigma=True, qk_norm=True, - cross_attention_dim=32, + cross_attention_dim=8, scaling_factor=1.0, ) torch.manual_seed(0) @@ -57,8 +57,8 @@ def get_dummy_components(self): torch.manual_seed(0) config = GemmaConfig( - head_dim=4, - hidden_size=32, + head_dim=2, + hidden_size=8, intermediate_size=37, num_attention_heads=4, num_hidden_layers=2,