From f8c2cbf63eae94614957287ac743f786fd980f1e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 Nov 2023 10:56:08 +0530 Subject: [PATCH 1/4] debug --- src/diffusers/models/transformer_2d.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 7c0cd12d1c67..fbfeb967543c 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -426,6 +426,7 @@ def forward( # unpatchify height = width = int(hidden_states.shape[1] ** 0.5) + print(f"height: {height}, width: {width}") hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) From 044384f7dad3ddd0bf2f4a34561592c5b1a617e0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 Nov 2023 11:04:29 +0530 Subject: [PATCH 2/4] support non-square images --- src/diffusers/models/transformer_2d.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index fbfeb967543c..24abf54d6da7 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -339,6 +339,7 @@ def forward( elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size hidden_states = self.pos_embed(hidden_states) if self.adaln_single is not None: @@ -425,8 +426,8 @@ def forward( hidden_states = hidden_states.squeeze(1) # unpatchify - height = width = int(hidden_states.shape[1] ** 0.5) - print(f"height: {height}, width: {width}") + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) From bcb81d9b56f60a5460169b1d8bb6a117a499db4b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 Nov 2023 11:10:15 +0530 Subject: [PATCH 3/4] add: test --- tests/pipelines/pixart/test_pixart.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py index 1797f7e0fec2..a8923234c5a0 100644 --- a/tests/pipelines/pixart/test_pixart.py +++ b/tests/pipelines/pixart/test_pixart.py @@ -174,13 +174,32 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - print(torch.from_numpy(image_slice.flatten())) self.assertEqual(image.shape, (1, 8, 8, 3)) expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) + def test_inference_non_square_images(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs, height=32, width=48).images + print(image.shape) + image_slice = image[0, -3:, -3:, -1] + slice = image_slice.flatten().tolist() + print(", ".join([str(round(x, 4)) for x in slice])) + + # self.assertEqual(image.shape, (1, 8, 8, 3)) + expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675]) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=1e-3) From f55e0d5abba7ed73fdf621f7e5a6a61c54cf459b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 Nov 2023 11:20:11 +0530 Subject: [PATCH 4/4] fix: test --- tests/pipelines/pixart/test_pixart.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py index a8923234c5a0..10e6e6e79244 100644 --- a/tests/pipelines/pixart/test_pixart.py +++ b/tests/pipelines/pixart/test_pixart.py @@ -190,13 +190,10 @@ def test_inference_non_square_images(self): inputs = self.get_dummy_inputs(device) image = pipe(**inputs, height=32, width=48).images - print(image.shape) image_slice = image[0, -3:, -3:, -1] - slice = image_slice.flatten().tolist() - print(", ".join([str(round(x, 4)) for x in slice])) - # self.assertEqual(image.shape, (1, 8, 8, 3)) - expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675]) + self.assertEqual(image.shape, (1, 32, 48, 3)) + expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3)