Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def __call__(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def __call__(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
Expand Down
61 changes: 31 additions & 30 deletions src/diffusers/pipelines/free_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,39 +146,40 @@ def _apply_free_init(
):
if free_init_iteration == 0:
self._free_init_initial_noise = latents.detach().clone()
return latents, self.scheduler.timesteps

latent_shape = latents.shape

free_init_filter_shape = (1, *latent_shape[1:])
free_init_freq_filter = self._get_free_init_freq_filter(
shape=free_init_filter_shape,
device=device,
filter_type=self._free_init_method,
order=self._free_init_order,
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
)

current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()

z_t = self.scheduler.add_noise(
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
).to(dtype=torch.float32)

z_rand = randn_tensor(
shape=latent_shape,
generator=generator,
device=device,
dtype=torch.float32,
)
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
latents = latents.to(dtype)
else:
latent_shape = latents.shape

free_init_filter_shape = (1, *latent_shape[1:])
free_init_freq_filter = self._get_free_init_freq_filter(
shape=free_init_filter_shape,
device=device,
filter_type=self._free_init_method,
order=self._free_init_order,
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
)

current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()

z_t = self.scheduler.add_noise(
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
).to(dtype=torch.float32)

z_rand = randn_tensor(
shape=latent_shape,
generator=generator,
device=device,
dtype=torch.float32,
)
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
latents = latents.to(dtype)

# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
if self._free_init_use_fast_sampling:
num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
num_inference_steps = max(
1, int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
)
self.scheduler.set_timesteps(num_inference_steps, device=device)

return latents, self.scheduler.timesteps
82 changes: 9 additions & 73 deletions src/diffusers/pipelines/pia/pipeline_pia.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.

import inspect
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import PIL
import torch
import torch.fft as fft
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import PipelineImageInput, VaeImageProcessor
Expand Down Expand Up @@ -130,81 +128,16 @@ def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_sca
return coef


def _get_freeinit_freq_filter(
shape: Tuple[int, ...],
device: Union[str, torch.dtype],
filter_type: str,
order: float,
spatial_stop_frequency: float,
temporal_stop_frequency: float,
) -> torch.Tensor:
r"""Returns the FreeInit filter based on filter type and other input conditions."""

time, height, width = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)

if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
return mask

if filter_type == "butterworth":

def retrieve_mask(x):
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
elif filter_type == "gaussian":

def retrieve_mask(x):
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
elif filter_type == "ideal":

def retrieve_mask(x):
return 1 if x <= spatial_stop_frequency * 2 else 0
else:
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")

for t in range(time):
for h in range(height):
for w in range(width):
d_square = (
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
+ (2 * h / height - 1) ** 2
+ (2 * w / width - 1) ** 2
)
mask[..., t, h, w] = retrieve_mask(d_square)

return mask.to(device)


def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor:
r"""Noise reinitialization."""
# FFT
x_freq = fft.fftn(x, dim=(-3, -2, -1))
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))

# frequency mix
HPF = 1 - LPF
x_freq_low = x_freq * LPF
noise_freq_high = noise_freq * HPF
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain

# IFFT
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real

return x_mixed


@dataclass
class PIAPipelineOutput(BaseOutput):
r"""
Output class for PIAPipeline.

Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
NumPy array of shape `(batch_size, num_frames, channels, height, width,
Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
NumPy array of shape `(batch_size, num_frames, channels, height, width,
Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
"""

frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
Expand Down Expand Up @@ -788,7 +721,8 @@ def __call__(
The input image to be used for video generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated video.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
Expand Down Expand Up @@ -979,8 +913,10 @@ def __call__(
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
)

self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:

with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
Expand Down