Skip to content
40 changes: 19 additions & 21 deletions examples/community/pipeline_animatediff_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -27,6 +26,7 @@
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unets.unet_motion_model import MotionAdapter
from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.schedulers import (
Expand All @@ -37,7 +37,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor


Expand Down Expand Up @@ -91,10 +91,8 @@
"""


# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78

batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
Expand All @@ -103,12 +101,16 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):

outputs.append(batch_output)

return outputs
if output_type == "np":
outputs = np.stack(outputs)

elif output_type == "pt":
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")

@dataclass
class AnimateDiffControlNetPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray]
return outputs


class AnimateDiffControlNetPipeline(
Expand Down Expand Up @@ -843,8 +845,8 @@ def __call__(
Examples:

Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""

Expand Down Expand Up @@ -1020,7 +1022,7 @@ def __call__(
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)

# Denoising loop
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
Expand Down Expand Up @@ -1096,21 +1098,17 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 9. Post processing
if output_type == "latent":
return AnimateDiffControlNetPipelineOutput(frames=latents)

# Post-processing
video_tensor = self.decode_latents(latents)

if output_type == "pt":
video = video_tensor
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)

# Offload all models
# 10. Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (video,)

return AnimateDiffControlNetPipelineOutput(frames=video)
return AnimateDiffPipelineOutput(frames=video)
24 changes: 15 additions & 9 deletions examples/community/pipeline_animatediff_img2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,8 @@ def slerp(
return v2


# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78

batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
Expand All @@ -170,6 +168,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):

outputs.append(batch_output)

if output_type == "np":
outputs = np.stack(outputs)

elif output_type == "pt":
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")

return outputs


Expand Down Expand Up @@ -826,8 +833,8 @@ def __call__(
Examples:

Returns:
[`AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
# 0. Default height and width to unet
Expand Down Expand Up @@ -958,11 +965,10 @@ def __call__(
return AnimateDiffPipelineOutput(frames=latents)

# 10. Post-processing
video_tensor = self.decode_latents(latents)

if output_type == "pt":
video = video_tensor
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)

# 11. Offload all models
Expand Down
19 changes: 11 additions & 8 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")

return outputs

Expand Down Expand Up @@ -668,8 +668,8 @@ def __call__(
Examples:

Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""

Expand Down Expand Up @@ -790,6 +790,8 @@ def __call__(

self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
Expand Down Expand Up @@ -829,13 +831,14 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 9. Post processing
if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)

video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)

# 9. Offload all models
# 10. Offload all models
self.maybe_free_model_hooks()

if not return_dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")

return outputs

Expand Down Expand Up @@ -828,8 +828,8 @@ def __call__(
Examples:

Returns:
[`AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
[`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""

Expand Down Expand Up @@ -942,6 +942,7 @@ def __call__(

self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
Expand Down Expand Up @@ -980,15 +981,11 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)

# 9. Post-processing
video_tensor = self.decode_latents(latents)

if output_type == "pt":
video = video_tensor
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)

# 10. Offload all models
Expand Down
13 changes: 7 additions & 6 deletions src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")

return outputs

Expand Down Expand Up @@ -726,13 +726,14 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

# 8. Post processing
if output_type == "latent":
return I2VGenXLPipelineOutput(frames=latents)

video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video = latents
else:
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)

# Offload all models
# 9. Offload all models
self.maybe_free_model_hooks()

if not return_dict:
Expand Down
17 changes: 9 additions & 8 deletions src/diffusers/pipelines/pia/pipeline_pia.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")

return outputs

Expand Down Expand Up @@ -860,8 +860,8 @@ def __call__(
Examples:

Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
[`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
# 0. Default height and width to unet
Expand Down Expand Up @@ -1018,13 +1018,14 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

# 9. Post processing
if output_type == "latent":
return PIAPipelineOutput(frames=latents)

video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)

# 9. Offload all models
# 10. Offload all models
self.maybe_free_model_hooks()

if not return_dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")

return outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")

return outputs

Expand Down Expand Up @@ -646,13 +646,14 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

# 8. Post processing
if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents)

video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type)
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type)

# Offload all models
# 9. Offload all models
self.maybe_free_model_hooks()

if not return_dict:
Expand Down
Loading