From 06b1a978bb4a6ae55170be6dd7ed3dc9ba77eaf9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 12 Aug 2024 00:23:48 +0200 Subject: [PATCH 01/17] refactor context parallel cache; update torch compile time benchmark --- docs/source/en/api/pipelines/cogvideox.md | 18 ++++----- .../autoencoders/autoencoder_kl_cogvideox.py | 39 ++++++++++--------- .../transformers/cogvideox_transformer_3d.py | 6 +-- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 51026091b348..493c933b2cc5 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -33,9 +33,7 @@ This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRz ## Inference -Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. - -First, load the pipeline: +Inference can be run with CogVideoX by following the code below: ```python import torch @@ -55,29 +53,29 @@ video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] export_to_video(video, "output.mp4", fps=8) ``` -Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`: +Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. + +Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`: ```python pipeline.transformer.to(memory_format=torch.channels_last) -pipeline.vae.to(memory_format=torch.channels_last) ``` Finally, compile the components and run inference: ```python pipeline.transformer = torch.compile(pipeline.transformer) -pipeline.vae.decode = torch.compile(pipeline.vae.decode) -# CogVideoX works very well with long and well-described prompts +# CogVideoX works well with long and well-described prompts prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] ``` -The [benchmark](TODO: link) results on an 80GB A100 machine are: +The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are: ``` -Without torch.compile(): Average inference time: TODO seconds. -With torch.compile(): Average inference time: TODO seconds. +Without torch.compile(): Average inference time: 96.89 seconds. +With torch.compile(): Average inference time: 84.60 seconds. ``` ## CogVideoXPipeline diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 6aad4d63410f..ecfe7cfd5ae6 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -118,19 +118,12 @@ def __init__( self.conv_cache = None def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: - dim = self.temporal_dim kernel_size = self.time_kernel_size - if kernel_size == 1: - return inputs - - inputs = inputs.transpose(0, dim) - - if self.conv_cache is not None: - inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0) - else: - inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0) - - inputs = inputs.transpose(0, dim).contiguous() + if kernel_size > 1: + cached_inputs = ( + [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) + ) + inputs = torch.cat(cached_inputs + [inputs], dim=2) return inputs def _clear_fake_context_parallel_cache(self): @@ -138,16 +131,26 @@ def _clear_fake_context_parallel_cache(self): self.conv_cache = None def forward(self, inputs: torch.Tensor) -> torch.Tensor: - input_parallel = self.fake_context_parallel_forward(inputs) + inputs = self.fake_context_parallel_forward(inputs) self._clear_fake_context_parallel_cache() - self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() + # Note: we could move these to the cpu for a lower maximum memory usage but its only a few + # hundred megabytes and so let's not do it for now + self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) - - output_parallel = self.conv(input_parallel) - output = output_parallel + inputs = F.pad(inputs, padding_2d, mode="constant", value=0) + + # Memory assessment: + # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. + # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. + # - Assume fp16 (2 bytes per value). + # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB + # + # Memory assessment when using tiling: + # - Assume everything as above but now HxW is 240x360 by tiling in half + # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + output = self.conv(inputs) return output diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 9eae35d62e69..c1253fbf183f 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -304,7 +304,7 @@ def forward( encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] hidden_states = hidden_states[:, self.config.max_text_seq_length :] - # 5. Transformer blocks + # 4. Transformer blocks for i, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: @@ -331,11 +331,11 @@ def custom_forward(*inputs): hidden_states = self.norm_final(hidden_states) - # 6. Final block + # 5. Final block hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.proj_out(hidden_states) - # 7. Unpatchify + # 6. Unpatchify p = self.config.patch_size output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) From d962677714c7f57efe3c492fb3c8a5223569fcab Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 12 Aug 2024 10:55:36 +0200 Subject: [PATCH 02/17] add tiling support --- .../autoencoders/autoencoder_kl_cogvideox.py | 163 ++++++++++++++++-- .../pipelines/cogvideo/pipeline_cogvideox.py | 21 +-- 2 files changed, 155 insertions(+), 29 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index ecfe7cfd5ae6..1165ede2a378 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -914,7 +914,8 @@ def __init__( norm_eps: float = 1e-6, norm_num_groups: int = 32, temporal_compression_ratio: float = 4, - sample_size: int = 256, + sample_height: int = 480, + sample_width: int = 720, scaling_factor: float = 1.15258426, shift_factor: Optional[float] = None, latents_mean: Optional[Tuple[float]] = None, @@ -953,25 +954,56 @@ def __init__( self.use_slicing = False self.use_tiling = False - self.tile_sample_min_size = self.config.sample_size - sample_size = ( - self.config.sample_size[0] - if isinstance(self.config.sample_size, (list, tuple)) - else self.config.sample_size - ) - self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) - self.tile_overlap_factor = 0.25 + # We make the minimum height and width of sample for tiling half that of the generally supported + self.tile_sample_min_height = sample_height // 2 + self.tile_sample_min_width = sample_width // 2 + self.tile_latent_min_height = int(self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.33 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): module.gradient_checkpointing = value - def clear_fake_context_parallel_cache(self): + def _clear_fake_context_parallel_cache(self): for name, module in self.named_modules(): if isinstance(module, CogVideoXCausalConv3d): logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") module._clear_fake_context_parallel_cache() + def enable_tiling(self, use_tiling: bool = True, tile_sample_min_height: Optional[int] = None, tile_sample_min_width: Optional[int] = None) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = int(self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True @@ -996,8 +1028,34 @@ def encode( return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_width or z.shape[-2] > self.tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + num_latent_frames = z.shape[2] + dec = [] + for i in range(num_latent_frames // 2): + if num_latent_frames % 2 == 0: + start_frame, end_frame = (2 * i, 2 * i + 2) + else: + start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) + + z_intermediate = z[:, :, start_frame:end_frame] + if self.post_quant_conv is not None: + z_intermediate = self.post_quant_conv(z_intermediate) + z_intermediate = self.decoder(z_intermediate) + dec.append(z_intermediate) + + self._clear_fake_context_parallel_cache() + dec = torch.cat(dec, dim=2) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + @apply_forward_hook - def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: """ Decode a batch of images. @@ -1010,15 +1068,92 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. """ - if self.post_quant_conv is not None: - z = self.post_quant_conv(z) - dec = self.decoder(z) + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_height): + row = [] + for j in range(0, z.shape[4], overlap_width): + time = [] + for k in range(z.shape[2] // 2): + if z.shape[2] % 2 == 0: + start_frame, end_frame = (2 * k, 2 * k + 2) + else: + start_frame, end_frame = (0, 3) if k == 0 else (2 * k + 1, 2 * k + 3) + tile = z[:, :, start_frame:end_frame, i : i + self.tile_latent_min_height, j : j + self.tile_latent_min_width] + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile = self.decoder(tile) + time.append(tile) + self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + if not return_dict: return (dec,) + return DecoderOutput(sample=dec) + def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 04f2752175af..2c150b72f8bf 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -332,20 +332,11 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents - def decode_latents(self, latents: torch.Tensor, num_seconds: int): + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = 1 / self.vae.config.scaling_factor * latents - frames = [] - for i in range(num_seconds): - start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) - - current_frames = self.vae.decode(latents[:, :, start_frame:end_frame]).sample - frames.append(current_frames) - - self.vae.clear_fake_context_parallel_cache() - - frames = torch.cat(frames, dim=2) + frames = self.vae.decode(latents).sample return frames # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -534,9 +525,9 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ - assert ( - num_frames <= 48 and num_frames % fps == 0 and fps == 8 - ), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX." + # assert ( + # num_frames <= 48 and num_frames % fps == 0 and fps == 8 + # ), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX." if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -673,7 +664,7 @@ def __call__( progress_bar.update() if not output_type == "latent": - video = self.decode_latents(latents, num_frames // fps) + video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents From 2b923a8d262700480a115cb5632d384913ae5a31 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 12 Aug 2024 10:55:59 +0200 Subject: [PATCH 03/17] make style --- .../autoencoders/autoencoder_kl_cogvideox.py | 48 +++++++++++++------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 1165ede2a378..f45537364925 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -954,10 +954,12 @@ def __init__( self.use_slicing = False self.use_tiling = False - # We make the minimum height and width of sample for tiling half that of the generally supported + # We make the minimum height and width of sample for tiling half that of the generally supported self.tile_sample_min_height = sample_height // 2 self.tile_sample_min_width = sample_width // 2 - self.tile_latent_min_height = int(self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_latent_min_height = int( + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + ) self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.33 @@ -971,7 +973,12 @@ def _clear_fake_context_parallel_cache(self): logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") module._clear_fake_context_parallel_cache() - def enable_tiling(self, use_tiling: bool = True, tile_sample_min_height: Optional[int] = None, tile_sample_min_width: Optional[int] = None) -> None: + def enable_tiling( + self, + use_tiling: bool = True, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + ) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow @@ -980,7 +987,9 @@ def enable_tiling(self, use_tiling: bool = True, tile_sample_min_height: Optiona self.use_tiling = True self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width - self.tile_latent_min_height = int(self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_latent_min_height = int( + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + ) self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) def disable_tiling(self) -> None: @@ -1003,7 +1012,7 @@ def disable_slicing(self) -> None: decoding in one step. """ self.use_slicing = False - + @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True @@ -1039,7 +1048,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut start_frame, end_frame = (2 * i, 2 * i + 2) else: start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) - + z_intermediate = z[:, :, start_frame:end_frame] if self.post_quant_conv is not None: z_intermediate = self.post_quant_conv(z_intermediate) @@ -1051,7 +1060,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut if not return_dict: return (dec,) - + return DecoderOutput(sample=dec) @apply_forward_hook @@ -1078,19 +1087,23 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) - + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for y in range(blend_extent): - b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) return b def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[4], b.shape[4], blend_extent) for x in range(blend_extent): - b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) return b - + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. @@ -1124,7 +1137,13 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod start_frame, end_frame = (2 * k, 2 * k + 2) else: start_frame, end_frame = (0, 3) if k == 0 else (2 * k + 1, 2 * k + 3) - tile = z[:, :, start_frame:end_frame, i : i + self.tile_latent_min_height, j : j + self.tile_latent_min_width] + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] if self.post_quant_conv is not None: tile = self.post_quant_conv(tile) tile = self.decoder(tile) @@ -1132,7 +1151,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod self._clear_fake_context_parallel_cache() row.append(torch.cat(time, dim=2)) rows.append(row) - + result_rows = [] for i, row in enumerate(rows): result_row = [] @@ -1145,7 +1164,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod tile = self.blend_h(row[j - 1], tile, blend_extent_width) result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) result_rows.append(torch.cat(result_row, dim=4)) - + dec = torch.cat(result_rows, dim=3) if not return_dict: @@ -1153,7 +1172,6 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return DecoderOutput(sample=dec) - def forward( self, sample: torch.Tensor, From e54db72eb42a151eb06d01b3cf9bfbd588566b51 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 12 Aug 2024 11:44:57 +0200 Subject: [PATCH 04/17] remove num_frames % 8 == 0 requirement --- .../autoencoders/autoencoder_kl_cogvideox.py | 20 +++++++++---------- .../pipelines/cogvideo/pipeline_cogvideox.py | 9 ++++----- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index f45537364925..78034824385f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -141,15 +141,6 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) inputs = F.pad(inputs, padding_2d, mode="constant", value=0) - # Memory assessment: - # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. - # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. - # - Assume fp16 (2 bytes per value). - # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB - # - # Memory assessment when using tiling: - # - Assume everything as above but now HxW is 240x360 by tiling in half - # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB output = self.conv(inputs) return output @@ -975,7 +966,6 @@ def _clear_fake_context_parallel_cache(self): def enable_tiling( self, - use_tiling: bool = True, tile_sample_min_height: Optional[int] = None, tile_sample_min_width: Optional[int] = None, ) -> None: @@ -1118,6 +1108,16 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ + # Rough memory assessment: + # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. + # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. + # - Assume fp16 (2 bytes per value). + # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB + # + # Memory assessment when using tiling: + # - Assume everything as above but now HxW is 240x360 by tiling in half + # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 2c150b72f8bf..cfdec2345715 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -430,7 +430,6 @@ def __call__( height: int = 480, width: int = 720, num_frames: int = 48, - fps: int = 8, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, @@ -525,9 +524,10 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ - # assert ( - # num_frames <= 48 and num_frames % fps == 0 and fps == 8 - # ), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX." + if num_frames > 48: + raise ValueError( + "The number of frames must be less than 48 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -584,7 +584,6 @@ def __call__( # 5. Prepare latents. latent_channels = self.transformer.config.in_channels - num_frames += 1 latents = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, From 84d6416c397b9479942869e5ea93a6cac2fe6300 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 12 Aug 2024 11:53:34 +0200 Subject: [PATCH 05/17] update default num_frames to original value --- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index cfdec2345715..b19f5a98d18d 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -429,7 +429,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 480, width: int = 720, - num_frames: int = 48, + num_frames: int = 49, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, From 1ed1cfb1fee402dbacf1629f48cac18706c63401 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 12 Aug 2024 16:07:18 +0200 Subject: [PATCH 06/17] add explanations + refactor --- .../autoencoders/autoencoder_kl_cogvideox.py | 58 ++++++++++++++----- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 78034824385f..5a9744dd2000 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -945,6 +945,22 @@ def __init__( self.use_slicing = False self.use_tiling = False + # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not + # recommended because the temporal parts of the VAE, here, are tricky to understand. + # If you decode X latent frames together, the number of output frames is (X + 2 + 4) - 2 frames => X + 4 frames + # Example with num_latent_frames_batch_size = 2: + # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together + # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale frames) + (4 time upscale frames) - (2 causal conv downscale frames)) + # => 6 * 8 = 48 frames + # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together + # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale frames) + (4 time upscale frames) - (2 causal conv downscale frames)) + + # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale frames) + (4 time upscale frames) - (2 causal conv downscale frames)) + # => 1 * 9 + 5 * 8 = 49 frames + # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that + # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different + # number of temporal frames. + self.num_latent_frames_batch_size = 2 + # We make the minimum height and width of sample for tiling half that of the generally supported self.tile_sample_min_height = sample_height // 2 self.tile_sample_min_width = sample_width // 2 @@ -952,7 +968,12 @@ def __init__( self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) ) self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) - self.tile_overlap_factor = 0.33 + + # These are experimental overlap factors that were chosen based on experimentation and seem to work best for + # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX + # and so the tiling implementation has only been tested on those specific resolutions. + self.tile_overlap_factor_height = 1 / 6 + self.tile_overlap_factor_width = 1 / 5 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): @@ -1031,13 +1052,17 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut if self.use_tiling and (z.shape[-1] > self.tile_latent_min_width or z.shape[-2] > self.tile_latent_min_height): return self.tiled_decode(z, return_dict=return_dict) - num_latent_frames = z.shape[2] + frame_batch_size = self.num_latent_frames_batch_size dec = [] - for i in range(num_latent_frames // 2): - if num_latent_frames % 2 == 0: - start_frame, end_frame = (2 * i, 2 * i + 2) + for i in range(z.shape[2] // frame_batch_size): + if z.shape[2] % frame_batch_size == 0: + start_frame, end_frame = (frame_batch_size * i, frame_batch_size * (i + 1)) else: - start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) + if i == 0: + remaining_frames = z.shape[2] % frame_batch_size + start_frame, end_frame = (0, frame_batch_size + remaining_frames) + else: + start_frame, end_frame = (frame_batch_size * i + 1, frame_batch_size * (i + 1) + 1) z_intermediate = z[:, :, start_frame:end_frame] if self.post_quant_conv is not None: @@ -1118,12 +1143,13 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod # - Assume everything as above but now HxW is 240x360 by tiling in half # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB - overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) - overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) - blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor) - blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor) + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) row_limit_height = self.tile_sample_min_height - blend_extent_height row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size # Split z into overlapping tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. @@ -1132,11 +1158,15 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod row = [] for j in range(0, z.shape[4], overlap_width): time = [] - for k in range(z.shape[2] // 2): - if z.shape[2] % 2 == 0: - start_frame, end_frame = (2 * k, 2 * k + 2) + for k in range(z.shape[2] // frame_batch_size): + if z.shape[2] % frame_batch_size == 0: + start_frame, end_frame = (frame_batch_size * k, frame_batch_size * (k + 1)) else: - start_frame, end_frame = (0, 3) if k == 0 else (2 * k + 1, 2 * k + 3) + if k == 0: + remaining_frames = z.shape[2] % frame_batch_size + start_frame, end_frame = (0, frame_batch_size + remaining_frames) + else: + start_frame, end_frame = (frame_batch_size * k + 1, frame_batch_size * (k + 1) + 1) tile = z[ :, :, From 0094792d86829f234bb4a47f89a813fab827eb96 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 12 Aug 2024 16:07:40 +0200 Subject: [PATCH 07/17] update torch compile example --- docs/source/en/api/pipelines/cogvideox.md | 4 ++-- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 493c933b2cc5..a55440fc13e9 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -64,7 +64,7 @@ pipeline.transformer.to(memory_format=torch.channels_last) Finally, compile the components and run inference: ```python -pipeline.transformer = torch.compile(pipeline.transformer) +pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True) # CogVideoX works well with long and well-described prompts prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." @@ -75,7 +75,7 @@ The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd ``` Without torch.compile(): Average inference time: 96.89 seconds. -With torch.compile(): Average inference time: 84.60 seconds. +With torch.compile(): Average inference time: 76.27 seconds. ``` ## CogVideoXPipeline diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index b19f5a98d18d..cfdec2345715 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -429,7 +429,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 480, width: int = 720, - num_frames: int = 49, + num_frames: int = 48, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, From 50fa1d0020064ca3f3e486a0bedf59d4b91b9945 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 12 Aug 2024 16:08:43 +0200 Subject: [PATCH 08/17] update docs --- docs/source/en/api/pipelines/cogvideox.md | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index a55440fc13e9..c896567318d4 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -33,7 +33,9 @@ This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRz ## Inference -Inference can be run with CogVideoX by following the code below: +Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. + +First, load the pipeline: ```python import torch @@ -41,20 +43,8 @@ from diffusers import CogVideoXPipeline from diffusers.utils import export_to_video pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda") -prompt = ( - "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " - "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " - "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " - "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " - "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " - "atmosphere of this unique musical performance." -) -video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] -export_to_video(video, "output.mp4", fps=8) ``` -Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. - Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`: ```python From 9de509dcf458b0620233f435d12121a05bb6aede Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 12 Aug 2024 16:15:28 +0200 Subject: [PATCH 09/17] update --- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 6 +++--- tests/pipelines/cogvideox/test_cogvideox.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index cfdec2345715..f43edab987fe 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -429,7 +429,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 480, width: int = 720, - num_frames: int = 48, + num_frames: int = 49, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, @@ -524,9 +524,9 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ - if num_frames > 48: + if num_frames > 49: raise ValueError( - "The number of frames must be less than 48 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py index 2219cde57088..a407ed2278fd 100644 --- a/tests/pipelines/cogvideox/test_cogvideox.py +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -148,8 +148,8 @@ def test_inference(self): video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 16, 16)) - expected_video = torch.randn(9, 3, 16, 16) + self.assertEqual(generated_video.shape, (8, 3, 16, 16)) + expected_video = torch.randn(8, 3, 16, 16) max_diff = np.abs(generated_video - expected_video).max() self.assertLessEqual(max_diff, 1e10) From 7f63ee2ed802182c9c96c64a1dcb95dd29f66d6f Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 12 Aug 2024 16:28:25 +0200 Subject: [PATCH 10/17] clean up if-statements --- .../autoencoders/autoencoder_kl_cogvideox.py | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 5a9744dd2000..7a26ce270abc 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -1055,15 +1055,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut frame_batch_size = self.num_latent_frames_batch_size dec = [] for i in range(z.shape[2] // frame_batch_size): - if z.shape[2] % frame_batch_size == 0: - start_frame, end_frame = (frame_batch_size * i, frame_batch_size * (i + 1)) - else: - if i == 0: - remaining_frames = z.shape[2] % frame_batch_size - start_frame, end_frame = (0, frame_batch_size + remaining_frames) - else: - start_frame, end_frame = (frame_batch_size * i + 1, frame_batch_size * (i + 1) + 1) - + remaining_frames = z.shape[2] % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames z_intermediate = z[:, :, start_frame:end_frame] if self.post_quant_conv is not None: z_intermediate = self.post_quant_conv(z_intermediate) @@ -1159,14 +1153,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod for j in range(0, z.shape[4], overlap_width): time = [] for k in range(z.shape[2] // frame_batch_size): - if z.shape[2] % frame_batch_size == 0: - start_frame, end_frame = (frame_batch_size * k, frame_batch_size * (k + 1)) - else: - if k == 0: - remaining_frames = z.shape[2] % frame_batch_size - start_frame, end_frame = (0, frame_batch_size + remaining_frames) - else: - start_frame, end_frame = (frame_batch_size * k + 1, frame_batch_size * (k + 1) + 1) + remaining_frames = z.shape[2] % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames tile = z[ :, :, From 76dda5ea8e2f244681c16ae57a9b171df77880d5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 13 Aug 2024 15:12:39 +0200 Subject: [PATCH 11/17] address review comments --- .../autoencoders/autoencoder_kl_cogvideox.py | 46 ++++++++++++++----- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 7a26ce270abc..687687a11337 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -947,14 +947,16 @@ def __init__( # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not # recommended because the temporal parts of the VAE, here, are tricky to understand. - # If you decode X latent frames together, the number of output frames is (X + 2 + 4) - 2 frames => X + 4 frames + # If you decode X latent frames together, the number of output frames is: + # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames + # # Example with num_latent_frames_batch_size = 2: # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together - # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale frames) + (4 time upscale frames) - (2 causal conv downscale frames)) + # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale frames)) # => 6 * 8 = 48 frames # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together - # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale frames) + (4 time upscale frames) - (2 causal conv downscale frames)) + - # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale frames) + (4 time upscale frames) - (2 causal conv downscale frames)) + # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale frames)) + + # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale frames)) # => 1 * 9 + 5 * 8 = 49 frames # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different @@ -989,11 +991,27 @@ def enable_tiling( self, tile_sample_min_height: Optional[int] = None, tile_sample_min_width: Optional[int] = None, + tile_overlap_factor_height: Optional[float] = None, + tile_overlap_factor_width: Optional[float] = None, ) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_overlap_factor_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + tile_overlap_factor_width (`int`, *optional*): + The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there + are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. """ self.use_tiling = True self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height @@ -1002,6 +1020,8 @@ def enable_tiling( self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) ) self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height + self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width def disable_tiling(self) -> None: r""" @@ -1049,13 +1069,15 @@ def encode( return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - if self.use_tiling and (z.shape[-1] > self.tile_latent_min_width or z.shape[-2] > self.tile_latent_min_height): + batch_size, num_channels, num_frames, height, width = z.shape + + if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): return self.tiled_decode(z, return_dict=return_dict) frame_batch_size = self.num_latent_frames_batch_size dec = [] - for i in range(z.shape[2] // frame_batch_size): - remaining_frames = z.shape[2] % frame_batch_size + for i in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) end_frame = frame_batch_size * (i + 1) + remaining_frames z_intermediate = z[:, :, start_frame:end_frame] @@ -1137,6 +1159,8 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod # - Assume everything as above but now HxW is 240x360 by tiling in half # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + batch_size, num_channels, num_frames, height, width = z.shape + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) @@ -1148,12 +1172,12 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod # Split z into overlapping tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] - for i in range(0, z.shape[3], overlap_height): + for i in range(0, height, overlap_height): row = [] - for j in range(0, z.shape[4], overlap_width): + for j in range(0, width, overlap_width): time = [] - for k in range(z.shape[2] // frame_batch_size): - remaining_frames = z.shape[2] % frame_batch_size + for k in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) end_frame = frame_batch_size * (k + 1) + remaining_frames tile = z[ From ea86c32a8d3be6bd30337a095b7c777656cb0b8a Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 13 Aug 2024 15:51:16 +0200 Subject: [PATCH 12/17] add test for vae tiling --- .../autoencoders/autoencoder_kl_cogvideox.py | 6 ++-- tests/pipelines/cogvideox/test_cogvideox.py | 35 ++++++++++++++++--- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 687687a11337..405071d41e18 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -952,11 +952,11 @@ def __init__( # # Example with num_latent_frames_batch_size = 2: # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together - # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale frames)) + # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) # => 6 * 8 = 48 frames # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together - # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale frames)) + - # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale frames)) + # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + + # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) # => 1 * 9 + 5 * 8 = 49 frames # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py index a407ed2278fd..3ae500eb9567 100644 --- a/tests/pipelines/cogvideox/test_cogvideox.py +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -125,11 +125,6 @@ def get_dummy_inputs(self, device, seed=0): # Cannot reduce because convolution kernel becomes bigger than sample "height": 16, "width": 16, - # TODO(aryan): improve this - # Cannot make this lower due to assert condition in pipeline at the moment. - # The reason why 8 can't be used here is due to how context-parallel cache works where the first - # second of video is decoded from latent frames (0, 3) instead of [(0, 2), (2, 3)]. If 8 is used, - # the number of output frames that you get are 5. "num_frames": 8, "max_sequence_length": 16, "output_type": "pt", @@ -250,6 +245,36 @@ def test_attention_slicing_forward_pass( "Attention slicing should not affect the inference results", ) + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_overlap_factor_height=1 / 12, + tile_overlap_factor_width=1 / 12, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + @slow @require_torch_gpu From 878890dd071a14dac4e4d130c15acd78184ec824 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 13 Aug 2024 16:01:36 +0200 Subject: [PATCH 13/17] update docs --- docs/source/en/api/pipelines/cogvideox.md | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index c896567318d4..71a2b3180fb5 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -15,9 +15,7 @@ # CogVideoX - - -[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) from Tsinghua University & ZhipuAI. +[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang. The abstract from the paper is: @@ -68,6 +66,16 @@ Without torch.compile(): Average inference time: 96.89 seconds. With torch.compile(): Average inference time: 76.27 seconds. ``` +### Memory optimization + +CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. The following optimizations can be applied (for replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script): + +- `pipe.enable_model_cpu_offload()`: + - Without enabling cpu offloading, memory usage is `33 GB` + - With enabling cpu offloading, memory usage is `19 GB` +- `pipe.vae.enable_tiling()`: + - With enabling cpu offloading and tiling, memory usage is `11 GB` + ## CogVideoXPipeline [[autodoc]] CogVideoXPipeline From a40e8d24904066b0eefdb1985386e3a9868e1d1a Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 13 Aug 2024 16:02:09 +0200 Subject: [PATCH 14/17] update docs --- docs/source/en/api/pipelines/cogvideox.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 71a2b3180fb5..3df12589ceee 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -46,17 +46,17 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda") Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`: ```python -pipeline.transformer.to(memory_format=torch.channels_last) +pipe.transformer.to(memory_format=torch.channels_last) ``` Finally, compile the components and run inference: ```python -pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True) +pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True) # CogVideoX works well with long and well-described prompts prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." -video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] +video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] ``` The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are: From 836f5d0644e4c9a5c81a126e029afaca5cc5f5f4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 13 Aug 2024 23:04:03 +0200 Subject: [PATCH 15/17] update docstrings --- docs/source/en/api/pipelines/cogvideox.md | 3 +- .../autoencoders/autoencoder_kl_cogvideox.py | 157 +++++++++++------- .../transformers/cogvideox_transformer_3d.py | 92 ++++++---- 3 files changed, 157 insertions(+), 95 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 3df12589ceee..549666e60ebc 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -68,13 +68,14 @@ With torch.compile(): Average inference time: 76.27 seconds. ### Memory optimization -CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. The following optimizations can be applied (for replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script): +CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script. - `pipe.enable_model_cpu_offload()`: - Without enabling cpu offloading, memory usage is `33 GB` - With enabling cpu offloading, memory usage is `19 GB` - `pipe.vae.enable_tiling()`: - With enabling cpu offloading and tiling, memory usage is `11 GB` +- `pipe.vae.enable_slicing()` ## CogVideoXPipeline diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 405071d41e18..3bf6e68d2628 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -36,7 +36,7 @@ class CogVideoXSafeConv3d(nn.Conv3d): - """ + r""" A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. """ @@ -68,12 +68,12 @@ class CogVideoXCausalConv3d(nn.Module): r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. Args: - in_channels (int): Number of channels in the input tensor. - out_channels (int): Number of output channels. - kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel. - stride (int, optional): Stride of the convolution. Default is 1. - dilation (int, optional): Dilation rate of the convolution. Default is 1. - pad_mode (str, optional): Padding mode. Default is "constant". + in_channels (`int`): Number of channels in the input tensor. + out_channels (`int`): Number of output channels produced by the convolution. + kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel. + stride (`int`, defaults to `1`): Stride of the convolution. + dilation (`int`, defaults to `1`): Dilation rate of the convolution. + pad_mode (`str`, defaults to `"constant"`): Padding mode. """ def __init__( @@ -157,6 +157,8 @@ class CogVideoXSpatialNorm3D(nn.Module): The number of channels for input to group normalization layer, and output of the spatial norm layer. zq_channels (`int`): The number of channels for the quantized vector as described in the paper. + groups (`int`): + Number of groups to separate the channels into for group normalization. """ def __init__( @@ -191,17 +193,26 @@ class CogVideoXResnetBlock3D(nn.Module): A 3D ResNet block used in the CogVideoX model. Args: - in_channels (int): Number of input channels. - out_channels (Optional[int], optional): - Number of output channels. If None, defaults to `in_channels`. Default is None. - dropout (float, optional): Dropout rate. Default is 0.0. - temb_channels (int, optional): Number of time embedding channels. Default is 512. - groups (int, optional): Number of groups for group normalization. Default is 32. - eps (float, optional): Epsilon value for normalization layers. Default is 1e-6. - non_linearity (str, optional): Activation function to use. Default is "swish". - conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False. - spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. - pad_mode (str, optional): Padding mode. Default is "first". + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. """ def __init__( @@ -303,18 +314,28 @@ class CogVideoXDownBlock3D(nn.Module): A downsampling block used in the CogVideoX model. Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - temb_channels (int): Number of time embedding channels. - dropout (float, optional): Dropout rate. Default is 0.0. - num_layers (int, optional): Number of layers in the block. Default is 1. - resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. - resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". - resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. - add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True. - downsample_padding (int, optional): Padding for the downsampling layer. Default is 0. - compress_time (bool, optional): If True, apply temporal compression. Default is False. - pad_mode (str, optional): Padding mode. Default is "first". + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + add_downsample (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. """ _supports_gradient_checkpointing = True @@ -399,15 +420,24 @@ class CogVideoXMidBlock3D(nn.Module): A middle block used in the CogVideoX model. Args: - in_channels (int): Number of input channels. - temb_channels (int): Number of time embedding channels. - dropout (float, optional): Dropout rate. Default is 0.0. - num_layers (int, optional): Number of layers in the block. Default is 1. - resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. - resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". - resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. - spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. - pad_mode (str, optional): Padding mode. Default is "first". + in_channels (`int`): + Number of input channels. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + dropout (`float`, defaults to `0.0`): + Dropout rate. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. """ _supports_gradient_checkpointing = True @@ -474,19 +504,30 @@ class CogVideoXUpBlock3D(nn.Module): An upsampling block used in the CogVideoX model. Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - temb_channels (int): Number of time embedding channels. - dropout (float, optional): Dropout rate. Default is 0.0. - num_layers (int, optional): Number of layers in the block. Default is 1. - resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. - resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". - resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. - spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16. - add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True. - upsample_padding (int, optional): Padding for the upsampling layer. Default is 1. - compress_time (bool, optional): If True, apply temporal compression. Default is False. - pad_mode (str, optional): Padding mode. Default is "first". + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + dropout (`float`, defaults to `0.0`): + Dropout rate. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + spatial_norm_dim (`int`, defaults to `16`): + The dimension to use for spatial norm if it is to be used instead of group norm. + add_upsample (`bool`, defaults to `True`): + Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. """ def __init__( @@ -581,14 +622,12 @@ class CogVideoXEncoder3D(nn.Module): options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - double_z (`bool`, *optional*, defaults to `True`): - Whether to double the number of output channels for the last block. """ _supports_gradient_checkpointing = True @@ -717,14 +756,12 @@ class CogVideoXDecoder3D(nn.Module): The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - norm_type (`str`, *optional*, defaults to `"group"`): - The normalization type to use. Can be either `"group"` or `"spatial"`. """ _supports_gradient_checkpointing = True diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index c1253fbf183f..1030b0df04ff 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -37,13 +37,20 @@ class CogVideoXBlock(nn.Module): Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. qk_norm (`bool`, defaults to `True`): Whether or not to use normalization after query and key projections in Attention. norm_elementwise_affine (`bool`, defaults to `True`): @@ -147,36 +154,53 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): + num_attention_heads (`int`, defaults to `30`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): The number of channels in the input. - out_channels (`int`, *optional*): + out_channels (`int`, *optional*, defaults to `16`): The number of channels in the output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - patch_size (`int`, *optional*): + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_width (`int`, defaults to `90`): + The width of the input latents. + sample_height (`int`, defaults to `60`): + The height of the input latents. + sample_frames (`int`, defaults to `49`): + The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 + instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, + but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with + K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). + patch_size (`int`, defaults to `2`): The size of the patches to use in the patch embedding layer. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. During inference, you can denoise for up to but not more steps than - `num_embeds_ada_norm`. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): - The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + temporal_compression_ratio (`int`, defaults to `4`): + The compression ratio across the temporal dimension. See documentation for `sample_frames`. + max_text_seq_length (`int`, defaults to `226`): + The maximum sequence length of the input text embeddings. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + timestep_activation_fn (`str`, defaults to `"silu"`): + Activation function to use when generating the timestep embeddings. + norm_elementwise_affine (`bool`, defaults to `True`): Whether or not to use elementwise affine in normalization layers. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. - caption_channels (`int`, *optional*): - The number of channels in the caption embeddings. - video_length (`int`, *optional*): - The number of frames in the video-like data. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + spatial_interpolation_scale (`float`, defaults to `1.875`): + Scaling factor to apply in 3D positional embeddings across spatial dimensions. + temporal_interpolation_scale (`float`, defaults to `1.0`): + Scaling factor to apply in 3D positional embeddings across temporal dimensions. """ _supports_gradient_checkpointing = True @@ -186,7 +210,7 @@ def __init__( self, num_attention_heads: int = 30, attention_head_dim: int = 64, - in_channels: Optional[int] = 16, + in_channels: int = 16, out_channels: Optional[int] = 16, flip_sin_to_cos: bool = True, freq_shift: int = 0, From 661f7b87cf0665f1665039f56d84fcdaf6e89a4e Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 14 Aug 2024 00:04:34 +0200 Subject: [PATCH 16/17] add modeling test for cogvideox transformer --- .../test_models_transformer_cogvideox.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 tests/models/transformers/test_models_transformer_cogvideox.py diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py new file mode 100644 index 000000000000..4a86cd120039 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_cogvideox.py @@ -0,0 +1,82 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import CogVideoXTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = CogVideoXTransformer3DModel + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 1 + height = 8 + width = 8 + embedding_dim = 8 + sequence_length = 8 + + hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (1, 4, 8, 8) + + @property + def output_shape(self): + return (1, 4, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings. + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 4, + "out_channels": 4, + "time_embed_dim": 2, + "text_embed_dim": 8, + "num_layers": 1, + "sample_width": 8, + "sample_height": 8, + "sample_frames": 8, + "patch_size": 2, + "temporal_compression_ratio": 4, + "max_text_seq_length": 8, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict From 1b6c527ab1ea2c69c8c253f1445aa98d8f0a067c Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 14 Aug 2024 00:07:23 +0200 Subject: [PATCH 17/17] make style --- .../models/transformers/test_models_transformer_cogvideox.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py index 4a86cd120039..83cdf87baa4f 100644 --- a/tests/models/transformers/test_models_transformer_cogvideox.py +++ b/tests/models/transformers/test_models_transformer_cogvideox.py @@ -52,11 +52,11 @@ def dummy_input(self): "encoder_hidden_states": encoder_hidden_states, "timestep": timestep, } - + @property def input_shape(self): return (1, 4, 8, 8) - + @property def output_shape(self): return (1, 4, 8, 8)