Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/source/en/api/pipelines/panorama.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ and increase the VRAM usage.

</Tip>

<Tip>

Circular padding is applied to ensure there are no stitching artifacts when working with
panoramas that needs to seamlessly transition from the rightmost part to the leftmost part.
By enabling circular padding (set `circular_padding=True`), the operation applies additional
crops after the rightmost point of the image, allowing the model to "see” the transition
from the rightmost part to the leftmost part. This helps maintain visual consistency in
a 360-degree sense and creates a proper “panorama” that can be viewed using 360-degree
panorama viewers. When decoding latents in StableDiffusion, circular padding is applied
to ensure that the decoded latents match in the RGB space.

Without circular padding, there is a stitching artifact (default):
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20no_circular_padding.png)

With circular padding, the right and the left parts are matching (`circular_padding=True`):
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20circular_padding.png)

</Tip>

## StableDiffusionPanoramaPipeline
[[autodoc]] StableDiffusionPanoramaPipeline
- __call__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,19 @@ def decode_latents(self, latents):
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image

def decode_latents_with_padding(self, latents, padding=8):
# Add padding to latents for circular inference
# padding is the number of latents to add on each side
# it would slightly increase the memory usage, but remove the boundary artifacts
latents = 1 / self.vae.config.scaling_factor * latents
latents_left = latents[..., :padding]
latents_right = latents[..., -padding:]
latents = torch.cat((latents_right, latents, latents_left), axis=-1)
image = self.vae.decode(latents, return_dict=False)[0]
padding_pix = self.vae_scale_factor * padding
image = image[..., padding_pix:-padding_pix]
return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
Expand Down Expand Up @@ -457,13 +470,16 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents

def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, circular_padding=False):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
# if panorama's height/width < window_size, num_blocks of height/width should return 1
panorama_height /= 8
panorama_width /= 8
num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1
if circular_padding:
num_blocks_width = panorama_width // stride if panorama_width > window_size else 1
else:
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
Expand Down Expand Up @@ -496,6 +512,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
circular_padding: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -560,6 +577,10 @@ def __call__(
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
circular_padding (`bool`, *optional*, defaults to `False`):
If set to True, circular padding is applied to ensure there are no stitching artifacts. Circular
padding allows the model to seamlessly generate a transition from the rightmost part of the image to
the leftmost part, maintaining consistency in a 360-degree sense.

Examples:

Expand Down Expand Up @@ -627,10 +648,9 @@ def __call__(

# 6. Define panorama grid and initialize views for synthesis.
# prepare batch grid
views = self.get_views(height, width)
views = self.get_views(height, width, circular_padding=circular_padding)
views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch)

count = torch.zeros_like(latents)
value = torch.zeros_like(latents)

Expand All @@ -655,9 +675,29 @@ def __call__(
for j, batch_view in enumerate(views_batch):
vb_size = len(batch_view)
# get the latents corresponding to the current view coordinates
latents_for_view = torch.cat(
[latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view]
)
if circular_padding:
latents_for_view = []
for h_start, h_end, w_start, w_end in batch_view:
if w_end > latents.shape[3]:
# Add circular horizontal padding
latent_view = torch.cat(
(
latents[:, :, h_start:h_end, w_start:],
latents[:, :, h_start:h_end, : w_end - latents.shape[3]],
),
axis=-1,
)
else:
latent_view = latents[:, :, h_start:h_end, w_start:w_end]
latents_for_view.append(latent_view)
latents_for_view = torch.cat(latents_for_view)
else:
latents_for_view = torch.cat(
[
latents[:, :, h_start:h_end, w_start:w_end]
for h_start, h_end, w_start, w_end in batch_view
]
)

# rematch block's scheduler status
self.scheduler.__dict__.update(views_scheduler_status[j])
Expand Down Expand Up @@ -698,8 +738,19 @@ def __call__(
for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
latents_denoised_batch.chunk(vb_size), batch_view
):
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
if circular_padding and w_end > latents.shape[3]:
# Case for circular padding
value[:, :, h_start:h_end, w_start:] += latents_view_denoised[
:, :, h_start:h_end, : latents.shape[3] - w_start
]
value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[
:, :, h_start:h_end, latents.shape[3] - w_start :
]
count[:, :, h_start:h_end, w_start:] += 1
count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1
else:
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1

# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = torch.where(count > 0, value / count, value)
Expand All @@ -711,7 +762,10 @@ def __call__(
callback(i, t, latents)

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if circular_padding:
image = self.decode_latents_with_padding(latents)
else:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
Expand Down
34 changes: 34 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ def test_stable_diffusion_panorama_default_case(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_panorama_circular_padding_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPanoramaPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs, circular_padding=True).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)

expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

# override to speed the overall test timing up.
def test_inference_batch_consistent(self):
super().test_inference_batch_consistent(batch_sizes=[1, 2])
Expand Down Expand Up @@ -170,6 +186,24 @@ def test_stable_diffusion_panorama_views_batch(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_panorama_views_batch_circular_padding(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPanoramaPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, circular_padding=True, view_batch_size=2)
image = output.images
image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 64, 64, 3)

expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_panorama_euler(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
Expand Down