diff --git a/docs/source/en/api/pipelines/panorama.mdx b/docs/source/en/api/pipelines/panorama.mdx index 044901f24bf3..75c27f129ad8 100644 --- a/docs/source/en/api/pipelines/panorama.mdx +++ b/docs/source/en/api/pipelines/panorama.mdx @@ -60,6 +60,25 @@ and increase the VRAM usage. + + +Circular padding is applied to ensure there are no stitching artifacts when working with +panoramas that needs to seamlessly transition from the rightmost part to the leftmost part. +By enabling circular padding (set `circular_padding=True`), the operation applies additional +crops after the rightmost point of the image, allowing the model to "see” the transition +from the rightmost part to the leftmost part. This helps maintain visual consistency in +a 360-degree sense and creates a proper “panorama” that can be viewed using 360-degree +panorama viewers. When decoding latents in StableDiffusion, circular padding is applied +to ensure that the decoded latents match in the RGB space. + +Without circular padding, there is a stitching artifact (default): +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20no_circular_padding.png) + +With circular padding, the right and the left parts are matching (`circular_padding=True`): +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20circular_padding.png) + + + ## StableDiffusionPanoramaPipeline [[autodoc]] StableDiffusionPanoramaPipeline - __call__ diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index e03687e89eb1..7ebeaa17c0bb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -373,6 +373,19 @@ def decode_latents(self, latents): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image + def decode_latents_with_padding(self, latents, padding=8): + # Add padding to latents for circular inference + # padding is the number of latents to add on each side + # it would slightly increase the memory usage, but remove the boundary artifacts + latents = 1 / self.vae.config.scaling_factor * latents + latents_left = latents[..., :padding] + latents_right = latents[..., -padding:] + latents = torch.cat((latents_right, latents, latents_left), axis=-1) + image = self.vae.decode(latents, return_dict=False)[0] + padding_pix = self.vae_scale_factor * padding + image = image[..., padding_pix:-padding_pix] + return image + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -457,13 +470,16 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): + def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, circular_padding=False): # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) # if panorama's height/width < window_size, num_blocks of height/width should return 1 panorama_height /= 8 panorama_width /= 8 num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1 - num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1 + if circular_padding: + num_blocks_width = panorama_width // stride if panorama_width > window_size else 1 + else: + num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1 total_num_blocks = int(num_blocks_height * num_blocks_width) views = [] for i in range(total_num_blocks): @@ -496,6 +512,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + circular_padding: bool = False, ): r""" Function invoked when calling the pipeline for generation. @@ -560,6 +577,10 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + circular_padding (`bool`, *optional*, defaults to `False`): + If set to True, circular padding is applied to ensure there are no stitching artifacts. Circular + padding allows the model to seamlessly generate a transition from the rightmost part of the image to + the leftmost part, maintaining consistency in a 360-degree sense. Examples: @@ -627,10 +648,9 @@ def __call__( # 6. Define panorama grid and initialize views for synthesis. # prepare batch grid - views = self.get_views(height, width) + views = self.get_views(height, width, circular_padding=circular_padding) views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch) - count = torch.zeros_like(latents) value = torch.zeros_like(latents) @@ -655,9 +675,29 @@ def __call__( for j, batch_view in enumerate(views_batch): vb_size = len(batch_view) # get the latents corresponding to the current view coordinates - latents_for_view = torch.cat( - [latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view] - ) + if circular_padding: + latents_for_view = [] + for h_start, h_end, w_start, w_end in batch_view: + if w_end > latents.shape[3]: + # Add circular horizontal padding + latent_view = torch.cat( + ( + latents[:, :, h_start:h_end, w_start:], + latents[:, :, h_start:h_end, : w_end - latents.shape[3]], + ), + axis=-1, + ) + else: + latent_view = latents[:, :, h_start:h_end, w_start:w_end] + latents_for_view.append(latent_view) + latents_for_view = torch.cat(latents_for_view) + else: + latents_for_view = torch.cat( + [ + latents[:, :, h_start:h_end, w_start:w_end] + for h_start, h_end, w_start, w_end in batch_view + ] + ) # rematch block's scheduler status self.scheduler.__dict__.update(views_scheduler_status[j]) @@ -698,8 +738,19 @@ def __call__( for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( latents_denoised_batch.chunk(vb_size), batch_view ): - value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised - count[:, :, h_start:h_end, w_start:w_end] += 1 + if circular_padding and w_end > latents.shape[3]: + # Case for circular padding + value[:, :, h_start:h_end, w_start:] += latents_view_denoised[ + :, :, h_start:h_end, : latents.shape[3] - w_start + ] + value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[ + :, :, h_start:h_end, latents.shape[3] - w_start : + ] + count[:, :, h_start:h_end, w_start:] += 1 + count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1 + else: + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count[:, :, h_start:h_end, w_start:w_end] += 1 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = torch.where(count > 0, value / count, value) @@ -711,7 +762,10 @@ def __call__( callback(i, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if circular_padding: + image = self.decode_latents_with_padding(latents) + else: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 080bd0091f4f..131e9402c7eb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -125,6 +125,22 @@ def test_stable_diffusion_panorama_default_case(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_panorama_circular_padding_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionPanoramaPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs, circular_padding=True).images + image_slice = image[0, -3:, -3:, -1] + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + # override to speed the overall test timing up. def test_inference_batch_consistent(self): super().test_inference_batch_consistent(batch_sizes=[1, 2]) @@ -170,6 +186,24 @@ def test_stable_diffusion_panorama_views_batch(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_panorama_views_batch_circular_padding(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionPanoramaPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, circular_padding=True, view_batch_size=2) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_panorama_euler(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components()