From 5289205316db5cc1d41e13224c9e0afd87a42a14 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 6 Feb 2024 08:13:39 +0000 Subject: [PATCH 01/10] update --- .../animatediff/pipeline_animatediff.py | 252 ++++++------------ 1 file changed, 84 insertions(+), 168 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 5988e7657e13..fe55dd3cf4f1 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -689,97 +689,19 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents - def _denoise_loop( - self, - timesteps, - num_inference_steps, - do_classifier_free_guidance, - guidance_scale, - num_warmup_steps, - prompt_embeds, - negative_prompt_embeds, - latents, - cross_attention_kwargs, - added_cond_kwargs, - extra_step_kwargs, - callback, - callback_steps, - callback_on_step_end, - callback_on_step_end_tensor_inputs, - ): - """Denoising loop for AnimateDiff.""" - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + def _apply_freeinit(self, latents, free_init_iteration, num_inference_steps, device, dtype): + if free_init_iteration == 0: + return latents - return latents + latent_shape = latents.shape + batch_size, num_channels_latents, num_frames, height, width = latent_shape - def _free_init_loop( - self, - height, - width, - num_frames, - num_channels_latents, - batch_size, - num_videos_per_prompt, - denoise_args, - device, - ): - """Denoising loop for AnimateDiff using FreeInit noise reinitialization technique.""" - - latents = denoise_args.get("latents") - prompt_embeds = denoise_args.get("prompt_embeds") - timesteps = denoise_args.get("timesteps") - num_inference_steps = denoise_args.get("num_inference_steps") - - latent_shape = ( - batch_size * num_videos_per_prompt, - num_channels_latents, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) free_init_filter_shape = ( 1, num_channels_latents, num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, + height, + width, ) free_init_freq_filter = _get_freeinit_freq_filter( shape=free_init_filter_shape, @@ -790,56 +712,30 @@ def _free_init_loop( temporal_stop_frequency=self._free_init_temporal_stop_frequency, ) - with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar: - for i in range(self._free_init_num_iters): - # For the first FreeInit iteration, the original latent is used without modification. - # Subsequent iterations apply the noise reinitialization technique. - if i == 0: - initial_noise = latents.detach().clone() - else: - current_diffuse_timestep = ( - self.scheduler.config.num_train_timesteps - 1 - ) # diffuse to t=999 noise level - diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() - z_T = self.scheduler.add_noise( - original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) - ).to(dtype=torch.float32) - z_rand = randn_tensor( - shape=latent_shape, - generator=self._free_init_generator, - device=device, - dtype=torch.float32, - ) - latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) - latents = latents.to(prompt_embeds.dtype) - - # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) - if self._free_init_use_fast_sampling: - current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1)) - self.scheduler.set_timesteps(current_num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps}) - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps}) - latents = self._denoise_loop(**denoise_args) - - free_init_progress_bar.update() - - return latents - - def _retrieve_video_frames(self, latents, output_type, return_dict): - """Helper function to handle latents to output conversion.""" - if output_type == "latent": - return AnimateDiffPipelineOutput(frames=latents) + initial_noise = latents.detach().clone() + current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1 + diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() - video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + z_T = self.scheduler.add_noise( + original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) + ).to(dtype=torch.float32) + z_rand = randn_tensor( + shape=latent_shape, + generator=self._free_init_generator, + device=device, + dtype=torch.float32, + ) + latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) + latents = latents.to(dtype) - if not return_dict: - return (video,) + # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) + if self._free_init_use_fast_sampling: + current_num_inference_steps = int( + num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1) + ) + self.scheduler.set_timesteps(current_num_inference_steps, device=device) - return AnimateDiffPipelineOutput(frames=video) + return latents @property def guidance_scale(self): @@ -1039,6 +935,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -1060,43 +957,62 @@ def __call__( # 7. Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - denoise_args = { - "timesteps": timesteps, - "num_inference_steps": num_inference_steps, - "do_classifier_free_guidance": self.do_classifier_free_guidance, - "guidance_scale": guidance_scale, - "num_warmup_steps": num_warmup_steps, - "prompt_embeds": prompt_embeds, - "negative_prompt_embeds": negative_prompt_embeds, - "latents": latents, - "cross_attention_kwargs": self.cross_attention_kwargs, - "added_cond_kwargs": added_cond_kwargs, - "extra_step_kwargs": extra_step_kwargs, - "callback": callback, - "callback_steps": callback_steps, - "callback_on_step_end": callback_on_step_end, - "callback_on_step_end_tensor_inputs": callback_on_step_end_tensor_inputs, - } - - if self.free_init_enabled: - latents = self._free_init_loop( - height=height, - width=width, - num_frames=num_frames, - num_channels_latents=num_channels_latents, - batch_size=batch_size, - num_videos_per_prompt=num_videos_per_prompt, - denoise_args=denoise_args, - device=device, - ) - else: - latents = self._denoise_loop(**denoise_args) - - video = self._retrieve_video_frames(latents, output_type, return_dict) + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents = self._apply_freeinit(latents, free_init_iter, num_inference_steps, device, latents.dtype) + timesteps = self.scheduler.timesteps + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 9. Offload all models self.maybe_free_model_hooks() - return video + if output_type == "latent": + return AnimateDiffPipelineOutput(frames=latents) + + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) From a779a3202e08454ca9cfae8bd38328aa0503cc35 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 6 Feb 2024 11:16:06 +0000 Subject: [PATCH 02/10] update --- .../animatediff/pipeline_animatediff.py | 179 +------------ .../pipeline_animatediff_video2video.py | 83 +++--- src/diffusers/pipelines/pia/pipeline_pia.py | 241 ++++-------------- 3 files changed, 107 insertions(+), 396 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index fe55dd3cf4f1..cd514d7b4159 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -13,12 +13,10 @@ # limitations under the License. import inspect -import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch -import torch.fft as fft from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor @@ -44,6 +42,7 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .freeinit_utils import FreeInitMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -87,72 +86,9 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: return outputs -def _get_freeinit_freq_filter( - shape: Tuple[int, ...], - device: Union[str, torch.dtype], - filter_type: str, - order: float, - spatial_stop_frequency: float, - temporal_stop_frequency: float, -) -> torch.Tensor: - r"""Returns the FreeInit filter based on filter type and other input conditions.""" - - T, H, W = shape[-3], shape[-2], shape[-1] - mask = torch.zeros(shape) - - if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: - return mask - - if filter_type == "butterworth": - - def retrieve_mask(x): - return 1 / (1 + (x / spatial_stop_frequency**2) ** order) - elif filter_type == "gaussian": - - def retrieve_mask(x): - return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) - elif filter_type == "ideal": - - def retrieve_mask(x): - return 1 if x <= spatial_stop_frequency * 2 else 0 - else: - raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") - - for t in range(T): - for h in range(H): - for w in range(W): - d_square = ( - ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / T - 1)) ** 2 - + (2 * h / H - 1) ** 2 - + (2 * w / W - 1) ** 2 - ) - mask[..., t, h, w] = retrieve_mask(d_square) - - return mask.to(device) - - -def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor: - r"""Noise reinitialization.""" - # FFT - x_freq = fft.fftn(x, dim=(-3, -2, -1)) - x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) - noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) - noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) - - # frequency mix - HPF = 1 - LPF - x_freq_low = x_freq * LPF - noise_freq_high = noise_freq * HPF - x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain - - # IFFT - x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) - x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real - - return x_mixed - - -class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): +class AnimateDiffPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin +): r""" Pipeline for text-to-video generation. @@ -535,63 +471,6 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): raise ValueError("The pipeline must have `unet` for using FreeU.") self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - - @property - def free_init_enabled(self): - return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None - - def enable_free_init( - self, - num_iters: int = 3, - use_fast_sampling: bool = False, - method: str = "butterworth", - order: int = 4, - spatial_stop_frequency: float = 0.25, - temporal_stop_frequency: float = 0.25, - generator: torch.Generator = None, - ): - """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. - - This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). - - Args: - num_iters (`int`, *optional*, defaults to `3`): - Number of FreeInit noise re-initialization iterations. - use_fast_sampling (`bool`, *optional*, defaults to `False`): - Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables - the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. - method (`str`, *optional*, defaults to `butterworth`): - Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the - FreeInit low pass filter. - order (`int`, *optional*, defaults to `4`): - Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour - whereas lower values lead to `gaussian` method behaviour. - spatial_stop_frequency (`float`, *optional*, defaults to `0.25`): - Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in - the original implementation. - temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): - Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in - the original implementation. - generator (`torch.Generator`, *optional*, defaults to `0.25`): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - FreeInit generation deterministic. - """ - self._free_init_num_iters = num_iters - self._free_init_use_fast_sampling = use_fast_sampling - self._free_init_method = method - self._free_init_order = order - self._free_init_spatial_stop_frequency = spatial_stop_frequency - self._free_init_temporal_stop_frequency = temporal_stop_frequency - self._free_init_generator = generator - - def disable_free_init(self): - """Disables the FreeInit mechanism if enabled.""" - self._free_init_num_iters = None - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -689,54 +568,6 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents - def _apply_freeinit(self, latents, free_init_iteration, num_inference_steps, device, dtype): - if free_init_iteration == 0: - return latents - - latent_shape = latents.shape - batch_size, num_channels_latents, num_frames, height, width = latent_shape - - free_init_filter_shape = ( - 1, - num_channels_latents, - num_frames, - height, - width, - ) - free_init_freq_filter = _get_freeinit_freq_filter( - shape=free_init_filter_shape, - device=device, - filter_type=self._free_init_method, - order=self._free_init_order, - spatial_stop_frequency=self._free_init_spatial_stop_frequency, - temporal_stop_frequency=self._free_init_temporal_stop_frequency, - ) - - initial_noise = latents.detach().clone() - current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1 - diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() - - z_T = self.scheduler.add_noise( - original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) - ).to(dtype=torch.float32) - z_rand = randn_tensor( - shape=latent_shape, - generator=self._free_init_generator, - device=device, - dtype=torch.float32, - ) - latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) - latents = latents.to(dtype) - - # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) - if self._free_init_use_fast_sampling: - current_num_inference_steps = int( - num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1) - ) - self.scheduler.set_timesteps(current_num_inference_steps, device=device) - - return latents - @property def guidance_scale(self): return self._guidance_scale diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 3d01009cbac7..59f4c7592b8a 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -34,6 +34,7 @@ ) from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor +from ..animatediff.freeinit_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline from .pipeline_output import AnimateDiffPipelineOutput @@ -163,7 +164,9 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): +class AnimateDiffVideoToVideoPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin +): r""" Pipeline for video-to-video generation. @@ -912,42 +915,48 @@ def __call__( # 7. Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=self.cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - ).sample - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - progress_bar.update() + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents = self._apply_freeinit(latents, free_init_iter, num_inference_steps, device, latents.dtype) + timesteps = self.scheduler.timesteps + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + progress_bar.update() if output_type == "latent": return AnimateDiffPipelineOutput(frames=latents) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 077de49cdc87..b927cb9a5597 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -45,6 +45,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ..animatediff.freeinit_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline @@ -210,7 +211,7 @@ class PIAPipelineOutput(BaseOutput): class PIAPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin, FreeInitMixin ): r""" Pipeline for text-to-video generation. @@ -793,143 +794,6 @@ def prepare_masked_condition( return mask, masked_image - def _denoise_loop( - self, - timesteps, - num_inference_steps, - do_classifier_free_guidance, - guidance_scale, - num_warmup_steps, - prompt_embeds, - negative_prompt_embeds, - latents, - mask, - masked_image, - cross_attention_kwargs, - added_cond_kwargs, - extra_step_kwargs, - callback_on_step_end, - callback_on_step_end_tensor_inputs, - ): - """Denoising loop for PIA.""" - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = torch.cat([latent_model_input, mask, masked_image], dim=1) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - return latents - - def _free_init_loop( - self, - height, - width, - num_frames, - batch_size, - num_videos_per_prompt, - denoise_args, - device, - ): - """Denoising loop for PIA using FreeInit noise reinitialization technique.""" - - latents = denoise_args.get("latents") - prompt_embeds = denoise_args.get("prompt_embeds") - timesteps = denoise_args.get("timesteps") - num_inference_steps = denoise_args.get("num_inference_steps") - - latent_shape = ( - batch_size * num_videos_per_prompt, - 4, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - free_init_filter_shape = ( - 1, - 4, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - free_init_freq_filter = _get_freeinit_freq_filter( - shape=free_init_filter_shape, - device=device, - filter_type=self._free_init_method, - order=self._free_init_order, - spatial_stop_frequency=self._free_init_spatial_stop_frequency, - temporal_stop_frequency=self._free_init_temporal_stop_frequency, - ) - - with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar: - for i in range(self._free_init_num_iters): - # For the first FreeInit iteration, the original latent is used without modification. - # Subsequent iterations apply the noise reinitialization technique. - if i == 0: - initial_noise = latents.detach().clone() - else: - current_diffuse_timestep = ( - self.scheduler.config.num_train_timesteps - 1 - ) # diffuse to t=999 noise level - diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() - z_T = self.scheduler.add_noise( - original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) - ).to(dtype=torch.float32) - z_rand = randn_tensor( - shape=latent_shape, - generator=self._free_init_generator, - device=device, - dtype=torch.float32, - ) - latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) - latents = latents.to(prompt_embeds.dtype) - - # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) - if self._free_init_use_fast_sampling: - current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1)) - self.scheduler.set_timesteps(current_num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps}) - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps}) - latents = self._denoise_loop(**denoise_args) - - free_init_progress_bar.update() - - return latents - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep @@ -942,19 +806,6 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start - def _retrieve_video_frames(self, latents, output_type, return_dict): - """Helper function to handle latents to output conversion.""" - if output_type == "latent": - return PIAPipelineOutput(frames=latents) - - video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) - - if not return_dict: - return (video,) - - return PIAPipelineOutput(frames=video) - @property def guidance_scale(self): return self._guidance_scale @@ -1183,41 +1034,61 @@ def __call__( added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - denoise_args = { - "timesteps": timesteps, - "num_inference_steps": num_inference_steps, - "do_classifier_free_guidance": self.do_classifier_free_guidance, - "guidance_scale": guidance_scale, - "num_warmup_steps": num_warmup_steps, - "prompt_embeds": prompt_embeds, - "negative_prompt_embeds": negative_prompt_embeds, - "latents": latents, - "mask": mask, - "masked_image": masked_image, - "cross_attention_kwargs": self.cross_attention_kwargs, - "added_cond_kwargs": added_cond_kwargs, - "extra_step_kwargs": extra_step_kwargs, - "callback_on_step_end": callback_on_step_end, - "callback_on_step_end_tensor_inputs": callback_on_step_end_tensor_inputs, - } - - if self.free_init_enabled: - latents = self._free_init_loop( - height=height, - width=width, - num_frames=num_frames, - batch_size=batch_size, - num_videos_per_prompt=num_videos_per_prompt, - denoise_args=denoise_args, - device=device, - ) - else: - latents = self._denoise_loop(**denoise_args) - - video = self._retrieve_video_frames(latents, output_type, return_dict) + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents = self._apply_freeinit(latents, free_init_iter, num_inference_steps, device, latents.dtype) + timesteps = self.scheduler.timesteps + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, mask, masked_image], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() # 9. Offload all models self.maybe_free_model_hooks() - return video + if output_type == "latent": + return PIAPipelineOutput(frames=latents) + + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + + if not return_dict: + return (video,) + + return PIAPipelineOutput(frames=video) From 4330a8d7f848accde504981e588efb7ffb469b9b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 6 Feb 2024 12:12:56 +0000 Subject: [PATCH 03/10] update --- .../pipelines/animatediff/freeinit_utils.py | 183 ++++++++++++++++++ .../pipeline_animatediff_video2video.py | 4 +- 2 files changed, 185 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/pipelines/animatediff/freeinit_utils.py diff --git a/src/diffusers/pipelines/animatediff/freeinit_utils.py b/src/diffusers/pipelines/animatediff/freeinit_utils.py new file mode 100644 index 000000000000..a8aaf0edce2a --- /dev/null +++ b/src/diffusers/pipelines/animatediff/freeinit_utils.py @@ -0,0 +1,183 @@ +import math +from typing import Tuple, Union + +import torch +import torch.fft as fft + +from ...utils.torch_utils import randn_tensor + + +def _get_freeinit_freq_filter( + shape: Tuple[int, ...], + device: Union[str, torch.dtype], + filter_type: str, + order: float, + spatial_stop_frequency: float, + temporal_stop_frequency: float, +) -> torch.Tensor: + r"""Returns the FreeInit filter based on filter type and other input conditions.""" + + time, height, width = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + + if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: + return mask + + if filter_type == "butterworth": + + def retrieve_mask(x): + return 1 / (1 + (x / spatial_stop_frequency**2) ** order) + elif filter_type == "gaussian": + + def retrieve_mask(x): + return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) + elif filter_type == "ideal": + + def retrieve_mask(x): + return 1 if x <= spatial_stop_frequency * 2 else 0 + else: + raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") + + for t in range(time): + for h in range(height): + for w in range(width): + d_square = ( + ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2 + + (2 * h / height - 1) ** 2 + + (2 * w / width - 1) ** 2 + ) + mask[..., t, h, w] = retrieve_mask(d_square) + + return mask.to(device) + + +def _apply_freq_filter(x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor: + r"""Noise reinitialization.""" + # FFT + x_freq = fft.fftn(x, dim=(-3, -2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) + noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) + noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) + + # frequency mix + high_pass_filter = 1 - low_pass_filter + x_freq_low = x_freq * low_pass_filter + noise_freq_high = noise_freq * high_pass_filter + x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain + + # IFFT + x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) + x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real + + return x_mixed + + +class FreeInitMixin: + r"""Mixin class for FreeInit.""" + + def enable_free_init( + self, + num_iters: int = 3, + use_fast_sampling: bool = False, + method: str = "butterworth", + order: int = 4, + spatial_stop_frequency: float = 0.25, + temporal_stop_frequency: float = 0.25, + generator: torch.Generator = None, + ): + """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. + + This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). + + Args: + num_iters (`int`, *optional*, defaults to `3`): + Number of FreeInit noise re-initialization iterations. + use_fast_sampling (`bool`, *optional*, defaults to `False`): + Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables + the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. + method (`str`, *optional*, defaults to `butterworth`): + Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the + FreeInit low pass filter. + order (`int`, *optional*, defaults to `4`): + Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour + whereas lower values lead to `gaussian` method behaviour. + spatial_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in + the original implementation. + temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in + the original implementation. + generator (`torch.Generator`, *optional*, defaults to `0.25`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + FreeInit generation deterministic. + """ + self._free_init_num_iters = num_iters + self._free_init_use_fast_sampling = use_fast_sampling + self._free_init_method = method + self._free_init_order = order + self._free_init_spatial_stop_frequency = spatial_stop_frequency + self._free_init_temporal_stop_frequency = temporal_stop_frequency + self._free_init_generator = generator + + def disable_free_init(self): + """Disables the FreeInit mechanism if enabled.""" + self._free_init_num_iters = None + + @property + def free_init_enabled(self): + return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None + + def _get_freeinit_freq_filter(self, shape: Tuple[int, ...]) -> torch.Tensor: + r"""Returns the FreeInit filter based on filter type and other input conditions.""" + return _get_freeinit_freq_filter( + shape, + self.device, + self.filter_type, + self.order, + self.spatial_stop_frequency, + self.temporal_stop_frequency, + ) + + def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor: + r"""Noise reinitialization.""" + return _apply_freq_filter(x, noise, low_pass_filter) + + def _apply_freeinit(self, latents, free_init_iteration, num_inference_steps, device, dtype): + if free_init_iteration == 0: + self._free_init_initial_noise = latents.detach().clone() + return latents + + latent_shape = latents.shape + + free_init_filter_shape = (1, *latent_shape[1:]) + free_init_freq_filter = _get_freeinit_freq_filter( + shape=free_init_filter_shape, + device=device, + filter_type=self._free_init_method, + order=self._free_init_order, + spatial_stop_frequency=self._free_init_spatial_stop_frequency, + temporal_stop_frequency=self._free_init_temporal_stop_frequency, + ) + + current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1 + diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long() + + z_t = self.scheduler.add_noise( + original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device) + ).to(dtype=torch.float32) + + z_rand = randn_tensor( + shape=latent_shape, + generator=self._free_init_generator, + device=device, + dtype=torch.float32, + ) + latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter) + latents = latents.to(dtype) + + # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) + if self._free_init_use_fast_sampling: + num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1)) + self.scheduler.set_timesteps(num_inference_steps, device=device) + + return latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 59f4c7592b8a..69f5f891f25c 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -919,10 +919,10 @@ def __call__( for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: latents = self._apply_freeinit(latents, free_init_iter, num_inference_steps, device, latents.dtype) - timesteps = self.scheduler.timesteps + num_inference_steps = len(self.scheduler.timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance From fd2e462504eae896531c9c4f969a79ecb920ed76 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 8 Feb 2024 07:57:32 +0000 Subject: [PATCH 04/10] update --- .../pipelines/animatediff/freeinit_utils.py | 145 ++++++++---------- .../animatediff/pipeline_animatediff.py | 12 +- .../pipeline_animatediff_video2video.py | 22 ++- src/diffusers/pipelines/pia/pipeline_pia.py | 5 +- 4 files changed, 89 insertions(+), 95 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/freeinit_utils.py b/src/diffusers/pipelines/animatediff/freeinit_utils.py index a8aaf0edce2a..04fd294011e8 100644 --- a/src/diffusers/pipelines/animatediff/freeinit_utils.py +++ b/src/diffusers/pipelines/animatediff/freeinit_utils.py @@ -7,71 +7,6 @@ from ...utils.torch_utils import randn_tensor -def _get_freeinit_freq_filter( - shape: Tuple[int, ...], - device: Union[str, torch.dtype], - filter_type: str, - order: float, - spatial_stop_frequency: float, - temporal_stop_frequency: float, -) -> torch.Tensor: - r"""Returns the FreeInit filter based on filter type and other input conditions.""" - - time, height, width = shape[-3], shape[-2], shape[-1] - mask = torch.zeros(shape) - - if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: - return mask - - if filter_type == "butterworth": - - def retrieve_mask(x): - return 1 / (1 + (x / spatial_stop_frequency**2) ** order) - elif filter_type == "gaussian": - - def retrieve_mask(x): - return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) - elif filter_type == "ideal": - - def retrieve_mask(x): - return 1 if x <= spatial_stop_frequency * 2 else 0 - else: - raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") - - for t in range(time): - for h in range(height): - for w in range(width): - d_square = ( - ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2 - + (2 * h / height - 1) ** 2 - + (2 * w / width - 1) ** 2 - ) - mask[..., t, h, w] = retrieve_mask(d_square) - - return mask.to(device) - - -def _apply_freq_filter(x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor: - r"""Noise reinitialization.""" - # FFT - x_freq = fft.fftn(x, dim=(-3, -2, -1)) - x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) - noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) - noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) - - # frequency mix - high_pass_filter = 1 - low_pass_filter - x_freq_low = x_freq * low_pass_filter - noise_freq_high = noise_freq * high_pass_filter - x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain - - # IFFT - x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) - x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real - - return x_mixed - - class FreeInitMixin: r"""Mixin class for FreeInit.""" @@ -117,7 +52,6 @@ def enable_free_init( self._free_init_order = order self._free_init_spatial_stop_frequency = spatial_stop_frequency self._free_init_temporal_stop_frequency = temporal_stop_frequency - self._free_init_generator = generator def disable_free_init(self): """Disables the FreeInit mechanism if enabled.""" @@ -127,30 +61,79 @@ def disable_free_init(self): def free_init_enabled(self): return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None - def _get_freeinit_freq_filter(self, shape: Tuple[int, ...]) -> torch.Tensor: + def _get_freeinit_freq_filter( + self, + shape: Tuple[int, ...], + device: Union[str, torch.dtype], + filter_type: str, + order: float, + spatial_stop_frequency: float, + temporal_stop_frequency: float, + ) -> torch.Tensor: r"""Returns the FreeInit filter based on filter type and other input conditions.""" - return _get_freeinit_freq_filter( - shape, - self.device, - self.filter_type, - self.order, - self.spatial_stop_frequency, - self.temporal_stop_frequency, - ) + + time, height, width = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + + if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: + return mask + + if filter_type == "butterworth": + + def retrieve_mask(x): + return 1 / (1 + (x / spatial_stop_frequency**2) ** order) + elif filter_type == "gaussian": + + def retrieve_mask(x): + return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) + elif filter_type == "ideal": + + def retrieve_mask(x): + return 1 if x <= spatial_stop_frequency * 2 else 0 + else: + raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") + + for t in range(time): + for h in range(height): + for w in range(width): + d_square = ( + ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2 + + (2 * h / height - 1) ** 2 + + (2 * w / width - 1) ** 2 + ) + mask[..., t, h, w] = retrieve_mask(d_square) + + return mask.to(device) def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor: r"""Noise reinitialization.""" - return _apply_freq_filter(x, noise, low_pass_filter) + # FFT + x_freq = fft.fftn(x, dim=(-3, -2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) + noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) + noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) + + # frequency mix + high_pass_filter = 1 - low_pass_filter + x_freq_low = x_freq * low_pass_filter + noise_freq_high = noise_freq * high_pass_filter + x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain + + # IFFT + x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) + x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real + + return x_mixed - def _apply_freeinit(self, latents, free_init_iteration, num_inference_steps, device, dtype): + def _apply_freeinit(self, latents, free_init_iteration, num_inference_steps, device, dtype, generator): if free_init_iteration == 0: self._free_init_initial_noise = latents.detach().clone() - return latents + return latents, self.scheduler.timesteps latent_shape = latents.shape free_init_filter_shape = (1, *latent_shape[1:]) - free_init_freq_filter = _get_freeinit_freq_filter( + free_init_freq_filter = self._get_freeinit_freq_filter( shape=free_init_filter_shape, device=device, filter_type=self._free_init_method, @@ -168,7 +151,7 @@ def _apply_freeinit(self, latents, free_init_iteration, num_inference_steps, dev z_rand = randn_tensor( shape=latent_shape, - generator=self._free_init_generator, + generator=generator, device=device, dtype=torch.float32, ) @@ -180,4 +163,4 @@ def _apply_freeinit(self, latents, free_init_iteration, num_inference_steps, dev num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1)) self.scheduler.set_timesteps(num_inference_steps, device=device) - return latents + return latents, self.scheduler.timesteps diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index cd514d7b4159..df75905b8a84 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -471,6 +471,10 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): raise ValueError("The pipeline must have `unet` for using FreeU.") self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -765,8 +769,6 @@ def __call__( # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - self._num_timesteps = len(timesteps) - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -791,9 +793,11 @@ def __call__( num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: - latents = self._apply_freeinit(latents, free_init_iter, num_inference_steps, device, latents.dtype) - timesteps = self.scheduler.timesteps + latents, timesteps = self._apply_freeinit( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 69f5f891f25c..6c7a6610be04 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -598,12 +598,12 @@ def check_inputs( if video is not None and latents is not None: raise ValueError("Only one of `video` or `latents` should be provided") - def get_timesteps(self, num_inference_steps, strength, device): + def get_timesteps(self, num_inference_steps, timesteps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + timesteps = timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start @@ -890,9 +890,8 @@ def __call__( # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) - self._num_timesteps = len(timesteps) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -918,10 +917,15 @@ def __call__( num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: - latents = self._apply_freeinit(latents, free_init_iter, num_inference_steps, device, latents.dtype) - num_inference_steps = len(self.scheduler.timesteps) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latents, timesteps = self._apply_freeinit( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + num_inference_steps = len(timesteps) + # make sure to readjust timesteps based on strength + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 8. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -956,7 +960,9 @@ def __call__( prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - progress_bar.update() + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() if output_type == "latent": return AnimateDiffPipelineOutput(frames=latents) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index b927cb9a5597..391116cda70a 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -1037,8 +1037,9 @@ def __call__( num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: - latents = self._apply_freeinit(latents, free_init_iter, num_inference_steps, device, latents.dtype) - timesteps = self.scheduler.timesteps + latents, timesteps = self._apply_freeinit( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: From a2b15b9e849e09f5fd14270ca765ec6c55385a61 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 8 Feb 2024 08:37:03 +0000 Subject: [PATCH 05/10] update --- .../pipelines/animatediff/pipeline_animatediff.py | 8 ++++---- src/diffusers/pipelines/pia/pipeline_pia.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index df75905b8a84..88bd294ecfab 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -118,7 +118,7 @@ class AnimateDiffPipeline( """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" - _optional_components = ["feature_extractor", "image_encoder"] + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -838,15 +838,15 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 9. Offload all models - self.maybe_free_model_hooks() - if output_type == "latent": return AnimateDiffPipelineOutput(frames=latents) video_tensor = self.decode_latents(latents) video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + # 9. Offload all models + self.maybe_free_model_hooks() + if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 391116cda70a..c140a2d5fa69 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -1080,15 +1080,15 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - # 9. Offload all models - self.maybe_free_model_hooks() - if output_type == "latent": return PIAPipelineOutput(frames=latents) video_tensor = self.decode_latents(latents) video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + # 9. Offload all models + self.maybe_free_model_hooks() + if not return_dict: return (video,) From b23e579b51361957c3962a34ecb8eae3c0dbfd54 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 12 Feb 2024 13:15:06 +0000 Subject: [PATCH 06/10] update --- src/diffusers/pipelines/animatediff/freeinit_utils.py | 5 ++--- src/diffusers/pipelines/animatediff/pipeline_animatediff.py | 2 +- .../animatediff/pipeline_animatediff_video2video.py | 2 +- src/diffusers/pipelines/pia/pipeline_pia.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/freeinit_utils.py b/src/diffusers/pipelines/animatediff/freeinit_utils.py index 04fd294011e8..d322d8cdbe90 100644 --- a/src/diffusers/pipelines/animatediff/freeinit_utils.py +++ b/src/diffusers/pipelines/animatediff/freeinit_utils.py @@ -18,7 +18,6 @@ def enable_free_init( order: int = 4, spatial_stop_frequency: float = 0.25, temporal_stop_frequency: float = 0.25, - generator: torch.Generator = None, ): """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. @@ -61,7 +60,7 @@ def disable_free_init(self): def free_init_enabled(self): return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None - def _get_freeinit_freq_filter( + def _get_free_init_freq_filter( self, shape: Tuple[int, ...], device: Union[str, torch.dtype], @@ -125,7 +124,7 @@ def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filt return x_mixed - def _apply_freeinit(self, latents, free_init_iteration, num_inference_steps, device, dtype, generator): + def _apply_free_init(self, latents, free_init_iteration, num_inference_steps, device, dtype, generator): if free_init_iteration == 0: self._free_init_initial_noise = latents.detach().clone() return latents, self.scheduler.timesteps diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index cdafb7a0c810..b9366f13bc13 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -801,7 +801,7 @@ def __call__( num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: - latents, timesteps = self._apply_freeinit( + latents, timesteps = self._apply_free_init( latents, free_init_iter, num_inference_steps, device, latents.dtype, generator ) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 2b06156381f9..47985642708c 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -906,7 +906,7 @@ def __call__( num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: - latents, timesteps = self._apply_freeinit( + latents, timesteps = self._apply_free_init( latents, free_init_iter, num_inference_steps, device, latents.dtype, generator ) num_inference_steps = len(timesteps) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 8caeb7d457f7..f45da64c7616 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -1045,7 +1045,7 @@ def __call__( num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: - latents, timesteps = self._apply_freeinit( + latents, timesteps = self._apply_free_init( latents, free_init_iter, num_inference_steps, device, latents.dtype, generator ) From 732fa8cebae4dbab4589bbc3cdf41d1f6d9a80c1 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 12 Feb 2024 13:24:03 +0000 Subject: [PATCH 07/10] update --- src/diffusers/pipelines/animatediff/freeinit_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/freeinit_utils.py b/src/diffusers/pipelines/animatediff/freeinit_utils.py index d322d8cdbe90..afaf521746d5 100644 --- a/src/diffusers/pipelines/animatediff/freeinit_utils.py +++ b/src/diffusers/pipelines/animatediff/freeinit_utils.py @@ -41,9 +41,6 @@ def enable_free_init( temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in the original implementation. - generator (`torch.Generator`, *optional*, defaults to `0.25`): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - FreeInit generation deterministic. """ self._free_init_num_iters = num_iters self._free_init_use_fast_sampling = use_fast_sampling From 138bc7f522098ed945f88b980292697d9fe76d3e Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 12 Feb 2024 13:26:46 +0000 Subject: [PATCH 08/10] update --- src/diffusers/pipelines/animatediff/freeinit_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/animatediff/freeinit_utils.py b/src/diffusers/pipelines/animatediff/freeinit_utils.py index afaf521746d5..2da0d9308e4e 100644 --- a/src/diffusers/pipelines/animatediff/freeinit_utils.py +++ b/src/diffusers/pipelines/animatediff/freeinit_utils.py @@ -121,7 +121,15 @@ def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filt return x_mixed - def _apply_free_init(self, latents, free_init_iteration, num_inference_steps, device, dtype, generator): + def _apply_free_init( + self, + latents: torch.Tensor, + free_init_iteration: int, + num_inference_steps: int, + device: torch.device, + dtype: torch.dtype, + generator: torch.Generator, + ): if free_init_iteration == 0: self._free_init_initial_noise = latents.detach().clone() return latents, self.scheduler.timesteps From 1ed8b4dbacecbf80ee9fcce986ca110ded8f3653 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 12 Feb 2024 13:33:18 +0000 Subject: [PATCH 09/10] update --- .../animatediff/pipeline_animatediff.py | 2 +- .../pipeline_animatediff_video2video.py | 2 +- .../freeinit_utils.py => free_init_utils.py} | 16 +++++++++++++++- src/diffusers/pipelines/pia/pipeline_pia.py | 2 +- 4 files changed, 18 insertions(+), 4 deletions(-) rename src/diffusers/pipelines/{animatediff/freeinit_utils.py => free_init_utils.py} (91%) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index b9366f13bc13..dbb953a92f9f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -41,8 +41,8 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline -from .freeinit_utils import FreeInitMixin from .pipeline_output import AnimateDiffPipelineOutput diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 47985642708c..fd07d2df2ff7 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -34,7 +34,7 @@ ) from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor -from ..animatediff.freeinit_utils import FreeInitMixin +from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline from .pipeline_output import AnimateDiffPipelineOutput diff --git a/src/diffusers/pipelines/animatediff/freeinit_utils.py b/src/diffusers/pipelines/free_init_utils.py similarity index 91% rename from src/diffusers/pipelines/animatediff/freeinit_utils.py rename to src/diffusers/pipelines/free_init_utils.py index 2da0d9308e4e..1d7dcde2873b 100644 --- a/src/diffusers/pipelines/animatediff/freeinit_utils.py +++ b/src/diffusers/pipelines/free_init_utils.py @@ -1,10 +1,24 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from typing import Tuple, Union import torch import torch.fft as fft -from ...utils.torch_utils import randn_tensor +from ..utils.torch_utils import randn_tensor class FreeInitMixin: diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index f45da64c7616..7ce34fefac38 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -45,7 +45,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor -from ..animatediff.freeinit_utils import FreeInitMixin +from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline From 8834fe6861920c8f0f20019f74a6f4c4976c3ce6 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 12 Feb 2024 14:21:50 +0000 Subject: [PATCH 10/10] update --- .../animatediff/pipeline_animatediff.py | 3 +- .../pipeline_animatediff_video2video.py | 5 +- src/diffusers/pipelines/free_init_utils.py | 2 +- src/diffusers/pipelines/pia/pipeline_pia.py | 52 ------------------- .../pipelines/animatediff/test_animatediff.py | 2 - .../test_animatediff_video2video.py | 35 +++++++++++++ tests/pipelines/pia/test_pia.py | 2 - 7 files changed, 41 insertions(+), 60 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index dbb953a92f9f..586567fc742f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -140,7 +140,8 @@ def __init__( image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() - unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) self.register_modules( vae=vae, diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index fd07d2df2ff7..beaa74ff151e 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -196,7 +196,7 @@ class AnimateDiffVideoToVideoPipeline( """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" - _optional_components = ["feature_extractor", "image_encoder"] + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -218,7 +218,8 @@ def __init__( image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() - unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) self.register_modules( vae=vae, diff --git a/src/diffusers/pipelines/free_init_utils.py b/src/diffusers/pipelines/free_init_utils.py index 1d7dcde2873b..50c28cc69f44 100644 --- a/src/diffusers/pipelines/free_init_utils.py +++ b/src/diffusers/pipelines/free_init_utils.py @@ -151,7 +151,7 @@ def _apply_free_init( latent_shape = latents.shape free_init_filter_shape = (1, *latent_shape[1:]) - free_init_freq_filter = self._get_freeinit_freq_filter( + free_init_freq_filter = self._get_free_init_freq_filter( shape=free_init_filter_shape, device=device, filter_type=self._free_init_method, diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 7ce34fefac38..07d3746465b5 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -561,58 +561,6 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() - @property - def free_init_enabled(self): - return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None - - def enable_free_init( - self, - num_iters: int = 3, - use_fast_sampling: bool = False, - method: str = "butterworth", - order: int = 4, - spatial_stop_frequency: float = 0.25, - temporal_stop_frequency: float = 0.25, - generator: Optional[torch.Generator] = None, - ): - """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. - - This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). - - Args: - num_iters (`int`, *optional*, defaults to `3`): - Number of FreeInit noise re-initialization iterations. - use_fast_sampling (`bool`, *optional*, defaults to `False`): - Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables - the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. - method (`str`, *optional*, defaults to `butterworth`): - Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the - FreeInit low pass filter. - order (`int`, *optional*, defaults to `4`): - Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour - whereas lower values lead to `gaussian` method behaviour. - spatial_stop_frequency (`float`, *optional*, defaults to `0.25`): - Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in - the original implementation. - temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): - Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in - the original implementation. - generator (`torch.Generator`, *optional*, defaults to `0.25`): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - FreeInit generation deterministic. - """ - self._free_init_num_iters = num_iters - self._free_init_use_fast_sampling = use_fast_sampling - self._free_init_method = method - self._free_init_order = order - self._free_init_spatial_stop_frequency = spatial_stop_frequency - self._free_init_temporal_stop_frequency = temporal_stop_frequency - self._free_init_generator = generator - - def disable_free_init(self): - """Disables the FreeInit mechanism if enabled.""" - self._free_init_num_iters = None - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 525ca24bbd9a..412d536c6e14 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -242,7 +242,6 @@ def test_free_init(self): inputs_normal = self.get_dummy_inputs(torch_device) frames_normal = pipe(**inputs_normal).frames[0] - free_init_generator = torch.Generator(device=torch_device).manual_seed(0) pipe.enable_free_init( num_iters=2, use_fast_sampling=True, @@ -250,7 +249,6 @@ def test_free_init(self): order=4, spatial_stop_frequency=0.25, temporal_stop_frequency=0.25, - generator=free_init_generator, ) inputs_enable_free_init = self.get_dummy_inputs(torch_device) frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0] diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index 767fc30b4eb5..bfb607ea507d 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -267,3 +267,38 @@ def test_xformers_attention_forwardGenerator_pass(self): max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results") + + def test_free_init(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + pipe.enable_free_init( + num_iters=2, + use_fast_sampling=True, + method="butterworth", + order=4, + spatial_stop_frequency=0.25, + temporal_stop_frequency=0.25, + ) + inputs_enable_free_init = self.get_dummy_inputs(torch_device) + frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0] + + pipe.disable_free_init() + inputs_disable_free_init = self.get_dummy_inputs(torch_device) + frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max() + self.assertGreater( + sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results" + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeInit should lead to results similar to the default pipeline results", + ) diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index edd129560c63..214f085e057e 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -255,7 +255,6 @@ def test_free_init(self): inputs_normal = self.get_dummy_inputs(torch_device) frames_normal = pipe(**inputs_normal).frames[0] - free_init_generator = torch.Generator(device=torch_device).manual_seed(0) pipe.enable_free_init( num_iters=2, use_fast_sampling=True, @@ -263,7 +262,6 @@ def test_free_init(self): order=4, spatial_stop_frequency=0.25, temporal_stop_frequency=0.25, - generator=free_init_generator, ) inputs_enable_free_init = self.get_dummy_inputs(torch_device) frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]