From cd743e1d5c42bdb73d0455f69723bbaa66579cd3 Mon Sep 17 00:00:00 2001 From: Goat Date: Sun, 28 Jan 2024 22:40:25 +0800 Subject: [PATCH] fix #6742 --- src/diffusers/models/unets/unet_3d_blocks.py | 31 +++++++------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 6c20b1175349..765ad994246a 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -1031,16 +1031,10 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, scale ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(motion_module), - hidden_states.requires_grad_(), - temb, - num_frames, - ) else: hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] output_states = output_states + (hidden_states,) @@ -1221,10 +1215,10 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: @@ -1425,10 +1419,10 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1563,15 +1557,10 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - ) else: hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] if self.upsamplers is not None: for upsampler in self.upsamplers: