Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}

# remove `null` components
init_dict = {k: v for k, v in init_dict.items() if v[0] is not None}
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True

init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}

if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
Expand All @@ -560,12 +567,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True

# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module and passed_class_obj[name] is not None:
if not is_pipeline_module:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
Expand All @@ -581,12 +587,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None and name not in pipeline_class._optional_components:
logger.warning(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
Expand All @@ -608,7 +608,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}

if loaded_sub_model is None and sub_model_should_be_defined:
if loaded_sub_model is None:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
Expand Down
11 changes: 6 additions & 5 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_xformers_memory_efficient_attention(self):
Expand Down Expand Up @@ -379,7 +380,7 @@ def check_inputs(self, prompt, height, width, callback_steps):
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
Expand Down Expand Up @@ -420,9 +421,9 @@ def __call__(
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -469,8 +470,8 @@ def __call__(
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
):
super().__init__()
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)

@torch.no_grad()
def __call__(
Expand All @@ -79,9 +80,9 @@ def __call__(
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand All @@ -107,8 +108,8 @@ def __call__(
generated images.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
Expand Down Expand Up @@ -168,8 +169,8 @@ def _generate(
neg_prompt_ids: jnp.array = None,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand All @@ -192,7 +193,12 @@ def _generate(
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])

latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_shape = (
batch_size,
self.unet.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
Expand Down Expand Up @@ -269,9 +275,9 @@ def __call__(
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -307,8 +313,8 @@ def __call__(
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if jit:
images = _p_generate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
Expand Down Expand Up @@ -206,8 +207,8 @@ def __call__(
**kwargs,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -241,7 +242,12 @@ def __call__(

# get the initial random noise unless the user supplied it
latents_dtype = text_embeddings.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
4,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
Expand Down Expand Up @@ -273,9 +274,9 @@ def __call__(
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -321,8 +322,8 @@ def __call__(
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -358,7 +359,12 @@ def __call__(
)

num_channels_latents = NUM_LATENT_CHANNELS
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
latents_dtype = text_embeddings.dtype
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def preprocess(image):
return 2.0 * image - 1.0


def preprocess_mask(mask):
def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
Expand Down Expand Up @@ -143,6 +143,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
Expand Down Expand Up @@ -349,7 +350,7 @@ def __call__(

# preprocess mask
if not isinstance(mask_image, np.ndarray):
mask_image = preprocess_mask(mask_image)
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
mask_image = mask_image.astype(latents_dtype)
mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_xformers_memory_efficient_attention(self):
Expand Down Expand Up @@ -378,7 +379,7 @@ def check_inputs(self, prompt, height, width, callback_steps):
)

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
Expand Down Expand Up @@ -419,9 +420,9 @@ def __call__(
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -468,8 +469,8 @@ def __call__(
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * 8
width = width or self.unet.config.sample_size * 8
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
Expand Down
Loading