From f4ce63361c1fc156f5b781af955a466585326fa4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Sep 2024 09:51:34 +0530 Subject: [PATCH 1/6] fix positional arguments in check_inputs(). --- .../cogvideo/pipeline_cogvideox_video2video.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 16686d1ab7ac..0aa0c5b1e65a 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -657,14 +657,14 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, - height, - width, - strength, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds, - negative_prompt_embeds, + prompt=prompt, + height=height, + width=width, + strength=strength, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) self._guidance_scale = guidance_scale self._interrupt = False From 588d75922c8d0f1234cf18ec208913ab3cef559f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Sep 2024 09:56:28 +0530 Subject: [PATCH 2/6] add video and latetns to check_inputs(). --- .../pipelines/cogvideo/pipeline_cogvideox_video2video.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 0aa0c5b1e65a..627503a95c7d 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -663,6 +663,8 @@ def __call__( strength=strength, negative_prompt=negative_prompt, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + video=video, + latents=latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) From 37c89224ccaee687ca1b2f55fab270e2ccd792e8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Sep 2024 10:12:01 +0530 Subject: [PATCH 3/6] prep latents_in_channels. --- .../pipelines/cogvideo/pipeline_cogvideox_video2video.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 627503a95c7d..c744cabcef79 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -206,6 +206,9 @@ def __init__( self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) + self.latents_in_channels = ( + self.transformer.config.in_channels if hasattr(self, "transformer") and self.transformer is not None else 16 + ) self.vae_scale_factor_spatial = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) @@ -711,7 +714,7 @@ def __call__( video = self.video_processor.preprocess_video(video, height=height, width=width) video = video.to(device=device, dtype=prompt_embeds.dtype) - latent_channels = self.transformer.config.in_channels + latent_channels = self.latents_in_channels latents = self.prepare_latents( video, batch_size * num_videos_per_prompt, From 4b0dc80c7c87c6742342be26514de02643bd54b4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Sep 2024 10:16:21 +0530 Subject: [PATCH 4/6] quality --- .../pipelines/cogvideo/pipeline_cogvideox_video2video.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index c744cabcef79..2db9d00af192 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -207,7 +207,9 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.latents_in_channels = ( - self.transformer.config.in_channels if hasattr(self, "transformer") and self.transformer is not None else 16 + self.transformer.config.in_channels + if hasattr(self, "transformer") and self.transformer is not None + else 16 ) self.vae_scale_factor_spatial = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 From 24b83a6a5de517b1da6afde52186de77a1ef440d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Sep 2024 12:56:32 +0530 Subject: [PATCH 5/6] multiple fixes. --- .../pipelines/cogvideo/pipeline_cogvideox.py | 21 ++++---- .../pipeline_cogvideox_image2video.py | 52 +++++++++---------- .../pipeline_cogvideox_video2video.py | 30 +++++------ 3 files changed, 47 insertions(+), 56 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 3af47c177437..2206b1e63592 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -316,6 +316,12 @@ def encode_prompt( def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + shape = ( batch_size, (num_frames - 1) // self.vae_scale_factor_temporal + 1, @@ -323,11 +329,6 @@ def prepare_latents( height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -504,10 +505,10 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. num_frames (`int`, defaults to `48`): Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where @@ -577,8 +578,6 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial - width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index a1576be97977..afc11bce00d5 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -207,6 +207,9 @@ def __init__( self.vae_scale_factor_temporal = ( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -348,6 +351,12 @@ def prepare_latents( generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 shape = ( batch_size, @@ -357,12 +366,6 @@ def prepare_latents( width // self.vae_scale_factor_spatial, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - image = image.unsqueeze(2) # [B, C, F, H, W] if isinstance(generator, list): @@ -373,7 +376,7 @@ def prepare_latents( image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] - image_latents = self.vae.config.scaling_factor * image_latents + image_latents = self.vae_scaling_factor_image * image_latents padding_shape = ( batch_size, @@ -397,7 +400,7 @@ def prepare_latents( # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] - latents = 1 / self.vae.config.scaling_factor * latents + latents = 1 / self.vae_scaling_factor_image * latents frames = self.vae.decode(latents).sample return frames @@ -438,7 +441,6 @@ def check_inputs( width, negative_prompt, callback_on_step_end_tensor_inputs, - video=None, latents=None, prompt_embeds=None, negative_prompt_embeds=None, @@ -494,9 +496,6 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - if video is not None and latents is not None: - raise ValueError("Only one of `video` or `latents` should be provided") - # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections def fuse_qkv_projections(self) -> None: r"""Enables fused QKV projections.""" @@ -584,7 +583,7 @@ def __call__( Args: image (`PipelineImageInput`): - The input video to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -592,10 +591,10 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. num_frames (`int`, defaults to `48`): Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where @@ -665,20 +664,19 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial - width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( - image, - prompt, - height, - width, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds, - negative_prompt_embeds, + image=image, + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) self._guidance_scale = guidance_scale self._interrupt = False diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 2db9d00af192..4c3e0196e0e1 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -206,11 +206,7 @@ def __init__( self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.latents_in_channels = ( - self.transformer.config.in_channels - if hasattr(self, "transformer") and self.transformer is not None - else 16 - ) + self.vae_scale_factor_spatial = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) @@ -358,6 +354,12 @@ def prepare_latents( latents: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None, ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) shape = ( @@ -368,12 +370,6 @@ def prepare_latents( width // self.vae_scale_factor_spatial, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if latents is None: if isinstance(generator, list): if len(generator) != batch_size: @@ -591,10 +587,10 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -656,8 +652,6 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial - width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct @@ -716,7 +710,7 @@ def __call__( video = self.video_processor.preprocess_video(video, height=height, width=width) video = video.to(device=device, dtype=prompt_embeds.dtype) - latent_channels = self.latents_in_channels + latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( video, batch_size * num_videos_per_prompt, From 514ed236c9cc4f8fb8afb5b97179006493e0ca8e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Sep 2024 13:02:29 +0530 Subject: [PATCH 6/6] fix --- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 5 ++++- .../pipelines/cogvideo/pipeline_cogvideox_video2video.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 2206b1e63592..9dbe6c12725f 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -187,6 +187,9 @@ def __init__( self.vae_scale_factor_temporal = ( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -341,7 +344,7 @@ def prepare_latents( def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] - latents = 1 / self.vae.config.scaling_factor * latents + latents = 1 / self.vae_scaling_factor_image * latents frames = self.vae.decode(latents).sample return frames diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 4c3e0196e0e1..5eef76d39334 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -213,6 +213,9 @@ def __init__( self.vae_scale_factor_temporal = ( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -385,7 +388,7 @@ def prepare_latents( init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] - init_latents = self.vae.config.scaling_factor * init_latents + init_latents = self.vae_scaling_factor_image * init_latents noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self.scheduler.add_noise(init_latents, noise, timestep) @@ -399,7 +402,7 @@ def prepare_latents( # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] - latents = 1 / self.vae.config.scaling_factor * latents + latents = 1 / self.vae_scaling_factor_image * latents frames = self.vae.decode(latents).sample return frames