diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 7c0cd12d1c67..24abf54d6da7 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -339,6 +339,7 @@ def forward( elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size hidden_states = self.pos_embed(hidden_states) if self.adaln_single is not None: @@ -425,7 +426,8 @@ def forward( hidden_states = hidden_states.squeeze(1) # unpatchify - height = width = int(hidden_states.shape[1] ** 0.5) + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py index 2e033896d900..70548092fedf 100644 --- a/tests/pipelines/pixart/test_pixart.py +++ b/tests/pipelines/pixart/test_pixart.py @@ -174,13 +174,29 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - print(torch.from_numpy(image_slice.flatten())) self.assertEqual(image.shape, (1, 8, 8, 3)) expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) + def test_inference_non_square_images(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs, height=32, width=48).images + image_slice = image[0, -3:, -3:, -1] + + self.assertEqual(image.shape, (1, 32, 48, 3)) + expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416]) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + def test_inference_with_embeddings_and_multiple_images(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components)