Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from transformers import T5EncoderModel, T5Tokenizer

from ...image_processor import VaeImageProcessor
Expand All @@ -43,7 +44,6 @@
if is_ftfy_available():
import ftfy


EXAMPLE_DOC_STRING = """
Examples:
```py
Expand All @@ -60,6 +60,42 @@
```
"""

ASPECT_RATIO_1024_BIN = {
"0.25": [512.0, 2048.0],
"0.28": [512.0, 1856.0],
"0.32": [576.0, 1792.0],
"0.33": [576.0, 1728.0],
"0.35": [576.0, 1664.0],
"0.4": [640.0, 1600.0],
"0.42": [640.0, 1536.0],
"0.48": [704.0, 1472.0],
"0.5": [704.0, 1408.0],
"0.52": [704.0, 1344.0],
"0.57": [768.0, 1344.0],
"0.6": [768.0, 1280.0],
"0.68": [832.0, 1216.0],
"0.72": [832.0, 1152.0],
"0.78": [896.0, 1152.0],
"0.82": [896.0, 1088.0],
"0.88": [960.0, 1088.0],
"0.94": [960.0, 1024.0],
"1.0": [1024.0, 1024.0],
"1.07": [1024.0, 960.0],
"1.13": [1088.0, 960.0],
"1.21": [1088.0, 896.0],
"1.29": [1152.0, 896.0],
"1.38": [1152.0, 832.0],
"1.46": [1216.0, 832.0],
"1.67": [1280.0, 768.0],
"1.75": [1344.0, 768.0],
"2.0": [1408.0, 704.0],
"2.09": [1472.0, 704.0],
"2.4": [1536.0, 640.0],
"2.5": [1600.0, 640.0],
"3.0": [1728.0, 576.0],
"4.0": [2048.0, 512.0],
}


class PixArtAlphaPipeline(DiffusionPipeline):
r"""
Expand Down Expand Up @@ -495,6 +531,38 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents

@staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
"""Returns binned height and width."""
ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
default_hw = ratios[closest_ratio]
return int(default_hw[0]), int(default_hw[1])

@staticmethod
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
orig_height, orig_width = samples.shape[2], samples.shape[3]

# Check if resizing is needed
if orig_height != new_height or orig_width != new_width:
ratio = max(new_height / orig_height, new_width / orig_width)
resized_width = int(orig_width * ratio)
resized_height = int(orig_height * ratio)

# Resize
samples = F.interpolate(
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)

# Center Crop
start_x = (resized_width - new_width) // 2
end_x = start_x + new_width
start_y = (resized_height - new_height) // 2
end_y = start_y + new_height
samples = samples[:, :, start_y:end_y, start_x:end_x]

return samples

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand All @@ -518,6 +586,7 @@ def __call__(
callback_steps: int = 1,
clean_caption: bool = True,
mask_feature: bool = True,
use_resolution_binning: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -580,6 +649,10 @@ def __call__(
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
use_resolution_binning:
(`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the
closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images,
they are resized back to the requested resolution. Useful for generating non-square images.

Examples:

Expand All @@ -591,6 +664,10 @@ def __call__(
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
orig_height, orig_width = height, width
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)

self.check_inputs(
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
)
Expand Down Expand Up @@ -709,6 +786,8 @@ def __call__(

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.resize_and_crop_tensor(image, orig_width, orig_height)
else:
image = latents

Expand Down
9 changes: 7 additions & 2 deletions tests/pipelines/pixart/test_pixart.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "numpy",
"use_resolution_binning": False,
"output_type": "np",
}
return inputs

Expand Down Expand Up @@ -120,6 +121,7 @@ def test_save_load_optional_components(self):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}

# set all optional components to None
Expand Down Expand Up @@ -154,6 +156,7 @@ def test_save_load_optional_components(self):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}

output_loaded = pipe_loaded(**inputs)[0]
Expand Down Expand Up @@ -189,8 +192,8 @@ def test_inference_non_square_images(self):
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs, height=32, width=48).images
image_slice = image[0, -3:, -3:, -1]

self.assertEqual(image.shape, (1, 32, 48, 3))

expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
Expand Down Expand Up @@ -219,6 +222,7 @@ def test_inference_with_embeddings_and_multiple_images(self):
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"num_images_per_prompt": 2,
"use_resolution_binning": False,
}

# set all optional components to None
Expand Down Expand Up @@ -254,6 +258,7 @@ def test_inference_with_embeddings_and_multiple_images(self):
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"num_images_per_prompt": 2,
"use_resolution_binning": False,
}

output_loaded = pipe_loaded(**inputs)[0]
Expand Down