From 99fb2b424529914cbc2e0b4aa7da24562eec4a8b Mon Sep 17 00:00:00 2001 From: Miguel Farinha Date: Sat, 20 Jan 2024 01:32:30 +0000 Subject: [PATCH] allow resolutions not multiple of 64 in SVD --- src/diffusers/models/unet_3d_blocks.py | 6 +++-- .../models/unet_spatio_temporal_condition.py | 23 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index e9c505c347b0..c8f076a628f4 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -2231,6 +2231,7 @@ def forward( hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: for resnet in self.resnets: @@ -2272,7 +2273,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -2341,6 +2342,7 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: for resnet, attn in zip(self.resnets, self.attentions): @@ -2390,6 +2392,6 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states diff --git a/src/diffusers/models/unet_spatio_temporal_condition.py b/src/diffusers/models/unet_spatio_temporal_condition.py index 8d0d3e61d879..8c035b31e52a 100644 --- a/src/diffusers/models/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unet_spatio_temporal_condition.py @@ -381,6 +381,20 @@ def forward( If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + # 1. time timesteps = timestep if not torch.is_tensor(timesteps): @@ -456,15 +470,23 @@ def forward( # 5. up for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, image_only_indicator=image_only_indicator, ) else: @@ -472,6 +494,7 @@ def forward( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, image_only_indicator=image_only_indicator, )