Skip to content

animatediff model/pipeline review #13599

@hlky

Description

@hlky

animatediff model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules. Public exports/imports, config/load paths, dtype/device/offload behavior, model attention behavior, docs/examples, and fast/slow tests were checked. Reproductions were run with .venv.

Duplicate search: checked GitHub Issues and PRs for animatediff, affected classes/files, num_videos_per_prompt, MultiControlNet validation, stale unet_motion_model import, SparseControlNet tests, and slow-test coverage. No exact duplicates found. Related but not duplicates: #8664, #9326, #9508, #7378.

Issue 1: Video-to-video pipelines ignore num_videos_per_prompt

Affected code:

enforce_inference_steps: bool = False,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,
guidance_scale: float = 7.5,
strength: float = 0.8,
negative_prompt: str | list[str] | None = None,
num_videos_per_prompt: int | None = 1,
eta: float = 0.0,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
ip_adapter_image: PipelineImageInput | None = None,
ip_adapter_image_embeds: list[torch.Tensor] | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
cross_attention_kwargs: dict[str, Any] | None = None,
clip_skip: int | None = None,
callback_on_step_end: Callable[[int, int], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
decode_chunk_size: int = 16,
):
r"""
The call function to the pipeline for generation.
Args:
video (`list[PipelineImageInput]`):
The input video to condition the generation on. Must be a list of images/frames of the video.
prompt (`str` or `list[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated video.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated video.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
expense of slower inference.
timesteps (`list[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`list[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
strength (`float`, *optional*, defaults to 0.8):
Higher strength leads to more differences between original video and generated video.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `list[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
`(batch_size, num_channel, num_frames, height, width)`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`list`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
decode_chunk_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples:
Returns:
[`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
num_videos_per_prompt = 1

timesteps: list[int] | None = None,
sigmas: list[float] | None = None,
guidance_scale: float = 7.5,
strength: float = 0.8,
negative_prompt: str | list[str] | None = None,
num_videos_per_prompt: int | None = 1,
eta: float = 0.0,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
ip_adapter_image: PipelineImageInput | None = None,
ip_adapter_image_embeds: list[torch.Tensor] | None = None,
conditioning_frames: list[PipelineImageInput] | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
cross_attention_kwargs: dict[str, Any] | None = None,
controlnet_conditioning_scale: float | list[float] = 1.0,
guess_mode: bool = False,
control_guidance_start: float | list[float] = 0.0,
control_guidance_end: float | list[float] = 1.0,
clip_skip: int | None = None,
callback_on_step_end: Callable[[int, int], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
decode_chunk_size: int = 16,
):
r"""
The call function to the pipeline for generation.
Args:
video (`list[PipelineImageInput]`):
The input video to condition the generation on. Must be a list of images/frames of the video.
prompt (`str` or `list[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated video.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated video.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
expense of slower inference.
timesteps (`list[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`list[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
strength (`float`, *optional*, defaults to 0.8):
Higher strength leads to more differences between original video and generated video.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `list[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
`(batch_size, num_channel, num_frames, height, width)`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
conditioning_frames (`list[PipelineImageInput]`, *optional*):
The ControlNet input condition to provide guidance to the `unet` for generation. If multiple
ControlNets are specified, images must be passed as a list such that each element of the list can be
correctly batched for input to a single ControlNet.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
the corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`):
The ControlNet encoder tries to recognize the content of the input image even if you remove all
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`list`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
decode_chunk_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples:
Returns:
[`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
num_videos_per_prompt = 1

Problem:
Both public signatures expose num_videos_per_prompt, but __call__ overwrites it with 1 before input validation, latent preparation, prompt expansion, and denoising. Users requesting multiple videos per prompt silently receive one video.

Impact:
Batch semantics are wrong and no test catches it. This also hides related latent-preparation gaps that need to duplicate/expand the input video latents for num_videos_per_prompt > 1.

Reproduction:

# Run from repo root with: .venv/Scripts/python.exe
# Shows actual batch is 1 even though num_videos_per_prompt=2.
# Uses the same tiny component shapes as the fast tests.
import torch
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, MotionAdapter, UNet2DConditionModel, AnimateDiffVideoToVideoPipeline

dim, blocks = 8, (8, 8)
unet = UNet2DConditionModel(block_out_channels=blocks, layers_per_block=2, sample_size=8, in_channels=4, out_channels=4, down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=dim, norm_num_groups=2)
vae = AutoencoderKL(block_out_channels=blocks, in_channels=3, out_channels=3, down_block_types=["DownEncoderBlock2D"] * 2, up_block_types=["UpDecoderBlock2D"] * 2, latent_channels=4, norm_num_groups=2)
text_encoder = CLIPTextModel(CLIPTextConfig(bos_token_id=0, eos_token_id=2, hidden_size=dim, intermediate_size=37, num_attention_heads=4, num_hidden_layers=5, pad_token_id=1, vocab_size=1000))
pipe = AnimateDiffVideoToVideoPipeline(unet=unet, scheduler=DDIMScheduler(), vae=vae, motion_adapter=MotionAdapter(block_out_channels=blocks, motion_layers_per_block=2, motion_norm_num_groups=2, motion_num_attention_heads=4), text_encoder=text_encoder, tokenizer=CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip"), feature_extractor=None, image_encoder=None)
pipe.set_progress_bar_config(disable=True)
out = pipe(video=[Image.new("RGB", (32, 32)) for _ in range(2)], prompt="test", num_inference_steps=1, strength=1.0, num_videos_per_prompt=2, output_type="pt").frames
print(out.shape[0])  # expected 2, actual 1

Relevant precedent:

def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many initial images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)

Suggested fix:
Remove the hard reset in both pipelines and update prepare_latents to expand the input video batch like img2img does:

# remove this line from both __call__ methods
num_videos_per_prompt = 1

# in prepare_latents, before encoding with a generator list
if isinstance(generator, list) and video.shape[0] < batch_size and batch_size % video.shape[0] == 0:
    video = torch.cat([video] * (batch_size // video.shape[0]), dim=0)

# after init_latents is built
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
    init_latents = torch.cat([init_latents] * (batch_size // init_latents.shape[0]), dim=0)

Issue 2: AnimateDiff Multi-ControlNet validation can silently drop ControlNets

Affected code:

elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if not isinstance(video, list) or not isinstance(video[0], list):
raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(video)=}")
if len(video[0]) != num_frames:
raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(video[0])=}")
if any(len(img) != len(video[0]) for img in video):
raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
else:
assert False
# Check `controlnet_conditioning_scale`
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(

isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if not isinstance(conditioning_frames, list) or not isinstance(conditioning_frames[0], list):
raise TypeError(
f"For multiple controlnets: `image` must be type list of lists but got {type(conditioning_frames)=}"
)
if len(conditioning_frames[0]) != num_frames:
raise ValueError(
f"Expected length of image sublist as {num_frames} but got {len(conditioning_frames)=}"
)
if any(len(img) != len(conditioning_frames[0]) for img in conditioning_frames):
raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
else:
assert False
# Check `controlnet_conditioning_scale`
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(

Problem:
For MultiControlNetModel, the conditioning-frame list length is not checked against len(controlnet.nets). The scale length check is also unreachable because it is under elif isinstance(controlnet_conditioning_scale, list) after an if isinstance(..., list) branch.

Impact:
A user can pass two ControlNets but only one conditioning video or one scale. Validation succeeds, then MultiControlNetModel.forward zips the lists and silently skips the extra ControlNet.

Reproduction:

import torch
from diffusers import AnimateDiffControlNetPipeline, AnimateDiffVideoToVideoControlNetPipeline, MultiControlNetModel

multi = MultiControlNetModel([torch.nn.Identity(), torch.nn.Identity()])

pipe = object.__new__(AnimateDiffControlNetPipeline)
pipe.controlnet = multi
pipe.check_inputs(prompt="x", height=64, width=64, num_frames=2, video=[[object(), object()]], controlnet_conditioning_scale=[1.0], control_guidance_start=[0.0, 0.0], control_guidance_end=[1.0, 1.0])
print("text2video validation passed unexpectedly")

pipe = object.__new__(AnimateDiffVideoToVideoControlNetPipeline)
pipe.controlnet = multi
pipe.check_inputs(prompt="x", height=64, width=64, strength=0.8, video=[object(), object()], conditioning_frames=[[object(), object()]], latents=None, controlnet_conditioning_scale=[1.0], control_guidance_start=[0.0, 0.0], control_guidance_end=[1.0, 1.0])
print("video2video validation passed unexpectedly")

Relevant precedent:

if isinstance(controlnet, SD3ControlNetModel):
self.check_image(image, prompt, prompt_embeds)
elif isinstance(controlnet, SD3MultiControlNetModel):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")
elif len(image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
)
for image_ in image:
self.check_image(image_, prompt, prompt_embeds)
# Check `controlnet_conditioning_scale`
if isinstance(controlnet, SD3MultiControlNetModel):
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
raise ValueError(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)

for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
down_samples, mid_sample = controlnet(

Suggested fix:

if isinstance(controlnet, MultiControlNetModel):
    if len(video) != len(controlnet.nets):
        raise ValueError(
            f"For multiple controlnets: expected {len(controlnet.nets)} conditioning videos, got {len(video)}."
        )

    if isinstance(controlnet_conditioning_scale, list):
        if any(isinstance(i, list) for i in controlnet_conditioning_scale):
            raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
        if len(controlnet_conditioning_scale) != len(controlnet.nets):
            raise ValueError(
                "`controlnet_conditioning_scale` must have the same length as the number of controlnets."
            )

Issue 3: Community AnimateDiff image-to-video example imports a removed module path

Affected code:

from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unet_motion_model import MotionAdapter
from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput

Problem:
The example imports MotionAdapter from diffusers.models.unet_motion_model, but that module path does not exist. The public import is available from diffusers.

Impact:
The community pipeline fails at import time before users can run it.

Reproduction:

from diffusers.models.unet_motion_model import MotionAdapter
# ModuleNotFoundError: No module named 'diffusers.models.unet_motion_model'

Relevant precedent:

from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unets.unet_motion_model import MotionAdapter
from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput

Suggested fix:

from diffusers import MotionAdapter

Issue 4: Missing slow coverage for most AnimateDiff variants and missing model tests for SparseControlNetModel

Affected code:

@slow
@require_torch_accelerator
class AnimateDiffPipelineSlowTests(unittest.TestCase):

class AnimateDiffControlNetPipelineFastTests(
IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
):
pipeline_class = AnimateDiffControlNetPipeline

class AnimateDiffSparseControlNetPipelineFastTests(
IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
):
pipeline_class = AnimateDiffSparseControlNetPipeline

class AnimateDiffPipelineSDXLFastTests(
IPAdapterTesterMixin,
SDFunctionTesterMixin,
PipelineTesterMixin,
unittest.TestCase,
):
pipeline_class = AnimateDiffSDXLPipeline

class SparseControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin):
"""
A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
Models](https://huggingface.co/papers/2311.16933).
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
conditioning_channels (`int`, defaults to 4):
The number of input channels in the controlnet conditional embedding module. If
`concat_condition_embedding` is True, the value provided here is incremented by 1.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
only_cross_attention (`bool | tuple[bool]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
transformer_layers_per_mid_block (`int` or `tuple[int]`, *optional*, defaults to 1):
The number of transformer layers to use in each layer in the middle block.
attention_head_dim (`int` or `tuple[int]`, defaults to 8):
The dimension of the attention heads.
num_attention_heads (`int` or `tuple[int]`, *optional*):
The number of heads to use for multi-head attention.
use_linear_projection (`bool`, defaults to `False`):
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
TODO(Patrick) - unused parameter
controlnet_conditioning_channel_order (`str`, defaults to `rgb`):
motion_max_seq_length (`int`, defaults to `32`):
The maximum sequence length to use in the motion module.
motion_num_attention_heads (`int` or `tuple[int]`, defaults to `8`):
The number of heads to use in each attention layer of the motion module.
concat_conditioning_mask (`bool`, defaults to `True`):
use_simplified_condition_embedding (`bool`, defaults to `True`):
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(

Problem:
Only the base AnimateDiffPipeline has a slow test. There are no slow tests for ControlNet, SparseCtrl, SDXL, video-to-video, or video-to-video ControlNet. SparseControlNetModel also has no model-level test under tests/models/controlnets, so model save/load, config roundtrip, attention processor behavior, and gradient checkpointing are only indirectly covered by pipeline tests.

Impact:
Checkpoint-specific regressions and model serialization issues can ship without coverage. This is especially risky for SparseCtrl because the model is public and loadable independently from the pipeline.

Reproduction:

from pathlib import Path

text = "\n".join(p.read_text(encoding="utf-8") for p in Path("tests").rglob("test*.py"))
for name in [
    "AnimateDiffPipelineSlowTests",
    "AnimateDiffControlNetPipelineSlowTests",
    "AnimateDiffSparseControlNetPipelineSlowTests",
    "AnimateDiffPipelineSDXLSlowTests",
    "AnimateDiffVideoToVideoPipelineSlowTests",
    "AnimateDiffVideoToVideoControlNetPipelineSlowTests",
    "SparseControlNetModelTests",
]:
    print(name, name in text)

Relevant precedent:

class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetMotionModel

@slow
@require_torch_accelerator
class AnimateDiffPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_animatediff(self):
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter)
pipe = pipe.to(torch_device)
pipe.scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
steps_offset=1,
clip_sample=False,
)
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
negative_prompt = "bad quality, worse quality"
generator = torch.Generator("cpu").manual_seed(0)
output = pipe(
prompt,
negative_prompt=negative_prompt,
num_frames=16,
generator=generator,
guidance_scale=7.5,
num_inference_steps=3,
output_type="np",
)
image = output.frames[0]
assert image.shape == (16, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array(
[
0.11357737,
0.11285847,
0.11180121,
0.11084166,
0.11414117,
0.09785956,
0.10742754,
0.10510018,
0.08045256,
]
)
assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3

Suggested fix:
Add slow smoke tests for the missing public pipelines using the documented small checkpoint paths where possible, and add tests/models/controlnets/test_models_controlnet_sparsectrl.py with the standard ModelTesterMixin coverage for forward shape, save/load, variant save/load, attention processors, and gradient checkpointing.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions