diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index a1d9e848c230..1495ae54ee82 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -1031,16 +1031,10 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, scale ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(motion_module), - hidden_states.requires_grad_(), - temb, - num_frames, - ) else: hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] output_states = output_states + (hidden_states,) @@ -1221,10 +1215,10 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: @@ -1425,10 +1419,10 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1563,15 +1557,10 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - ) else: hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] if self.upsamplers is not None: for upsampler in self.upsamplers: