From 0a251303aa53a393fc5a83d5a9f4057ce924c2fb Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 23 Jan 2024 12:27:25 +0530 Subject: [PATCH 1/2] refactor decode latents in video pipelines --- .../pipelines/animatediff/pipeline_animatediff.py | 13 +------------ .../pipeline_text_to_video_synth.py | 13 +------------ .../pipeline_text_to_video_synth_img2img.py | 13 +------------ 3 files changed, 3 insertions(+), 36 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 25b664fe3b29..23ea3eefc615 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -440,18 +440,7 @@ def decode_latents(self, latents): latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample - video = ( - image[None, :] - .reshape( - ( - batch_size, - num_frames, - -1, - ) - + image.shape[2:] - ) - .permute(0, 2, 1, 3, 4) - ) + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 6e5db85c9e66..bb8164a223cb 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -384,18 +384,7 @@ def decode_latents(self, latents): latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample - video = ( - image[None, :] - .reshape( - ( - batch_size, - num_frames, - -1, - ) - + image.shape[2:] - ) - .permute(0, 2, 1, 3, 4) - ) + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index c781e490caae..a21bb5e3de59 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -461,18 +461,7 @@ def decode_latents(self, latents): latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample - video = ( - image[None, :] - .reshape( - ( - batch_size, - num_frames, - -1, - ) - + image.shape[2:] - ) - .permute(0, 2, 1, 3, 4) - ) + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video From 406800465fa7e7226947355ee1bdd671607f062e Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 9 Feb 2024 07:06:19 +0530 Subject: [PATCH 2/2] make fix-copies --- .../animatediff/pipeline_animatediff_video2video.py | 13 +------------ src/diffusers/pipelines/pia/pipeline_pia.py | 13 +------------ 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 3d01009cbac7..02036783e609 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -445,18 +445,7 @@ def decode_latents(self, latents): latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample - video = ( - image[None, :] - .reshape( - ( - batch_size, - num_frames, - -1, - ) - + image.shape[2:] - ) - .permute(0, 2, 1, 3, 4) - ) + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 077de49cdc87..2205ec37b494 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -494,18 +494,7 @@ def decode_latents(self, latents): latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample - video = ( - image[None, :] - .reshape( - ( - batch_size, - num_frames, - -1, - ) - + image.shape[2:] - ) - .permute(0, 2, 1, 3, 4) - ) + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video