From 7438d798c2c4542bb5e8180e453488c8aee8dd0e Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Tue, 14 Mar 2023 07:14:34 +0800 Subject: [PATCH] AutoencoderKL: clamp indices of blend_h and blend_v to input size --- src/diffusers/models/autoencoder_kl.py | 4 ++-- tests/pipelines/stable_diffusion/test_stable_diffusion.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 9cb0a4b2432b..3ee0c56796fe 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -190,12 +190,12 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode return DecoderOutput(sample=decoded) def blend_v(self, a, b, blend_extent): - for y in range(blend_extent): + for y in range(min(a.shape[2], b.shape[2], blend_extent)): b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b def blend_h(self, a, b, blend_extent): - for x in range(blend_extent): + for x in range(min(a.shape[3], b.shape[3], blend_extent)): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index d4fd30458373..4d4f680dbb1d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -445,6 +445,12 @@ def test_stable_diffusion_vae_tiling(self): assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 5e-1 + # test that tiled decode works with various shapes + shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)] + for shape in shapes: + zeros = torch.zeros(shape).to(device) + sd_pipe.vae.decode(zeros) + def test_stable_diffusion_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components()