cogvideo model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Reviewed target pipelines, model files, lazy exports, docs/tests/examples, dtype/device paths, offload-facing behavior, attention processors, and coverage. Public imports/lazy loading looked consistent. I did not find separate actionable issues in pipeline_output.py or autoencoder_kl_cogvideox.py.
Execution: standalone repros were run with .venv/Scripts/python.exe; no full pytest suite was run.
Duplicate search: searched GitHub Issues and PRs for cogvideo, affected class/file names, and the specific failure modes. Exact duplicate found only for Issue 5: #9641. Related but not exact: #11133, #9972, #13586, PR #11368, PR #9333.
Issue 1: num_videos_per_prompt is accepted but ignored
Affected code:
|
num_videos_per_prompt: int = 1, |
|
num_videos_per_prompt = 1 |
|
num_videos_per_prompt: int = 1, |
|
num_videos_per_prompt = 1 |
|
num_videos_per_prompt: int = 1, |
|
num_videos_per_prompt = 1 |
|
num_videos_per_prompt: int = 1, |
|
num_videos_per_prompt = 1 |
Problem:
All four pipelines expose num_videos_per_prompt, but each __call__ resets it to 1 before prompt encoding and latent preparation. The text-only pipeline can already use the parameter correctly if that reset is removed. The conditioned pipelines also need image/video/control latents expanded to the effective batch, or they should reject values above 1.
Impact:
Users requesting multiple videos per prompt silently get one video per prompt. Batch behavior and callback tensor shapes are also misleading.
Reproduction:
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
def tiny_pipe():
transformer = CogVideoXTransformer3DModel(
num_attention_heads=4, attention_head_dim=8, in_channels=4, out_channels=4,
time_embed_dim=2, text_embed_dim=32, num_layers=1,
sample_width=2, sample_height=2, sample_frames=9, patch_size=2,
temporal_compression_ratio=4, max_text_seq_length=16,
)
vae = AutoencoderKLCogVideoX(
in_channels=3, out_channels=3,
down_block_types=("CogVideoXDownBlock3D",) * 4,
up_block_types=("CogVideoXUpBlock3D",) * 4,
block_out_channels=(8, 8, 8, 8), latent_channels=4,
layers_per_block=1, norm_num_groups=2, temporal_compression_ratio=4,
)
return CogVideoXPipeline(None, None, transformer, vae, CogVideoXDDIMScheduler())
pipe = tiny_pipe()
frames = pipe(
prompt_embeds=torch.zeros(1, 16, 32),
height=16, width=16, num_frames=5,
num_inference_steps=1, guidance_scale=1,
num_videos_per_prompt=2, output_type="latent",
).frames
print(frames.shape)
assert frames.shape[0] == 2, f"expected 2 videos, got {frames.shape[0]}"
Relevant precedent:
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
do_classifier_free_guidance=self.do_classifier_free_guidance, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
) |
|
|
|
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype |
|
prompt_embeds = prompt_embeds.to(transformer_dtype) |
|
if negative_prompt_embeds is not None: |
|
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) |
|
|
|
# 4. Prepare timesteps |
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
# 5. Prepare latent variables |
|
num_channels_latents = ( |
|
self.transformer.config.in_channels |
|
if self.transformer is not None |
|
else self.transformer_2.config.in_channels |
|
) |
|
latents = self.prepare_latents( |
|
batch_size * num_videos_per_prompt, |
|
) = self.encode_prompt( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
do_classifier_free_guidance=self.do_classifier_free_guidance, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
prompt_attention_mask=prompt_attention_mask, |
|
negative_prompt_attention_mask=negative_prompt_attention_mask, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
) |
|
# 4. Prepare latent variables |
|
num_channels_latents = self.transformer.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_videos_per_prompt, |
Suggested fix:
# Text-to-video: remove the forced reset.
# num_videos_per_prompt = 1
# Conditioned pipelines should either expand conditioning latents:
def _repeat_to_effective_batch(tensor, batch_size, num_videos_per_prompt):
if tensor.shape[0] == 1:
return tensor.repeat_interleave(batch_size * num_videos_per_prompt, dim=0)
if tensor.shape[0] == batch_size:
return tensor.repeat_interleave(num_videos_per_prompt, dim=0)
return tensor
# Or reject unsupported requests until expansion is implemented:
if num_videos_per_prompt != 1:
raise ValueError("`num_videos_per_prompt > 1` is not currently supported by this conditioned CogVideoX pipeline.")
Issue 2: CogVideoXFunControlPipeline crashes when control_video_latents is supplied
Affected code:
|
control_video: list[Image.Image] | None = None, |
|
height: int | None = None, |
|
width: int | None = None, |
|
num_inference_steps: int = 50, |
|
timesteps: list[int] | None = None, |
|
guidance_scale: float = 6, |
|
use_dynamic_cfg: bool = False, |
|
num_videos_per_prompt: int = 1, |
|
eta: float = 0.0, |
|
generator: torch.Generator | list[torch.Generator] | None = None, |
|
latents: torch.Tensor | None = None, |
|
control_video_latents: torch.Tensor | None = None, |
|
if control_video_latents is None: |
|
control_video = self.video_processor.preprocess_video(control_video, height=height, width=width) |
|
control_video = control_video.to(device=device, dtype=prompt_embeds.dtype) |
|
|
|
_, control_video_latents = self.prepare_control_latents(None, control_video) |
|
control_video_latents = control_video_latents.permute(0, 2, 1, 3, 4) |
Problem:
The API documents control_video_latents, and check_inputs only rejects passing both raw control video and latents. But __call__ unconditionally runs prepare_control_latents(None, control_video) after the preprocessing branch. When only control_video_latents is supplied, control_video is None, so the user tensor is discarded and the next .permute() crashes.
Impact:
The precomputed control-latent path is unusable.
Reproduction:
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXFunControlPipeline, CogVideoXTransformer3DModel
transformer = CogVideoXTransformer3DModel(
num_attention_heads=4, attention_head_dim=8, in_channels=8, out_channels=4,
time_embed_dim=2, text_embed_dim=32, num_layers=1,
sample_width=2, sample_height=2, sample_frames=9, patch_size=2,
temporal_compression_ratio=4, max_text_seq_length=16,
)
vae = AutoencoderKLCogVideoX(
in_channels=3, out_channels=3,
down_block_types=("CogVideoXDownBlock3D",) * 4,
up_block_types=("CogVideoXUpBlock3D",) * 4,
block_out_channels=(8, 8, 8, 8), latent_channels=4,
layers_per_block=1, norm_num_groups=2, temporal_compression_ratio=4,
)
pipe = CogVideoXFunControlPipeline(None, None, transformer, vae, CogVideoXDDIMScheduler())
try:
pipe(
prompt_embeds=torch.zeros(1, 16, 32),
control_video_latents=torch.zeros(1, 4, 2, 2, 2),
height=16, width=16, num_inference_steps=1,
guidance_scale=1, output_type="latent",
)
except Exception as e:
print(type(e).__name__, e)
Relevant precedent:
The raw-image/video latent paths in the other CogVideoX pipelines keep the precomputed latents branch separate from preprocessing.
Suggested fix:
if control_video_latents is None:
if control_video is None:
raise ValueError("Provide either `control_video` or `control_video_latents`.")
control_video = self.video_processor.preprocess_video(control_video, height=height, width=width)
control_video = control_video.to(device=device, dtype=prompt_embeds.dtype)
_, control_video_latents = self.prepare_control_latents(None, control_video)
else:
control_video_latents = control_video_latents.to(device=device, dtype=prompt_embeds.dtype)
control_video_latents = control_video_latents.permute(0, 2, 1, 3, 4)
Issue 3: Supplied prompt_embeds and latents are not cast to execution dtype
Affected code:
|
device = device or self._execution_device |
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
if prompt is not None: |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
if prompt_embeds is None: |
|
prompt_embeds = self._get_t5_prompt_embeds( |
|
prompt=prompt, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None: |
|
negative_prompt = negative_prompt or "" |
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
|
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
|
|
negative_prompt_embeds = self._get_t5_prompt_embeds( |
|
prompt=negative_prompt, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
return prompt_embeds, negative_prompt_embeds |
|
if latents is None: |
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
else: |
|
latents = latents.to(device) |
|
|
|
# scale the initial noise by the standard deviation required by the scheduler |
|
latents = latents * self.scheduler.init_noise_sigma |
|
def encode_prompt( |
|
self, |
|
prompt: str | list[str], |
|
negative_prompt: str | list[str] | None = None, |
|
do_classifier_free_guidance: bool = True, |
|
num_videos_per_prompt: int = 1, |
|
prompt_embeds: torch.Tensor | None = None, |
|
negative_prompt_embeds: torch.Tensor | None = None, |
|
max_sequence_length: int = 226, |
|
device: torch.device | None = None, |
|
dtype: torch.dtype | None = None, |
|
): |
|
r""" |
|
Encodes the prompt into text encoder hidden states. |
|
|
|
Args: |
|
prompt (`str` or `list[str]`, *optional*): |
|
prompt to be encoded |
|
negative_prompt (`str` or `list[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
|
less than `1`). |
|
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): |
|
Whether to use classifier free guidance or not. |
|
num_videos_per_prompt (`int`, *optional*, defaults to 1): |
|
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
device: (`torch.device`, *optional*): |
|
torch device |
|
dtype: (`torch.dtype`, *optional*): |
|
torch dtype |
|
""" |
|
device = device or self._execution_device |
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
if prompt is not None: |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
if prompt_embeds is None: |
|
prompt_embeds = self._get_t5_prompt_embeds( |
|
prompt=prompt, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None: |
|
negative_prompt = negative_prompt or "" |
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
|
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
|
|
negative_prompt_embeds = self._get_t5_prompt_embeds( |
|
prompt=negative_prompt, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
return prompt_embeds, negative_prompt_embeds |
|
if latents is None: |
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
else: |
|
latents = latents.to(device) |
|
|
|
# scale the initial noise by the standard deviation required by the scheduler |
|
latents = latents * self.scheduler.init_noise_sigma |
|
def encode_prompt( |
|
self, |
|
prompt: str | list[str], |
|
negative_prompt: str | list[str] | None = None, |
|
do_classifier_free_guidance: bool = True, |
|
num_videos_per_prompt: int = 1, |
|
prompt_embeds: torch.Tensor | None = None, |
|
negative_prompt_embeds: torch.Tensor | None = None, |
|
max_sequence_length: int = 226, |
|
device: torch.device | None = None, |
|
dtype: torch.dtype | None = None, |
|
): |
|
r""" |
|
Encodes the prompt into text encoder hidden states. |
|
|
|
Args: |
|
prompt (`str` or `list[str]`, *optional*): |
|
prompt to be encoded |
|
negative_prompt (`str` or `list[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
|
less than `1`). |
|
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): |
|
Whether to use classifier free guidance or not. |
|
num_videos_per_prompt (`int`, *optional*, defaults to 1): |
|
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
device: (`torch.device`, *optional*): |
|
torch device |
|
dtype: (`torch.dtype`, *optional*): |
|
torch dtype |
|
""" |
|
device = device or self._execution_device |
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
if prompt is not None: |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
if prompt_embeds is None: |
|
prompt_embeds = self._get_t5_prompt_embeds( |
|
prompt=prompt, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None: |
|
negative_prompt = negative_prompt or "" |
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
|
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
|
|
negative_prompt_embeds = self._get_t5_prompt_embeds( |
|
prompt=negative_prompt, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
return prompt_embeds, negative_prompt_embeds |
|
if latents is None: |
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
else: |
|
latents = latents.to(device) |
|
|
|
# scale the initial noise by the standard deviation required by the scheduler |
|
latents = latents * self.scheduler.init_noise_sigma |
|
def encode_prompt( |
|
self, |
|
prompt: str | list[str], |
|
negative_prompt: str | list[str] | None = None, |
|
do_classifier_free_guidance: bool = True, |
|
num_videos_per_prompt: int = 1, |
|
prompt_embeds: torch.Tensor | None = None, |
|
negative_prompt_embeds: torch.Tensor | None = None, |
|
max_sequence_length: int = 226, |
|
device: torch.device | None = None, |
|
dtype: torch.dtype | None = None, |
|
): |
|
r""" |
|
Encodes the prompt into text encoder hidden states. |
|
|
|
Args: |
|
prompt (`str` or `list[str]`, *optional*): |
|
prompt to be encoded |
|
negative_prompt (`str` or `list[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
|
less than `1`). |
|
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): |
|
Whether to use classifier free guidance or not. |
|
num_videos_per_prompt (`int`, *optional*, defaults to 1): |
|
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
device: (`torch.device`, *optional*): |
|
torch device |
|
dtype: (`torch.dtype`, *optional*): |
|
torch dtype |
|
""" |
|
device = device or self._execution_device |
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
if prompt is not None: |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
if prompt_embeds is None: |
|
prompt_embeds = self._get_t5_prompt_embeds( |
|
prompt=prompt, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None: |
|
negative_prompt = negative_prompt or "" |
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
|
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
|
|
negative_prompt_embeds = self._get_t5_prompt_embeds( |
|
prompt=negative_prompt, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
return prompt_embeds, negative_prompt_embeds |
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
latents = self.scheduler.add_noise(init_latents, noise, timestep) |
|
else: |
|
latents = latents.to(device) |
|
|
|
# scale the initial noise by the standard deviation required by the scheduler |
|
latents = latents * self.scheduler.init_noise_sigma |
Problem:
Generated prompt embeddings are cast, but user-supplied prompt_embeds and negative_prompt_embeds are returned unchanged. User-supplied latents are moved to device but not dtype. With a bf16/fp16 pipeline and fp32 tensors, transformer projections hit dtype mismatches.
Impact:
Documented advanced inputs break mixed precision inference.
Reproduction:
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
transformer = CogVideoXTransformer3DModel(
num_attention_heads=4, attention_head_dim=8, in_channels=4, out_channels=4,
time_embed_dim=2, text_embed_dim=32, num_layers=1,
sample_width=2, sample_height=2, sample_frames=9, patch_size=2,
temporal_compression_ratio=4, max_text_seq_length=16,
)
vae = AutoencoderKLCogVideoX(
in_channels=3, out_channels=3,
down_block_types=("CogVideoXDownBlock3D",) * 4,
up_block_types=("CogVideoXUpBlock3D",) * 4,
block_out_channels=(8, 8, 8, 8), latent_channels=4,
layers_per_block=1, norm_num_groups=2, temporal_compression_ratio=4,
)
pipe = CogVideoXPipeline(None, None, transformer, vae, CogVideoXDDIMScheduler()).to(dtype=torch.bfloat16)
try:
pipe(
prompt_embeds=torch.zeros(1, 16, 32, dtype=torch.float32),
height=16, width=16, num_frames=5,
num_inference_steps=1, guidance_scale=1, output_type="latent",
)
except RuntimeError as e:
print(e)
Relevant precedent:
|
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype |
|
prompt_embeds = prompt_embeds.to(transformer_dtype) |
|
if negative_prompt_embeds is not None: |
|
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) |
|
if latents is not None: |
|
return latents.to(device=device, dtype=dtype) |
Suggested fix:
# After prompt embedding selection in encode_prompt:
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=dtype)
# In prepare_latents branches:
latents = latents.to(device=device, dtype=dtype)
Issue 4: Spatial validation accepts sizes that later fail patchification
Affected code:
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
p = self.config.patch_size |
|
p_t = self.config.patch_size_t |
|
|
|
if p_t is None: |
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) |
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) |
|
else: |
|
output = hidden_states.reshape( |
|
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p |
|
) |
|
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) |
Problem:
Pipelines only require height and width to be divisible by 8, the VAE scale factor. The transformer then patchifies latents with patch_size=2, so the original size must usually be divisible by 8 * 2 = 16. For example, 24x24 passes validation but produces latent 3x3, which fails later.
Impact:
Users get a late low-level tensor shape error instead of an actionable validation error.
Reproduction:
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
transformer = CogVideoXTransformer3DModel(
num_attention_heads=4, attention_head_dim=8, in_channels=4, out_channels=4,
time_embed_dim=2, text_embed_dim=32, num_layers=1,
sample_width=2, sample_height=2, sample_frames=9, patch_size=2,
temporal_compression_ratio=4, max_text_seq_length=16,
)
vae = AutoencoderKLCogVideoX(
in_channels=3, out_channels=3,
down_block_types=("CogVideoXDownBlock3D",) * 4,
up_block_types=("CogVideoXUpBlock3D",) * 4,
block_out_channels=(8, 8, 8, 8), latent_channels=4,
layers_per_block=1, norm_num_groups=2, temporal_compression_ratio=4,
)
pipe = CogVideoXPipeline(None, None, transformer, vae, CogVideoXDDIMScheduler())
try:
pipe(
prompt_embeds=torch.zeros(1, 16, 32),
height=24, width=24, num_frames=5,
num_inference_steps=1, guidance_scale=1, output_type="latent",
)
except Exception as e:
print(type(e).__name__, e)
Relevant precedent:
|
patch_size = ( |
|
self.transformer.config.patch_size |
|
if self.transformer is not None |
|
else self.transformer_2.config.patch_size |
|
) |
|
h_multiple_of = self.vae_scale_factor_spatial * patch_size[1] |
|
w_multiple_of = self.vae_scale_factor_spatial * patch_size[2] |
|
calc_height = height // h_multiple_of * h_multiple_of |
|
calc_width = width // w_multiple_of * w_multiple_of |
|
if height != calc_height or width != calc_width: |
|
logger.warning( |
|
f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper patchification. " |
|
f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." |
|
) |
|
height, width = calc_height, calc_width |
|
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") |
Suggested fix:
spatial_multiple = self.vae_scale_factor_spatial * self.transformer.config.patch_size
if height % spatial_multiple != 0 or width % spatial_multiple != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {spatial_multiple} "
f"because CogVideoX patchifies VAE latents with patch_size={self.transformer.config.patch_size}; "
f"got {height} and {width}."
)
Issue 5: Attention backend selection cannot affect CogVideoX attention
Affected code:
|
from ..attention_processor import CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 |
|
self.attn1 = Attention( |
|
query_dim=dim, |
|
dim_head=attention_head_dim, |
|
heads=num_attention_heads, |
|
qk_norm="layer_norm" if qk_norm else None, |
|
eps=1e-6, |
|
bias=attention_bias, |
|
out_bias=attention_out_bias, |
|
processor=CogVideoXAttnProcessor2_0(), |
|
) |
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 |
|
def fuse_qkv_projections(self): |
|
""" |
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) |
|
are fused. For cross-attention modules, key and value projection matrices are fused. |
|
|
|
> [!WARNING] > This API is 🧪 experimental. |
|
""" |
|
self.original_attn_processors = None |
|
|
|
for _, attn_processor in self.attn_processors.items(): |
|
if "Added" in str(attn_processor.__class__.__name__): |
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") |
|
|
|
self.original_attn_processors = self.attn_processors |
|
|
|
for module in self.modules(): |
|
if isinstance(module, Attention): |
|
module.fuse_projections(fuse=True) |
|
|
|
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) |
|
class CogVideoXAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on |
|
query and key vectors, but does not include spatial normalization. |
|
""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor | None = None, |
|
image_rotary_emb: torch.Tensor | None = None, |
|
) -> torch.Tensor: |
|
text_seq_length = encoder_hidden_states.size(1) |
|
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
|
|
batch_size, sequence_length, _ = hidden_states.shape |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
# Apply RoPE if needed |
|
if image_rotary_emb is not None: |
|
from .embeddings import apply_rotary_emb |
|
|
|
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) |
|
if not attn.is_cross_attention: |
|
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) |
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
class FusedCogVideoXAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on |
|
query and key vectors, but does not include spatial normalization. |
|
""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor | None = None, |
|
image_rotary_emb: torch.Tensor | None = None, |
|
) -> torch.Tensor: |
|
text_seq_length = encoder_hidden_states.size(1) |
|
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
qkv = attn.to_qkv(hidden_states) |
|
split_size = qkv.shape[-1] // 3 |
|
query, key, value = torch.split(qkv, split_size, dim=-1) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
# Apply RoPE if needed |
|
if image_rotary_emb is not None: |
|
from .embeddings import apply_rotary_emb |
|
|
|
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) |
|
if not attn.is_cross_attention: |
|
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) |
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
Problem:
CogVideoXAttnProcessor2_0 and FusedCogVideoXAttnProcessor2_0 are shared legacy processors without _attention_backend / _parallel_config, and they call F.scaled_dot_product_attention directly. CogVideoXTransformer3DModel.set_attention_backend(...) therefore leaves them unchanged and cannot route through the dispatcher.
Impact:
CogVideoX cannot use the model-level attention backend infrastructure consistently, including alternate kernels and parallel attention integrations covered by the review rules.
Reproduction:
from diffusers import CogVideoXTransformer3DModel
model = CogVideoXTransformer3DModel(
num_attention_heads=2, attention_head_dim=8, in_channels=4, out_channels=4,
time_embed_dim=2, text_embed_dim=8, num_layers=1,
sample_width=8, sample_height=8, sample_frames=8,
patch_size=2, temporal_compression_ratio=4, max_text_seq_length=8,
)
model.set_attention_backend("native")
processors = list(model.attn_processors.values())
print([type(p).__name__ for p in processors])
print([hasattr(p, "_attention_backend") for p in processors])
assert all(hasattr(p, "_attention_backend") for p in processors)
Relevant precedent:
|
_attention_backend = None |
|
_parallel_config = None |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError( |
|
"WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." |
|
) |
|
|
|
def __call__( |
|
self, |
|
attn: "WanAttention", |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor | None = None, |
|
attention_mask: torch.Tensor | None = None, |
|
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, |
|
) -> torch.Tensor: |
|
encoder_hidden_states_img = None |
|
if attn.add_k_proj is not None: |
|
# 512 is the context length of the text encoder, hardcoded for now |
|
image_context_length = encoder_hidden_states.shape[1] - 512 |
|
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] |
|
encoder_hidden_states = encoder_hidden_states[:, image_context_length:] |
|
|
|
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) |
|
|
|
query = attn.norm_q(query) |
|
key = attn.norm_k(key) |
|
|
|
query = query.unflatten(2, (attn.heads, -1)) |
|
key = key.unflatten(2, (attn.heads, -1)) |
|
value = value.unflatten(2, (attn.heads, -1)) |
|
|
|
if rotary_emb is not None: |
|
|
|
def apply_rotary_emb( |
|
hidden_states: torch.Tensor, |
|
freqs_cos: torch.Tensor, |
|
freqs_sin: torch.Tensor, |
|
): |
|
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) |
|
cos = freqs_cos[..., 0::2] |
|
sin = freqs_sin[..., 1::2] |
|
out = torch.empty_like(hidden_states) |
|
out[..., 0::2] = x1 * cos - x2 * sin |
|
out[..., 1::2] = x1 * sin + x2 * cos |
|
return out.type_as(hidden_states) |
|
|
|
query = apply_rotary_emb(query, *rotary_emb) |
|
key = apply_rotary_emb(key, *rotary_emb) |
|
|
|
# I2V task |
|
hidden_states_img = None |
|
if encoder_hidden_states_img is not None: |
|
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) |
|
key_img = attn.norm_added_k(key_img) |
|
|
|
key_img = key_img.unflatten(2, (attn.heads, -1)) |
|
value_img = value_img.unflatten(2, (attn.heads, -1)) |
|
|
|
hidden_states_img = dispatch_attention_fn( |
|
query, |
|
key_img, |
|
value_img, |
|
attn_mask=None, |
|
dropout_p=0.0, |
|
is_causal=False, |
|
backend=self._attention_backend, |
|
# Reference: https://github.com/huggingface/diffusers/pull/12909 |
|
parallel_config=None, |
|
) |
|
hidden_states_img = hidden_states_img.flatten(2, 3) |
|
hidden_states_img = hidden_states_img.type_as(query) |
|
|
|
hidden_states = dispatch_attention_fn( |
|
from ..attention_dispatch import dispatch_attention_fn |
|
from ..cache_utils import CacheMixin |
|
from ..embeddings import ( |
|
CombinedTimestepGuidanceTextProjEmbeddings, |
|
CombinedTimestepTextProjEmbeddings, |
|
apply_rotary_emb, |
|
get_1d_rotary_pos_embed, |
|
) |
|
from ..modeling_outputs import Transformer2DModelOutput |
|
from ..modeling_utils import ModelMixin |
|
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle |
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
|
|
|
|
|
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): |
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
|
|
encoder_query = encoder_key = encoder_value = None |
|
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: |
|
encoder_query = attn.add_q_proj(encoder_hidden_states) |
|
encoder_key = attn.add_k_proj(encoder_hidden_states) |
|
encoder_value = attn.add_v_proj(encoder_hidden_states) |
|
|
|
return query, key, value, encoder_query, encoder_key, encoder_value |
|
|
|
|
|
def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): |
|
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) |
|
|
|
encoder_query = encoder_key = encoder_value = (None,) |
|
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): |
|
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) |
|
|
|
return query, key, value, encoder_query, encoder_key, encoder_value |
|
|
|
|
|
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): |
|
if attn.fused_projections: |
|
return _get_fused_projections(attn, hidden_states, encoder_hidden_states) |
|
return _get_projections(attn, hidden_states, encoder_hidden_states) |
|
|
|
|
|
class FluxAttnProcessor: |
|
_attention_backend = None |
|
_parallel_config = None |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") |
|
|
|
def __call__( |
|
self, |
|
attn: "FluxAttention", |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor = None, |
|
attention_mask: torch.Tensor | None = None, |
|
image_rotary_emb: torch.Tensor | None = None, |
|
) -> torch.Tensor: |
|
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( |
|
attn, hidden_states, encoder_hidden_states |
|
) |
|
|
|
query = query.unflatten(-1, (attn.heads, -1)) |
|
key = key.unflatten(-1, (attn.heads, -1)) |
|
value = value.unflatten(-1, (attn.heads, -1)) |
|
|
|
query = attn.norm_q(query) |
|
key = attn.norm_k(key) |
|
|
|
if attn.added_kv_proj_dim is not None: |
|
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) |
|
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) |
|
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) |
|
|
|
encoder_query = attn.norm_added_q(encoder_query) |
|
encoder_key = attn.norm_added_k(encoder_key) |
|
|
|
query = torch.cat([encoder_query, query], dim=1) |
|
key = torch.cat([encoder_key, key], dim=1) |
|
value = torch.cat([encoder_value, value], dim=1) |
|
|
|
if image_rotary_emb is not None: |
|
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
|
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
|
|
|
hidden_states = dispatch_attention_fn( |
|
query, |
|
key, |
|
value, |
|
attn_mask=attention_mask, |
|
backend=self._attention_backend, |
Suggested fix:
Move CogVideoX attention processors into cogvideox_transformer_3d.py as model-local processors, add _attention_backend and _parallel_config, and replace direct SDPA calls with dispatcher calls while preserving text/video split and RoPE behavior:
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
Issue 6: Dynamic CFG uses scheduler timestep values as denoising progress
Affected code:
|
if use_dynamic_cfg: |
|
self._guidance_scale = 1 + guidance_scale * ( |
|
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 |
|
) |
|
if use_dynamic_cfg: |
|
self._guidance_scale = 1 + guidance_scale * ( |
|
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 |
|
) |
|
if use_dynamic_cfg: |
|
self._guidance_scale = 1 + guidance_scale * ( |
|
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 |
|
) |
|
if use_dynamic_cfg: |
|
self._guidance_scale = 1 + guidance_scale * ( |
|
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 |
|
) |
Problem:
This is already tracked by #9641, so I am not presenting it as new. The dynamic CFG formula uses t.item() from scheduler timesteps, not the denoising loop index/progress. With standard timesteps like 980, 960, ..., the expression does not represent normalized progress and can exceed the requested guidance scale because it computes 1 + guidance_scale * ....
Impact:
use_dynamic_cfg=True applies an unintuitive and scheduler-dependent guidance schedule across all CogVideoX pipelines.
Reproduction:
import math
from diffusers import CogVideoXDDIMScheduler
guidance_scale = 6
num_inference_steps = 50
scheduler = CogVideoXDDIMScheduler()
scheduler.set_timesteps(num_inference_steps)
scales = [
1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
for t in scheduler.timesteps
]
print(scheduler.timesteps[:5].tolist())
print([round(x, 3) for x in scales[:12]], max(scales))
assert max(scales) <= guidance_scale
Relevant precedent:
The intended behavior should be resolved in the existing issue: #9641
Suggested fix:
Use loop progress instead of scheduler timestep value, and confirm the intended max scale against the CogVideoX implementation:
progress = (num_inference_steps - i) / num_inference_steps
self._guidance_scale = 1 + (guidance_scale - 1) * ((1 - math.cos(math.pi * progress**5.0)) / 2)
Issue 7: Slow tests are missing for FunControl and Video-to-Video
Affected code:
|
# Copyright 2025 The HuggingFace Team. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import inspect |
|
import unittest |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel |
|
|
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXFunControlPipeline, CogVideoXTransformer3DModel, DDIMScheduler |
|
|
|
from ...testing_utils import ( |
|
enable_full_determinism, |
|
torch_device, |
|
) |
|
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS |
|
from ..test_pipelines_common import ( |
|
PipelineTesterMixin, |
|
check_qkv_fusion_matches_attn_procs_length, |
|
check_qkv_fusion_processors_exist, |
|
to_np, |
|
) |
|
|
|
|
|
enable_full_determinism() |
|
|
|
|
|
class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
|
pipeline_class = CogVideoXFunControlPipeline |
|
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} |
|
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"control_video"}) |
|
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS |
|
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS |
|
required_optional_params = frozenset( |
|
[ |
|
"num_inference_steps", |
|
"generator", |
|
"latents", |
|
"return_dict", |
|
"callback_on_step_end", |
|
"callback_on_step_end_tensor_inputs", |
|
] |
|
) |
|
test_xformers_attention = False |
|
test_layerwise_casting = True |
|
test_group_offloading = True |
|
|
|
def get_dummy_components(self): |
|
torch.manual_seed(0) |
|
transformer = CogVideoXTransformer3DModel( |
|
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings |
|
# But, since we are using tiny-random-t5 here, we need the internal dim of CogVideoXTransformer3DModel |
|
# to be 32. The internal dim is product of num_attention_heads and attention_head_dim |
|
num_attention_heads=4, |
|
attention_head_dim=8, |
|
in_channels=8, |
|
out_channels=4, |
|
time_embed_dim=2, |
|
text_embed_dim=32, # Must match with tiny-random-t5 |
|
num_layers=1, |
|
sample_width=2, # latent width: 2 -> final width: 16 |
|
sample_height=2, # latent height: 2 -> final height: 16 |
|
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 |
|
patch_size=2, |
|
temporal_compression_ratio=4, |
|
max_text_seq_length=16, |
|
) |
|
|
|
torch.manual_seed(0) |
|
vae = AutoencoderKLCogVideoX( |
|
in_channels=3, |
|
out_channels=3, |
|
down_block_types=( |
|
"CogVideoXDownBlock3D", |
|
"CogVideoXDownBlock3D", |
|
"CogVideoXDownBlock3D", |
|
"CogVideoXDownBlock3D", |
|
), |
|
up_block_types=( |
|
"CogVideoXUpBlock3D", |
|
"CogVideoXUpBlock3D", |
|
"CogVideoXUpBlock3D", |
|
"CogVideoXUpBlock3D", |
|
), |
|
block_out_channels=(8, 8, 8, 8), |
|
latent_channels=4, |
|
layers_per_block=1, |
|
norm_num_groups=2, |
|
temporal_compression_ratio=4, |
|
) |
|
|
|
torch.manual_seed(0) |
|
scheduler = DDIMScheduler() |
|
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5") |
|
text_encoder = T5EncoderModel(config) |
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") |
|
|
|
components = { |
|
"transformer": transformer, |
|
"vae": vae, |
|
"scheduler": scheduler, |
|
"text_encoder": text_encoder, |
|
"tokenizer": tokenizer, |
|
} |
|
return components |
|
|
|
def get_dummy_inputs(self, device, seed: int = 0, num_frames: int = 8): |
|
if str(device).startswith("mps"): |
|
generator = torch.manual_seed(seed) |
|
else: |
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
# Cannot reduce because convolution kernel becomes bigger than sample |
|
height = 16 |
|
width = 16 |
|
|
|
control_video = [Image.new("RGB", (width, height))] * num_frames |
|
|
|
inputs = { |
|
"prompt": "dance monkey", |
|
"negative_prompt": "", |
|
"control_video": control_video, |
|
"generator": generator, |
|
"num_inference_steps": 2, |
|
"guidance_scale": 6.0, |
|
"height": height, |
|
"width": width, |
|
"max_sequence_length": 16, |
|
"output_type": "pt", |
|
} |
|
return inputs |
|
|
|
def test_inference(self): |
|
device = "cpu" |
|
|
|
components = self.get_dummy_components() |
|
pipe = self.pipeline_class(**components) |
|
pipe.to(device) |
|
pipe.set_progress_bar_config(disable=None) |
|
|
|
inputs = self.get_dummy_inputs(device) |
|
video = pipe(**inputs).frames |
|
generated_video = video[0] |
|
|
|
self.assertEqual(generated_video.shape, (8, 3, 16, 16)) |
|
expected_video = torch.randn(8, 3, 16, 16) |
|
max_diff = np.abs(generated_video - expected_video).max() |
|
self.assertLessEqual(max_diff, 1e10) |
|
|
|
def test_callback_inputs(self): |
|
sig = inspect.signature(self.pipeline_class.__call__) |
|
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters |
|
has_callback_step_end = "callback_on_step_end" in sig.parameters |
|
|
|
if not (has_callback_tensor_inputs and has_callback_step_end): |
|
return |
|
|
|
components = self.get_dummy_components() |
|
pipe = self.pipeline_class(**components) |
|
pipe = pipe.to(torch_device) |
|
pipe.set_progress_bar_config(disable=None) |
|
self.assertTrue( |
|
hasattr(pipe, "_callback_tensor_inputs"), |
|
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", |
|
) |
|
|
|
def callback_inputs_subset(pipe, i, t, callback_kwargs): |
|
# iterate over callback args |
|
for tensor_name, tensor_value in callback_kwargs.items(): |
|
# check that we're only passing in allowed tensor inputs |
|
assert tensor_name in pipe._callback_tensor_inputs |
|
|
|
return callback_kwargs |
|
|
|
def callback_inputs_all(pipe, i, t, callback_kwargs): |
|
for tensor_name in pipe._callback_tensor_inputs: |
|
assert tensor_name in callback_kwargs |
|
|
|
# iterate over callback args |
|
for tensor_name, tensor_value in callback_kwargs.items(): |
|
# check that we're only passing in allowed tensor inputs |
|
assert tensor_name in pipe._callback_tensor_inputs |
|
|
|
return callback_kwargs |
|
|
|
inputs = self.get_dummy_inputs(torch_device) |
|
|
|
# Test passing in a subset |
|
inputs["callback_on_step_end"] = callback_inputs_subset |
|
inputs["callback_on_step_end_tensor_inputs"] = ["latents"] |
|
output = pipe(**inputs)[0] |
|
|
|
# Test passing in a everything |
|
inputs["callback_on_step_end"] = callback_inputs_all |
|
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs |
|
output = pipe(**inputs)[0] |
|
|
|
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): |
|
is_last = i == (pipe.num_timesteps - 1) |
|
if is_last: |
|
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) |
|
return callback_kwargs |
|
|
|
inputs["callback_on_step_end"] = callback_inputs_change_tensor |
|
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs |
|
output = pipe(**inputs)[0] |
|
assert output.abs().sum() < 1e10 |
|
|
|
def test_inference_batch_single_identical(self): |
|
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) |
|
|
|
def test_attention_slicing_forward_pass( |
|
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 |
|
): |
|
if not self.test_attention_slicing: |
|
return |
|
|
|
components = self.get_dummy_components() |
|
for key in components: |
|
if "text_encoder" in key and hasattr(components[key], "eval"): |
|
components[key].eval() |
|
pipe = self.pipeline_class(**components) |
|
for component in pipe.components.values(): |
|
if hasattr(component, "set_default_attn_processor"): |
|
component.set_default_attn_processor() |
|
pipe.to(torch_device) |
|
pipe.set_progress_bar_config(disable=None) |
|
|
|
generator_device = "cpu" |
|
inputs = self.get_dummy_inputs(generator_device) |
|
output_without_slicing = pipe(**inputs)[0] |
|
|
|
pipe.enable_attention_slicing(slice_size=1) |
|
inputs = self.get_dummy_inputs(generator_device) |
|
output_with_slicing1 = pipe(**inputs)[0] |
|
|
|
pipe.enable_attention_slicing(slice_size=2) |
|
inputs = self.get_dummy_inputs(generator_device) |
|
output_with_slicing2 = pipe(**inputs)[0] |
|
|
|
if test_max_difference: |
|
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() |
|
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() |
|
self.assertLess( |
|
max(max_diff1, max_diff2), |
|
expected_max_diff, |
|
"Attention slicing should not affect the inference results", |
|
) |
|
|
|
def test_vae_tiling(self, expected_diff_max: float = 0.5): |
|
# NOTE(aryan): This requires a higher expected_max_diff than other CogVideoX pipelines |
|
generator_device = "cpu" |
|
components = self.get_dummy_components() |
|
|
|
pipe = self.pipeline_class(**components) |
|
pipe.to("cpu") |
|
pipe.set_progress_bar_config(disable=None) |
|
|
|
# Without tiling |
|
inputs = self.get_dummy_inputs(generator_device) |
|
inputs["height"] = inputs["width"] = 128 |
|
output_without_tiling = pipe(**inputs)[0] |
|
|
|
# With tiling |
|
pipe.vae.enable_tiling( |
|
tile_sample_min_height=96, |
|
tile_sample_min_width=96, |
|
tile_overlap_factor_height=1 / 12, |
|
tile_overlap_factor_width=1 / 12, |
|
) |
|
inputs = self.get_dummy_inputs(generator_device) |
|
inputs["height"] = inputs["width"] = 128 |
|
output_with_tiling = pipe(**inputs)[0] |
|
|
|
self.assertLess( |
|
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(), |
|
expected_diff_max, |
|
"VAE tiling should not affect the inference results", |
|
) |
|
|
|
def test_fused_qkv_projections(self): |
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator |
|
components = self.get_dummy_components() |
|
pipe = self.pipeline_class(**components) |
|
pipe = pipe.to(device) |
|
pipe.set_progress_bar_config(disable=None) |
|
|
|
inputs = self.get_dummy_inputs(device) |
|
frames = pipe(**inputs).frames # [B, F, C, H, W] |
|
original_image_slice = frames[0, -2:, -1, -3:, -3:] |
|
|
|
pipe.fuse_qkv_projections() |
|
assert check_qkv_fusion_processors_exist(pipe.transformer), ( |
|
"Something wrong with the fused attention processors. Expected all the attention processors to be fused." |
|
) |
|
assert check_qkv_fusion_matches_attn_procs_length( |
|
pipe.transformer, pipe.transformer.original_attn_processors |
|
), "Something wrong with the attention processors concerning the fused QKV projections." |
|
|
|
inputs = self.get_dummy_inputs(device) |
|
frames = pipe(**inputs).frames |
|
image_slice_fused = frames[0, -2:, -1, -3:, -3:] |
|
|
|
pipe.transformer.unfuse_qkv_projections() |
|
inputs = self.get_dummy_inputs(device) |
|
frames = pipe(**inputs).frames |
|
image_slice_disabled = frames[0, -2:, -1, -3:, -3:] |
|
|
|
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( |
|
"Fusion of QKV projections shouldn't affect the outputs." |
|
) |
|
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( |
|
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." |
|
) |
|
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( |
|
"Original outputs should match when fused QKV projections are disabled." |
|
) |
|
# Copyright 2025 The HuggingFace Team. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import inspect |
|
import unittest |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel |
|
|
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler |
|
|
|
from ...testing_utils import enable_full_determinism, torch_device |
|
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS |
|
from ..test_pipelines_common import ( |
|
PipelineTesterMixin, |
|
check_qkv_fusion_matches_attn_procs_length, |
|
check_qkv_fusion_processors_exist, |
|
to_np, |
|
) |
|
|
|
|
|
enable_full_determinism() |
|
|
|
|
|
class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
|
pipeline_class = CogVideoXVideoToVideoPipeline |
|
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} |
|
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"video"}) |
|
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS |
|
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS |
|
required_optional_params = frozenset( |
|
[ |
|
"num_inference_steps", |
|
"generator", |
|
"latents", |
|
"return_dict", |
|
"callback_on_step_end", |
|
"callback_on_step_end_tensor_inputs", |
|
] |
|
) |
|
test_xformers_attention = False |
|
|
|
def get_dummy_components(self): |
|
torch.manual_seed(0) |
|
transformer = CogVideoXTransformer3DModel( |
|
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings |
|
# But, since we are using tiny-random-t5 here, we need the internal dim of CogVideoXTransformer3DModel |
|
# to be 32. The internal dim is product of num_attention_heads and attention_head_dim |
|
num_attention_heads=4, |
|
attention_head_dim=8, |
|
in_channels=4, |
|
out_channels=4, |
|
time_embed_dim=2, |
|
text_embed_dim=32, # Must match with tiny-random-t5 |
|
num_layers=1, |
|
sample_width=2, # latent width: 2 -> final width: 16 |
|
sample_height=2, # latent height: 2 -> final height: 16 |
|
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 |
|
patch_size=2, |
|
temporal_compression_ratio=4, |
|
max_text_seq_length=16, |
|
) |
|
|
|
torch.manual_seed(0) |
|
vae = AutoencoderKLCogVideoX( |
|
in_channels=3, |
|
out_channels=3, |
|
down_block_types=( |
|
"CogVideoXDownBlock3D", |
|
"CogVideoXDownBlock3D", |
|
"CogVideoXDownBlock3D", |
|
"CogVideoXDownBlock3D", |
|
), |
|
up_block_types=( |
|
"CogVideoXUpBlock3D", |
|
"CogVideoXUpBlock3D", |
|
"CogVideoXUpBlock3D", |
|
"CogVideoXUpBlock3D", |
|
), |
|
block_out_channels=(8, 8, 8, 8), |
|
latent_channels=4, |
|
layers_per_block=1, |
|
norm_num_groups=2, |
|
temporal_compression_ratio=4, |
|
) |
|
|
|
torch.manual_seed(0) |
|
scheduler = DDIMScheduler() |
|
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5") |
|
text_encoder = T5EncoderModel(config) |
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") |
|
|
|
components = { |
|
"transformer": transformer, |
|
"vae": vae, |
|
"scheduler": scheduler, |
|
"text_encoder": text_encoder, |
|
"tokenizer": tokenizer, |
|
} |
|
return components |
|
|
|
def get_dummy_inputs(self, device, seed: int = 0, num_frames: int = 8): |
|
if str(device).startswith("mps"): |
|
generator = torch.manual_seed(seed) |
|
else: |
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
video_height = 16 |
|
video_width = 16 |
|
video = [Image.new("RGB", (video_width, video_height))] * num_frames |
|
|
|
inputs = { |
|
"video": video, |
|
"prompt": "dance monkey", |
|
"negative_prompt": "", |
|
"generator": generator, |
|
"num_inference_steps": 2, |
|
"strength": 0.5, |
|
"guidance_scale": 6.0, |
|
# Cannot reduce because convolution kernel becomes bigger than sample |
|
"height": video_height, |
|
"width": video_width, |
|
"max_sequence_length": 16, |
|
"output_type": "pt", |
|
} |
|
return inputs |
|
|
|
def test_inference(self): |
|
device = "cpu" |
|
|
|
components = self.get_dummy_components() |
|
pipe = self.pipeline_class(**components) |
|
pipe.to(device) |
|
pipe.set_progress_bar_config(disable=None) |
|
|
|
inputs = self.get_dummy_inputs(device) |
|
video = pipe(**inputs).frames |
|
generated_video = video[0] |
|
|
|
self.assertEqual(generated_video.shape, (8, 3, 16, 16)) |
|
expected_video = torch.randn(8, 3, 16, 16) |
|
max_diff = np.abs(generated_video - expected_video).max() |
|
self.assertLessEqual(max_diff, 1e10) |
|
|
|
def test_callback_inputs(self): |
|
sig = inspect.signature(self.pipeline_class.__call__) |
|
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters |
|
has_callback_step_end = "callback_on_step_end" in sig.parameters |
|
|
|
if not (has_callback_tensor_inputs and has_callback_step_end): |
|
return |
|
|
|
components = self.get_dummy_components() |
|
pipe = self.pipeline_class(**components) |
|
pipe = pipe.to(torch_device) |
|
pipe.set_progress_bar_config(disable=None) |
|
self.assertTrue( |
|
hasattr(pipe, "_callback_tensor_inputs"), |
|
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", |
|
) |
|
|
|
def callback_inputs_subset(pipe, i, t, callback_kwargs): |
|
# iterate over callback args |
|
for tensor_name, tensor_value in callback_kwargs.items(): |
|
# check that we're only passing in allowed tensor inputs |
|
assert tensor_name in pipe._callback_tensor_inputs |
|
|
|
return callback_kwargs |
|
|
|
def callback_inputs_all(pipe, i, t, callback_kwargs): |
|
for tensor_name in pipe._callback_tensor_inputs: |
|
assert tensor_name in callback_kwargs |
|
|
|
# iterate over callback args |
|
for tensor_name, tensor_value in callback_kwargs.items(): |
|
# check that we're only passing in allowed tensor inputs |
|
assert tensor_name in pipe._callback_tensor_inputs |
|
|
|
return callback_kwargs |
|
|
|
inputs = self.get_dummy_inputs(torch_device) |
|
|
|
# Test passing in a subset |
|
inputs["callback_on_step_end"] = callback_inputs_subset |
|
inputs["callback_on_step_end_tensor_inputs"] = ["latents"] |
|
output = pipe(**inputs)[0] |
|
|
|
# Test passing in a everything |
|
inputs["callback_on_step_end"] = callback_inputs_all |
|
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs |
|
output = pipe(**inputs)[0] |
|
|
|
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): |
|
is_last = i == (pipe.num_timesteps - 1) |
|
if is_last: |
|
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) |
|
return callback_kwargs |
|
|
|
inputs["callback_on_step_end"] = callback_inputs_change_tensor |
|
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs |
|
output = pipe(**inputs)[0] |
|
assert output.abs().sum() < 1e10 |
|
|
|
def test_inference_batch_single_identical(self): |
|
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) |
|
|
|
def test_attention_slicing_forward_pass( |
|
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 |
|
): |
|
if not self.test_attention_slicing: |
|
return |
|
|
|
components = self.get_dummy_components() |
|
pipe = self.pipeline_class(**components) |
|
for component in pipe.components.values(): |
|
if hasattr(component, "set_default_attn_processor"): |
|
component.set_default_attn_processor() |
|
pipe.to(torch_device) |
|
pipe.set_progress_bar_config(disable=None) |
|
|
|
generator_device = "cpu" |
|
inputs = self.get_dummy_inputs(generator_device) |
|
output_without_slicing = pipe(**inputs)[0] |
|
|
|
pipe.enable_attention_slicing(slice_size=1) |
|
inputs = self.get_dummy_inputs(generator_device) |
|
output_with_slicing1 = pipe(**inputs)[0] |
|
|
|
pipe.enable_attention_slicing(slice_size=2) |
|
inputs = self.get_dummy_inputs(generator_device) |
|
output_with_slicing2 = pipe(**inputs)[0] |
|
|
|
if test_max_difference: |
|
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() |
|
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() |
|
self.assertLess( |
|
max(max_diff1, max_diff2), |
|
expected_max_diff, |
|
"Attention slicing should not affect the inference results", |
|
) |
|
|
|
def test_vae_tiling(self, expected_diff_max: float = 0.2): |
|
# Since VideoToVideo uses both encoder and decoder tiling, there seems to be much more numerical |
|
# difference. We seem to need a higher tolerance here... |
|
# TODO(aryan): Look into this more deeply |
|
expected_diff_max = 0.4 |
|
|
|
generator_device = "cpu" |
|
components = self.get_dummy_components() |
|
|
|
pipe = self.pipeline_class(**components) |
|
pipe.to("cpu") |
|
pipe.set_progress_bar_config(disable=None) |
|
|
|
# Without tiling |
|
inputs = self.get_dummy_inputs(generator_device) |
|
inputs["height"] = inputs["width"] = 128 |
|
output_without_tiling = pipe(**inputs)[0] |
|
|
|
# With tiling |
|
pipe.vae.enable_tiling( |
|
tile_sample_min_height=96, |
|
tile_sample_min_width=96, |
|
tile_overlap_factor_height=1 / 12, |
|
tile_overlap_factor_width=1 / 12, |
|
) |
|
inputs = self.get_dummy_inputs(generator_device) |
|
inputs["height"] = inputs["width"] = 128 |
|
output_with_tiling = pipe(**inputs)[0] |
|
|
|
self.assertLess( |
|
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(), |
|
expected_diff_max, |
|
"VAE tiling should not affect the inference results", |
|
) |
|
|
|
def test_fused_qkv_projections(self): |
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator |
|
components = self.get_dummy_components() |
|
pipe = self.pipeline_class(**components) |
|
pipe = pipe.to(device) |
|
pipe.set_progress_bar_config(disable=None) |
|
|
|
inputs = self.get_dummy_inputs(device) |
|
frames = pipe(**inputs).frames # [B, F, C, H, W] |
|
original_image_slice = frames[0, -2:, -1, -3:, -3:] |
|
|
|
pipe.fuse_qkv_projections() |
|
assert check_qkv_fusion_processors_exist(pipe.transformer), ( |
|
"Something wrong with the fused attention processors. Expected all the attention processors to be fused." |
|
) |
|
assert check_qkv_fusion_matches_attn_procs_length( |
|
pipe.transformer, pipe.transformer.original_attn_processors |
|
), "Something wrong with the attention processors concerning the fused QKV projections." |
|
|
|
inputs = self.get_dummy_inputs(device) |
|
frames = pipe(**inputs).frames |
|
image_slice_fused = frames[0, -2:, -1, -3:, -3:] |
|
|
|
pipe.transformer.unfuse_qkv_projections() |
|
inputs = self.get_dummy_inputs(device) |
|
frames = pipe(**inputs).frames |
|
image_slice_disabled = frames[0, -2:, -1, -3:, -3:] |
|
|
|
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( |
|
"Fusion of QKV projections shouldn't affect the outputs." |
|
) |
|
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( |
|
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." |
|
) |
|
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( |
|
"Original outputs should match when fused QKV projections are disabled." |
|
) |
Problem:
Fast tests exist for these pipelines, but there are no @slow integration tests for CogVideoXFunControlPipeline or CogVideoXVideoToVideoPipeline. Text-to-video and image-to-video do have slow coverage.
Impact:
Checkpoint compatibility, preprocessing with real media, and end-to-end output regressions are not covered for two public CogVideoX pipelines.
Reproduction:
from pathlib import Path
for path in [
"tests/pipelines/cogvideo/test_cogvideox_fun_control.py",
"tests/pipelines/cogvideo/test_cogvideox_video2video.py",
]:
text = Path(path).read_text()
print(path, "@slow" in text, "IntegrationTests" in text)
assert "@slow" in text and "IntegrationTests" in text
Relevant precedent:
|
@slow |
|
@require_torch_accelerator |
|
class CogVideoXPipelineIntegrationTests(unittest.TestCase): |
|
prompt = "A painting of a squirrel eating a burger." |
|
|
|
def setUp(self): |
|
super().setUp() |
|
gc.collect() |
|
backend_empty_cache(torch_device) |
|
|
|
def tearDown(self): |
|
super().tearDown() |
|
gc.collect() |
|
backend_empty_cache(torch_device) |
|
|
|
def test_cogvideox(self): |
|
generator = torch.Generator("cpu").manual_seed(0) |
|
|
|
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16) |
|
@slow |
|
@require_torch_accelerator |
|
class CogVideoXImageToVideoPipelineIntegrationTests(unittest.TestCase): |
|
prompt = "A painting of a squirrel eating a burger." |
|
|
|
def setUp(self): |
|
super().setUp() |
|
gc.collect() |
|
backend_empty_cache(torch_device) |
|
|
|
def tearDown(self): |
|
super().tearDown() |
|
gc.collect() |
|
backend_empty_cache(torch_device) |
|
|
|
def test_cogvideox(self): |
|
generator = torch.Generator("cpu").manual_seed(0) |
|
|
|
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16) |
Suggested fix:
Add slow integration classes for CogVideoXVideoToVideoPipeline and CogVideoXFunControlPipeline with published checkpoints, fixed seeds, small media fixtures, and output slice assertions. Also add a fast regression test for the control_video_latents path from Issue 2.
cogvideomodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Reviewed target pipelines, model files, lazy exports, docs/tests/examples, dtype/device paths, offload-facing behavior, attention processors, and coverage. Public imports/lazy loading looked consistent. I did not find separate actionable issues in
pipeline_output.pyorautoencoder_kl_cogvideox.py.Execution: standalone repros were run with
.venv/Scripts/python.exe; no full pytest suite was run.Duplicate search: searched GitHub Issues and PRs for
cogvideo, affected class/file names, and the specific failure modes. Exact duplicate found only for Issue 5: #9641. Related but not exact: #11133, #9972, #13586, PR #11368, PR #9333.Issue 1:
num_videos_per_promptis accepted but ignoredAffected code:
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Line 518 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Line 618 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
Line 564 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
Line 669 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
Line 612 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
Line 714 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
Line 589 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
Line 688 in 0f1abc4
Problem:
All four pipelines expose
num_videos_per_prompt, but each__call__resets it to1before prompt encoding and latent preparation. The text-only pipeline can already use the parameter correctly if that reset is removed. The conditioned pipelines also need image/video/control latents expanded to the effective batch, or they should reject values above1.Impact:
Users requesting multiple videos per prompt silently get one video per prompt. Batch behavior and callback tensor shapes are also misleading.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/wan/pipeline_wan.py
Lines 533 to 560 in 0f1abc4
diffusers/src/diffusers/pipelines/mochi/pipeline_mochi.py
Lines 635 to 650 in 0f1abc4
Suggested fix:
Issue 2:
CogVideoXFunControlPipelinecrashes whencontrol_video_latentsis suppliedAffected code:
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
Lines 557 to 568 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
Lines 751 to 756 in 0f1abc4
Problem:
The API documents
control_video_latents, andcheck_inputsonly rejects passing both raw control video and latents. But__call__unconditionally runsprepare_control_latents(None, control_video)after the preprocessing branch. When onlycontrol_video_latentsis supplied,control_videoisNone, so the user tensor is discarded and the next.permute()crashes.Impact:
The precomputed control-latent path is unusable.
Reproduction:
Relevant precedent:
The raw-image/video latent paths in the other CogVideoX pipelines keep the precomputed
latentsbranch separate from preprocessing.Suggested fix:
Issue 3: Supplied
prompt_embedsandlatentsare not cast to execution dtypeAffected code:
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Lines 282 to 323 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Lines 342 to 348 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
Lines 253 to 332 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
Lines 352 to 358 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
Lines 263 to 342 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
Lines 410 to 416 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
Lines 269 to 348 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
Lines 390 to 396 in 0f1abc4
Problem:
Generated prompt embeddings are cast, but user-supplied
prompt_embedsandnegative_prompt_embedsare returned unchanged. User-suppliedlatentsare moved to device but not dtype. With a bf16/fp16 pipeline and fp32 tensors, transformer projections hit dtype mismatches.Impact:
Documented advanced inputs break mixed precision inference.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/wan/pipeline_wan.py
Lines 544 to 547 in 0f1abc4
diffusers/src/diffusers/pipelines/mochi/pipeline_mochi.py
Lines 461 to 462 in 0f1abc4
Suggested fix:
Issue 4: Spatial validation accepts sizes that later fail patchification
Affected code:
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Lines 387 to 388 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
Lines 427 to 428 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
Lines 477 to 478 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
Lines 448 to 449 in 0f1abc4
diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py
Lines 431 to 441 in 0f1abc4
Problem:
Pipelines only require
heightandwidthto be divisible by8, the VAE scale factor. The transformer then patchifies latents withpatch_size=2, so the original size must usually be divisible by8 * 2 = 16. For example,24x24passes validation but produces latent3x3, which fails later.Impact:
Users get a late low-level tensor shape error instead of an actionable validation error.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/wan/pipeline_wan.py
Lines 497 to 511 in 0f1abc4
diffusers/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
Line 351 in 0f1abc4
Suggested fix:
Issue 5: Attention backend selection cannot affect CogVideoX attention
Affected code:
diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py
Line 26 in 0f1abc4
diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py
Lines 95 to 104 in 0f1abc4
diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py
Lines 334 to 354 in 0f1abc4
diffusers/src/diffusers/models/attention_processor.py
Lines 2277 to 2330 in 0f1abc4
diffusers/src/diffusers/models/attention_processor.py
Lines 2346 to 2401 in 0f1abc4
Problem:
CogVideoXAttnProcessor2_0andFusedCogVideoXAttnProcessor2_0are shared legacy processors without_attention_backend/_parallel_config, and they callF.scaled_dot_product_attentiondirectly.CogVideoXTransformer3DModel.set_attention_backend(...)therefore leaves them unchanged and cannot route through the dispatcher.Impact:
CogVideoX cannot use the model-level attention backend infrastructure consistently, including alternate kernels and parallel attention integrations covered by the review rules.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/transformer_wan.py
Lines 69 to 143 in 0f1abc4
diffusers/src/diffusers/models/transformers/transformer_flux.py
Lines 29 to 123 in 0f1abc4
Suggested fix:
Move CogVideoX attention processors into
cogvideox_transformer_3d.pyas model-local processors, add_attention_backendand_parallel_config, and replace direct SDPA calls with dispatcher calls while preserving text/video split and RoPE behavior:Issue 6: Dynamic CFG uses scheduler timestep values as denoising progress
Affected code:
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Lines 737 to 740 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
Lines 803 to 806 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
Lines 847 to 850 in 0f1abc4
diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
Lines 818 to 821 in 0f1abc4
Problem:
This is already tracked by #9641, so I am not presenting it as new. The dynamic CFG formula uses
t.item()from scheduler timesteps, not the denoising loop index/progress. With standard timesteps like980, 960, ..., the expression does not represent normalized progress and can exceed the requested guidance scale because it computes1 + guidance_scale * ....Impact:
use_dynamic_cfg=Trueapplies an unintuitive and scheduler-dependent guidance schedule across all CogVideoX pipelines.Reproduction:
Relevant precedent:
The intended behavior should be resolved in the existing issue: #9641
Suggested fix:
Use loop progress instead of scheduler timestep value, and confirm the intended max scale against the CogVideoX implementation:
Issue 7: Slow tests are missing for FunControl and Video-to-Video
Affected code:
diffusers/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
Lines 1 to 330 in 0f1abc4
diffusers/tests/pipelines/cogvideo/test_cogvideox_video2video.py
Lines 1 to 326 in 0f1abc4
Problem:
Fast tests exist for these pipelines, but there are no
@slowintegration tests forCogVideoXFunControlPipelineorCogVideoXVideoToVideoPipeline. Text-to-video and image-to-video do have slow coverage.Impact:
Checkpoint compatibility, preprocessing with real media, and end-to-end output regressions are not covered for two public CogVideoX pipelines.
Reproduction:
Relevant precedent:
diffusers/tests/pipelines/cogvideo/test_cogvideox.py
Lines 339 to 357 in 0f1abc4
diffusers/tests/pipelines/cogvideo/test_cogvideox_image2video.py
Lines 351 to 369 in 0f1abc4
Suggested fix:
Add slow integration classes for
CogVideoXVideoToVideoPipelineandCogVideoXFunControlPipelinewith published checkpoints, fixed seeds, small media fixtures, and output slice assertions. Also add a fast regression test for thecontrol_video_latentspath from Issue 2.