diff --git a/src/diffusers/pipelines/animatediff/freeinit_utils.py b/src/diffusers/pipelines/animatediff/freeinit_utils.py new file mode 100644 index 000000000000..4e96ac884e7e --- /dev/null +++ b/src/diffusers/pipelines/animatediff/freeinit_utils.py @@ -0,0 +1,211 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Tuple, Union + +import torch +import torch.fft as fft + +from ...utils.torch_utils import randn_tensor + + +class FreeInitMixin: + r""" + Base class for FreeInit related utilities. A pipeline that derives from this class would mean that + it supports the FreeInit mechanism as described in https://arxiv.org/abs/2312.07537. + + The methods exposed by a pipeline that derives from this class are: + - `free_init_enabled`: Returns whether or not FreeInit has been enabled for generation. + - `enable_free_init`: Enables the usage of the FreeInit mechanism. + - `disable_free_init`: Disables the usage of the FreeInit mechanism. + """ + + @property + def free_init_enabled(self): + return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None + + def enable_free_init( + self, + num_iters: int = 3, + use_fast_sampling: bool = False, + method: str = "butterworth", + order: int = 4, + spatial_stop_frequency: float = 0.25, + temporal_stop_frequency: float = 0.25, + ): + r"""Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. + + This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). + + Args: + num_iters (`int`, *optional*, defaults to `3`): + Number of FreeInit noise re-initialization iterations. + use_fast_sampling (`bool`, *optional*, defaults to `False`): + Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables + the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. + method (`str`, *optional*, defaults to `butterworth`): + Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the + FreeInit low pass filter. + order (`int`, *optional*, defaults to `4`): + Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour + whereas lower values lead to `gaussian` method behaviour. + spatial_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in + the original implementation. + temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in + the original implementation. + """ + self._free_init_num_iters = num_iters + self._free_init_use_fast_sampling = use_fast_sampling + self._free_init_method = method + self._free_init_order = order + self._free_init_spatial_stop_frequency = spatial_stop_frequency + self._free_init_temporal_stop_frequency = temporal_stop_frequency + + def disable_free_init(self): + """Disables the FreeInit mechanism if enabled.""" + self._free_init_num_iters = None + + @staticmethod + def _get_freeinit_freq_filter( + shape: Tuple[int, ...], + device: Union[str, torch.dtype], + filter_type: str, + order: float, + spatial_stop_frequency: float, + temporal_stop_frequency: float, + ) -> torch.Tensor: + r"""Returns the FreeInit filter based on filter type and other input conditions.""" + + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + + if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: + return mask + + if filter_type == "butterworth": + + def retrieve_mask(x): + return 1 / (1 + (x / spatial_stop_frequency**2) ** order) + elif filter_type == "gaussian": + + def retrieve_mask(x): + return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) + elif filter_type == "ideal": + + def retrieve_mask(x): + return 1 if x <= spatial_stop_frequency * 2 else 0 + else: + raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") + + for t in range(T): + for h in range(H): + for w in range(W): + d_square = ( + ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / T - 1)) ** 2 + + (2 * h / H - 1) ** 2 + + (2 * w / W - 1) ** 2 + ) + mask[..., t, h, w] = retrieve_mask(d_square) + + return mask.to(device) + + @staticmethod + def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor: + r"""Noise reinitialization.""" + # FFT + x_freq = fft.fftn(x, dim=(-3, -2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) + noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) + noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) + + # frequency mix + HPF = 1 - LPF + x_freq_low = x_freq * LPF + noise_freq_high = noise_freq * HPF + x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain + + # IFFT + x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) + x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real + + return x_mixed + + def _free_init_loop( + self, + height: int, + width: int, + num_frames: int, + num_channels_latents: int, + batch_size: int, + num_videos_per_prompt: int, + denoise_args: Dict[str, Any], + device: Union[str, torch.dtype], + generator: torch.Generator, + ): + r"""Denoising loop for AnimateDiff using FreeInit noise reinitialization technique.""" + + latents = denoise_args.get("latents") + prompt_embeds = denoise_args.get("prompt_embeds") + timesteps = denoise_args.get("timesteps") + num_inference_steps = denoise_args.get("num_inference_steps") + H = height // self.vae_scale_factor + W = width // self.vae_scale_factor + bs = batch_size * num_videos_per_prompt + + latent_shape = (bs, num_channels_latents, num_frames, H, W) + free_init_filter_shape = (1, num_channels_latents, num_frames, H, W) + free_init_freq_filter = self._get_freeinit_freq_filter( + shape=free_init_filter_shape, + device=device, + filter_type=self._free_init_method, + order=self._free_init_order, + spatial_stop_frequency=self._free_init_spatial_stop_frequency, + temporal_stop_frequency=self._free_init_temporal_stop_frequency, + ) + + with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar: + for i in range(self._free_init_num_iters): + # For the first FreeInit iteration, the original latent is used without modification. + # Subsequent iterations apply the noise reinitialization technique. + if i == 0: + initial_noise = latents.detach().clone() + else: + current_diffuse_timestep = ( + self.scheduler.config.num_train_timesteps - 1 + ) # diffuse to t=999 noise level + diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() + z_T = self.scheduler.add_noise( + original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) + ).to(dtype=torch.float32) + z_rand = randn_tensor(latent_shape, generator, device, torch.float32) + latents = self._freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) + latents = latents.to(prompt_embeds.dtype) + + # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) + if self._free_init_use_fast_sampling: + current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1)) + self.scheduler.set_timesteps(current_num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps}) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps}) + latents = self._denoise_loop(**denoise_args) + + free_init_progress_bar.update() + + return latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 25b664fe3b29..29c3a6544853 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -13,13 +13,11 @@ # limitations under the License. import inspect -import math from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch -import torch.fft as fft from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor @@ -46,6 +44,7 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .freeinit_utils import FreeInitMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -88,77 +87,14 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: return outputs -def _get_freeinit_freq_filter( - shape: Tuple[int, ...], - device: Union[str, torch.dtype], - filter_type: str, - order: float, - spatial_stop_frequency: float, - temporal_stop_frequency: float, -) -> torch.Tensor: - r"""Returns the FreeInit filter based on filter type and other input conditions.""" - - T, H, W = shape[-3], shape[-2], shape[-1] - mask = torch.zeros(shape) - - if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: - return mask - - if filter_type == "butterworth": - - def retrieve_mask(x): - return 1 / (1 + (x / spatial_stop_frequency**2) ** order) - elif filter_type == "gaussian": - - def retrieve_mask(x): - return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) - elif filter_type == "ideal": - - def retrieve_mask(x): - return 1 if x <= spatial_stop_frequency * 2 else 0 - else: - raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") - - for t in range(T): - for h in range(H): - for w in range(W): - d_square = ( - ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / T - 1)) ** 2 - + (2 * h / H - 1) ** 2 - + (2 * w / W - 1) ** 2 - ) - mask[..., t, h, w] = retrieve_mask(d_square) - - return mask.to(device) - - -def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor: - r"""Noise reinitialization.""" - # FFT - x_freq = fft.fftn(x, dim=(-3, -2, -1)) - x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) - noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) - noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) - - # frequency mix - HPF = 1 - LPF - x_freq_low = x_freq * LPF - noise_freq_high = noise_freq * HPF - x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain - - # IFFT - x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) - x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real - - return x_mixed - - @dataclass class AnimateDiffPipelineOutput(BaseOutput): frames: Union[torch.Tensor, np.ndarray] -class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): +class AnimateDiffPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin +): r""" Pipeline for text-to-video generation. @@ -170,6 +106,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + - [`~pipelines.animatediff.FreeInitMixin.enable_free_init`] for enabling the FreeInit mechanism for generation + - [`~pipelines.animatediff.FreeInitMixin.disable_free_init`] for disabling the FreeInit mechanism for generation Args: vae ([`AutoencoderKL`]): @@ -517,58 +455,6 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() - @property - def free_init_enabled(self): - return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None - - def enable_free_init( - self, - num_iters: int = 3, - use_fast_sampling: bool = False, - method: str = "butterworth", - order: int = 4, - spatial_stop_frequency: float = 0.25, - temporal_stop_frequency: float = 0.25, - generator: torch.Generator = None, - ): - """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. - - This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). - - Args: - num_iters (`int`, *optional*, defaults to `3`): - Number of FreeInit noise re-initialization iterations. - use_fast_sampling (`bool`, *optional*, defaults to `False`): - Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables - the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. - method (`str`, *optional*, defaults to `butterworth`): - Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the - FreeInit low pass filter. - order (`int`, *optional*, defaults to `4`): - Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour - whereas lower values lead to `gaussian` method behaviour. - spatial_stop_frequency (`float`, *optional*, defaults to `0.25`): - Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in - the original implementation. - temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): - Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in - the original implementation. - generator (`torch.Generator`, *optional*, defaults to `0.25`): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - FreeInit generation deterministic. - """ - self._free_init_num_iters = num_iters - self._free_init_use_fast_sampling = use_fast_sampling - self._free_init_method = method - self._free_init_order = order - self._free_init_spatial_stop_frequency = spatial_stop_frequency - self._free_init_temporal_stop_frequency = temporal_stop_frequency - self._free_init_generator = generator - - def disable_free_init(self): - """Disables the FreeInit mechanism if enabled.""" - self._free_init_num_iters = None - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -726,85 +612,6 @@ def _denoise_loop( return latents - def _free_init_loop( - self, - height, - width, - num_frames, - num_channels_latents, - batch_size, - num_videos_per_prompt, - denoise_args, - device, - ): - """Denoising loop for AnimateDiff using FreeInit noise reinitialization technique.""" - - latents = denoise_args.get("latents") - prompt_embeds = denoise_args.get("prompt_embeds") - timesteps = denoise_args.get("timesteps") - num_inference_steps = denoise_args.get("num_inference_steps") - - latent_shape = ( - batch_size * num_videos_per_prompt, - num_channels_latents, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - free_init_filter_shape = ( - 1, - num_channels_latents, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - free_init_freq_filter = _get_freeinit_freq_filter( - shape=free_init_filter_shape, - device=device, - filter_type=self._free_init_method, - order=self._free_init_order, - spatial_stop_frequency=self._free_init_spatial_stop_frequency, - temporal_stop_frequency=self._free_init_temporal_stop_frequency, - ) - - with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar: - for i in range(self._free_init_num_iters): - # For the first FreeInit iteration, the original latent is used without modification. - # Subsequent iterations apply the noise reinitialization technique. - if i == 0: - initial_noise = latents.detach().clone() - else: - current_diffuse_timestep = ( - self.scheduler.config.num_train_timesteps - 1 - ) # diffuse to t=999 noise level - diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() - z_T = self.scheduler.add_noise( - original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) - ).to(dtype=torch.float32) - z_rand = randn_tensor( - shape=latent_shape, - generator=self._free_init_generator, - device=device, - dtype=torch.float32, - ) - latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) - latents = latents.to(prompt_embeds.dtype) - - # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) - if self._free_init_use_fast_sampling: - current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1)) - self.scheduler.set_timesteps(current_num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps}) - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps}) - latents = self._denoise_loop(**denoise_args) - - free_init_progress_bar.update() - - return latents - def _retrieve_video_frames(self, latents, output_type, return_dict): """Helper function to handle latents to output conversion.""" if output_type == "latent": @@ -1070,6 +877,7 @@ def __call__( num_videos_per_prompt=num_videos_per_prompt, denoise_args=denoise_args, device=device, + generator=generator, ) else: latents = self._denoise_loop(**denoise_args) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 80a8fd19f5a0..fb10872456ee 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -242,7 +242,6 @@ def test_free_init(self): inputs_normal = self.get_dummy_inputs(torch_device) frames_normal = pipe(**inputs_normal).frames[0] - free_init_generator = torch.Generator(device=torch_device).manual_seed(0) pipe.enable_free_init( num_iters=2, use_fast_sampling=True, @@ -250,7 +249,6 @@ def test_free_init(self): order=4, spatial_stop_frequency=0.25, temporal_stop_frequency=0.25, - generator=free_init_generator, ) inputs_enable_free_init = self.get_dummy_inputs(torch_device) frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]