From 28469de9d185437896e2d07cace58c1c36f8433c Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Fri, 2 Aug 2024 07:29:47 +0000 Subject: [PATCH 01/10] Fix ```from_single_file``` for xl_inpaint --- src/diffusers/loaders/single_file_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 483125f24825..b700197fb9dc 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -442,6 +442,8 @@ def infer_diffusers_model_type(checkpoint): ): if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: model_type = "inpainting_v2" + elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: + model_type = "xl_inpaint" else: model_type = "inpainting" From 31d3707cb9bea7b6dceaf51fefaf487ba0ccf493 Mon Sep 17 00:00:00 2001 From: Gothos Date: Thu, 8 Aug 2024 18:41:26 +0000 Subject: [PATCH 02/10] Add basic flux inpaint pipeline --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/flux/__init__.py | 1 + .../pipelines/flux/pipeline_flux_inpaint.py | 988 ++++++++++++++++++ 4 files changed, 993 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_inpaint.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 73acea26c1f2..9b750c00fbbf 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -255,6 +255,7 @@ "CogVideoXPipeline", "CycleDiffusionPipeline", "FluxPipeline", + "FluxInpaintPipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", @@ -694,6 +695,7 @@ CogVideoXPipeline, CycleDiffusionPipeline, FluxPipeline, + FluxInpaintPipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 0ec57c43b8e0..8584f7d97e91 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -124,7 +124,7 @@ "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoPipeline", ] - _import_structure["flux"] = ["FluxPipeline"] + _import_structure["flux"] = ["FluxPipeline","FluxInpaintPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -494,7 +494,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) - from .flux import FluxPipeline + from .flux import FluxPipeline,FluxInpaintPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index d8c3edf0eaca..16b88b2310b4 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_flux"] = ["FluxPipeline"] + _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py new file mode 100644 index 000000000000..09c9560542f5 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -0,0 +1,988 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput,VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPipeline + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self,timesteps,num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max = None, + ): + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + # return latents.to(device=device, dtype=dtype), latent_image_ids + + if latents is None: + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = latents + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise) + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(2 * height // self.vae_scale_factor, 2 * width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + print(masked_image_latents.shape,mask.shape, "mask im la", "mask") + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + is_strength_max = strength == 1.0 + + + # 2. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height)//self.vae_scale_factor) * (int(width)//self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps,num_inference_steps = self.get_timesteps(timesteps,num_inference_steps,strength,device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_transformer = self.transformer.config.in_channels + + latents, noise,image_latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + init_image, + latent_timestep, + is_strength_max + + ) + + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + masked_image_latents = self._pack_latents(masked_image_latents,batch_size,num_channels_latents,2 * (int(height) // self.vae_scale_factor),2 * (int(width) // self.vae_scale_factor)) + mask = self._pack_latents(mask.repeat(1,num_channels_latents,1,1),batch_size,num_channels_latents,2 * (int(height) // self.vae_scale_factor),2 * (int(width) // self.vae_scale_factor)) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # # match the inpainting pipeline and will be updated with input + mask inpainting model later + # if num_channels_transformer = : + # # default case for runwayml/stable-diffusion-inpainting + # num_channels_mask = mask.shape[1] + # num_channels_masked_image = masked_image_latents.shape[1] + # if ( + # num_channels_latents + num_channels_mask + num_channels_masked_image + # != self.transformer.config.in_channels + # ): + # raise ValueError( + # f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" + # f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + # f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + # f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + # " `pipeline.transformer` or your `mask_image` or `image` input." + # ) + if num_channels_transformer != 64: + raise ValueError( + f"The transformer {self.transformer.__class__} should have 64 input channels, not {self.transformer.config.in_channels}." + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.tensor([guidance_scale], device=device) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + noise_pred = self.transformer( + hidden_states=latents, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if num_channels_transformer == 64: + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) From e305a0a65aa7ba50240b7e0356c94eae52d2d14a Mon Sep 17 00:00:00 2001 From: Gothos Date: Thu, 8 Aug 2024 21:18:16 +0000 Subject: [PATCH 03/10] style, quality, stray print --- src/diffusers/__init__.py | 4 +- src/diffusers/pipelines/__init__.py | 4 +- .../pipelines/flux/pipeline_flux_inpaint.py | 50 +++++++++++-------- .../dummy_torch_and_transformers_objects.py | 15 ++++++ 4 files changed, 48 insertions(+), 25 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9b750c00fbbf..61f782d16836 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -254,8 +254,8 @@ "CLIPImageProjection", "CogVideoXPipeline", "CycleDiffusionPipeline", - "FluxPipeline", "FluxInpaintPipeline", + "FluxPipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", @@ -694,8 +694,8 @@ CLIPImageProjection, CogVideoXPipeline, CycleDiffusionPipeline, - FluxPipeline, FluxInpaintPipeline, + FluxPipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8584f7d97e91..3c3a9b05ed83 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -124,7 +124,7 @@ "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoPipeline", ] - _import_structure["flux"] = ["FluxPipeline","FluxInpaintPipeline"] + _import_structure["flux"] = ["FluxPipeline", "FluxInpaintPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -494,7 +494,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) - from .flux import FluxPipeline,FluxInpaintPipeline + from .flux import FluxInpaintPipeline, FluxPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 09c9560542f5..5d051b9ce3b1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -19,7 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast -from ...image_processor import PipelineImageInput,VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel @@ -90,6 +90,7 @@ def retrieve_latents( else: raise AttributeError("Could not access latents of provided encoder_output") + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -388,7 +389,7 @@ def encode_prompt( text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) return prompt_embeds, pooled_prompt_embeds, text_ids - + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ @@ -402,9 +403,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor return image_latents - + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self,timesteps,num_inference_steps, strength, device): + def get_timesteps(self, timesteps, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(num_inference_steps * strength, num_inference_steps) @@ -429,7 +430,7 @@ def check_inputs( ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - + if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -516,9 +517,8 @@ def prepare_latents( latents=None, image=None, timestep=None, - is_strength_max = None, + is_strength_max=None, ): - if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -530,14 +530,14 @@ def prepare_latents( "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." "However, either the image or the noise timestep has not been provided." ) - + height = 2 * (int(height) // self.vae_scale_factor) width = 2 * (int(width) // self.vae_scale_factor) - + shape = (batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - # return latents.to(device=device, dtype=dtype), latent_image_ids - + # return latents.to(device=device, dtype=dtype), latent_image_ids + if latents is None: image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) @@ -602,7 +602,6 @@ def prepare_mask_latents( # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - print(masked_image_latents.shape,mask.shape, "mask im la", "mask") return mask, masked_image_latents @property @@ -769,7 +768,6 @@ def __call__( self._interrupt = False is_strength_max = strength == 1.0 - # 2. Preprocess mask and image if padding_mask_crop is not None: crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) @@ -814,7 +812,7 @@ def __call__( # 4.Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height)//self.vae_scale_factor) * (int(width)//self.vae_scale_factor) + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, @@ -830,7 +828,7 @@ def __call__( sigmas, mu=mu, ) - timesteps,num_inference_steps = self.get_timesteps(timesteps,num_inference_steps,strength,device) + timesteps, num_inference_steps = self.get_timesteps(timesteps, num_inference_steps, strength, device) if num_inference_steps < 1: raise ValueError( @@ -843,7 +841,7 @@ def __call__( num_channels_latents = self.transformer.config.in_channels // 4 num_channels_transformer = self.transformer.config.in_channels - latents, noise,image_latents, latent_image_ids = self.prepare_latents( + latents, noise, image_latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, @@ -854,11 +852,9 @@ def __call__( latents, init_image, latent_timestep, - is_strength_max - + is_strength_max, ) - mask_condition = self.mask_processor.preprocess( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) @@ -879,8 +875,20 @@ def __call__( device, generator, ) - masked_image_latents = self._pack_latents(masked_image_latents,batch_size,num_channels_latents,2 * (int(height) // self.vae_scale_factor),2 * (int(width) // self.vae_scale_factor)) - mask = self._pack_latents(mask.repeat(1,num_channels_latents,1,1),batch_size,num_channels_latents,2 * (int(height) // self.vae_scale_factor),2 * (int(width) // self.vae_scale_factor)) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + 2 * (int(height) // self.vae_scale_factor), + 2 * (int(width) // self.vae_scale_factor), + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + 2 * (int(height) // self.vae_scale_factor), + 2 * (int(width) // self.vae_scale_factor), + ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # # match the inpainting pipeline and will be updated with input + mask inpainting model later diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index fb722553ed28..5a51c08dffee 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -287,6 +287,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FluxInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 2f580032905737973ac8330ff1e916b703fd4fdc Mon Sep 17 00:00:00 2001 From: Gothos Date: Fri, 9 Aug 2024 06:10:04 +0000 Subject: [PATCH 04/10] Fix stray changes --- src/diffusers/loaders/single_file_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 2ca37630e7c4..9c2a2cbf2942 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -449,8 +449,6 @@ def infer_diffusers_model_type(checkpoint): ): if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: model_type = "inpainting_v2" - elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: - model_type = "xl_inpaint" else: model_type = "inpainting" From 3d1b1c20ea97beb3545b2a024d62ff232b98b991 Mon Sep 17 00:00:00 2001 From: Gothos Date: Tue, 13 Aug 2024 10:54:09 +0000 Subject: [PATCH 05/10] --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/flux/__init__.py | 1 + .../pipelines/flux/pipeline_flux_img2img.py | 862 ++++++++++++++++++ .../pipelines/flux/pipeline_flux_inpaint.py | 2 +- .../dummy_torch_and_transformers_objects.py | 15 + 6 files changed, 883 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_img2img.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 61f782d16836..a1085b4485a8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -254,6 +254,7 @@ "CLIPImageProjection", "CogVideoXPipeline", "CycleDiffusionPipeline", + "FluxImg2ImgPipeline", "FluxInpaintPipeline", "FluxPipeline", "HunyuanDiTControlNetPipeline", @@ -694,6 +695,7 @@ CLIPImageProjection, CogVideoXPipeline, CycleDiffusionPipeline, + FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline, HunyuanDiTControlNetPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3c3a9b05ed83..a9ffe6119757 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -124,7 +124,7 @@ "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoPipeline", ] - _import_structure["flux"] = ["FluxPipeline", "FluxInpaintPipeline"] + _import_structure["flux"] = ["FluxPipeline", "FluxImg2ImgPipeline", "FluxInpaintPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -494,7 +494,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) - from .flux import FluxInpaintPipeline, FluxPipeline + from .flux import FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 16b88b2310b4..15c07e3f6606 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_flux"] = ["FluxPipeline"] + _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py new file mode 100644 index 000000000000..52e3c8347245 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -0,0 +1,862 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPipeline + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, timesteps, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + # return latents.to(device=device, dtype=dtype), latent_image_ids + + if latents is None: + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = latents + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise) + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + is_strength_max = strength == 1.0 + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(timesteps, num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_transformer = self.transformer.config.in_channels + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + init_image, + latent_timestep, + is_strength_max, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.tensor([guidance_scale], device=device) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + noise_pred = self.transformer( + hidden_states=latents, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 5d051b9ce3b1..613909116129 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -837,7 +837,7 @@ def __call__( ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - # 4. Prepare latent variables + # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 num_channels_transformer = self.transformer.config.in_channels diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 5a51c08dffee..05589ee385b4 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -287,6 +287,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FluxImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 7bbf12cc94575a9d98f5d29471e612e84009951a Mon Sep 17 00:00:00 2001 From: Gothos Date: Tue, 13 Aug 2024 12:06:23 +0000 Subject: [PATCH 06/10] Add inpainting model support --- .../pipelines/flux/pipeline_flux_inpaint.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 613909116129..3c927402fbdb 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -891,22 +891,21 @@ def __call__( ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - # # match the inpainting pipeline and will be updated with input + mask inpainting model later - # if num_channels_transformer = : - # # default case for runwayml/stable-diffusion-inpainting - # num_channels_mask = mask.shape[1] - # num_channels_masked_image = masked_image_latents.shape[1] - # if ( - # num_channels_latents + num_channels_mask + num_channels_masked_image - # != self.transformer.config.in_channels - # ): - # raise ValueError( - # f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" - # f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - # f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - # f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - # " `pipeline.transformer` or your `mask_image` or `image` input." - # ) + if num_channels_transformer == 132: + # default case for inpainting models + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if ( + num_channels_latents + num_channels_mask + num_channels_masked_image + != self.transformer.config.in_channels + ): + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" + f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.transformer` or your `mask_image` or `image` input." + ) if num_channels_transformer != 64: raise ValueError( f"The transformer {self.transformer.__class__} should have 64 input channels, not {self.transformer.config.in_channels}." @@ -921,6 +920,9 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + if num_channels_transformer == 132: + latents = torch.cat([latents, mask, masked_image_latents], dim=1) + # handle guidance if self.transformer.config.guidance_embeds: guidance = torch.tensor([guidance_scale], device=device) From 2140924c150ece6620811027cc92bd340e70d3a5 Mon Sep 17 00:00:00 2001 From: Gothos Date: Mon, 2 Sep 2024 22:29:56 +0000 Subject: [PATCH 07/10] Change lora conversion for xlabs --- src/diffusers/loaders/lora_conversion_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 4b5426947906..8adeb70b47c9 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -562,7 +562,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): new_key += ".attn.to_out.0" elif "processor.proj_lora2" in old_key: new_key += ".attn.to_add_out" - elif "processor.qkv_lora1" in old_key and "up" not in old_key: + elif "processor.qkv_lora2" in old_key and "up" not in old_key: handle_qkv( old_state_dict, new_state_dict, @@ -574,7 +574,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): ], ) # continue - elif "processor.qkv_lora2" in old_key and "up" not in old_key: + elif "processor.qkv_lora1" in old_key and "up" not in old_key: handle_qkv( old_state_dict, new_state_dict, From 417b758c6439b633287e720d914cba3250a13950 Mon Sep 17 00:00:00 2001 From: Gothos Date: Mon, 2 Sep 2024 22:35:07 +0000 Subject: [PATCH 08/10] Fix stray changes --- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/flux/__init__.py | 2 - .../pipelines/flux/pipeline_flux_img2img.py | 862 --------------- .../pipelines/flux/pipeline_flux_inpaint.py | 998 ------------------ .../dummy_torch_and_transformers_objects.py | 30 - 5 files changed, 1 insertion(+), 1893 deletions(-) delete mode 100644 src/diffusers/pipelines/flux/pipeline_flux_img2img.py delete mode 100644 src/diffusers/pipelines/flux/pipeline_flux_inpaint.py diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 20cb18c8f5c6..63436e9be6b5 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -124,7 +124,7 @@ "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoPipeline", ] - _import_structure["flux"] = ["FluxPipeline", "FluxImg2ImgPipeline", "FluxInpaintPipeline", "FluxControlNetPipeline"] + _import_structure["flux"] = ["FluxPipeline", "FluxControlNetPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 7b1243d92acd..900189102c5b 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -23,8 +23,6 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_flux"] = ["FluxPipeline"] - _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] - _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py deleted file mode 100644 index 52e3c8347245..000000000000 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ /dev/null @@ -1,862 +0,0 @@ -# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import torch -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast - -from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin -from ...models.autoencoders import AutoencoderKL -from ...models.transformers import FluxTransformer2DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - USE_PEFT_BACKEND, - is_torch_xla_available, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import FluxPipelineOutput - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import FluxPipeline - - >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") - >>> prompt = "A cat holding a sign that says hello world" - >>> # Depending on the variant being used, the pipeline call will slightly vary. - >>> # Refer to the pipeline documentation for more details. - >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] - >>> image.save("flux.png") - ``` -""" - - -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.16, -): - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): - r""" - The Flux pipeline for image inpainting. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - - Args: - transformer ([`FluxTransformer2DModel`]): - Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_encoder_2 ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`T5TokenizerFast`): - Second Tokenizer of class - [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). - """ - - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _optional_components = [] - _callback_tensor_inputs = ["latents", "prompt_embeds"] - - def __init__( - self, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - text_encoder_2: T5EncoderModel, - tokenizer_2: T5TokenizerFast, - transformer: FluxTransformer2DModel, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - transformer=transformer, - scheduler=scheduler, - ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 - ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, - vae_latent_channels=self.vae.config.latent_channels, - do_normalize=False, - do_binarize=True, - do_convert_grayscale=True, - ) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 - ) - self.default_sample_size = 64 - - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) - - return prompt_embeds, pooled_prompt_embeds, text_ids - - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - return image_latents - - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, timesteps, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(num_inference_steps * strength, num_inference_steps) - - t_start = int(max(num_inference_steps - init_timestep, 0)) - timesteps = timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) - - return timesteps, num_inference_steps - t_start - - def check_inputs( - self, - prompt, - prompt_2, - strength, - height, - width, - prompt_embeds=None, - pooled_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - - @staticmethod - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) - latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - height = height // vae_scale_factor - width = width // vae_scale_factor - - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) - - return latents - - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - image=None, - timestep=None, - is_strength_max=None, - ): - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) - - shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - # return latents.to(device=device, dtype=dtype), latent_image_ids - - if latents is None: - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) - else: - image_latents = latents - - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise) - noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) - image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - return latents, noise, image_latents, latent_image_ids - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - image: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - strength: float = 0.6, - num_inference_steps: int = 28, - timesteps: List[int] = None, - guidance_scale: float = 7.0, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both - numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list - or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a - list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image - latents as `image`, but if passing latents directly it is not encoded again. - mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask - are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a - single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one - color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, - H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, - 1)`, or `(H, W)`. - mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): - `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask - latents tensor will ge generated by `mask_image`. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - padding_mask_crop (`int`, *optional*, defaults to `None`): - The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to - image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region - with the same aspect ration of the image and contains all masked area, and then expand that area based - on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before - resizing to the original image size for inpainting. This is useful when the masked area is small while - the image is large and contain information irrelevant for inpainting, such as background. - strength (`float`, *optional*, defaults to 1.0): - Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a - starting point and more noise is added the higher the `strength`. The number of denoising steps depends - on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising - process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 - essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - - Examples: - - Returns: - [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. - """ - - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - strength, - height, - width, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - max_sequence_length=max_sequence_length, - ) - - self._guidance_scale = guidance_scale - self._joint_attention_kwargs = joint_attention_kwargs - self._interrupt = False - is_strength_max = strength == 1.0 - - # 2. Preprocess image - init_image = self.image_processor.preprocess(image, height=height, width=width) - init_image = init_image.to(dtype=torch.float32) - - # 3. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - - # 4.Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) - mu = calculate_shift( - image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, - ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - timesteps, - sigmas, - mu=mu, - ) - timesteps, num_inference_steps = self.get_timesteps(timesteps, num_inference_steps, strength, device) - - if num_inference_steps < 1: - raise ValueError( - f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" - f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." - ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - num_channels_transformer = self.transformer.config.in_channels - - latents, noise, image_latents, latent_image_ids = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - init_image, - latent_timestep, - is_strength_max, - ) - - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.tensor([guidance_scale], device=device) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - - noise_pred = self.transformer( - hidden_states=latents, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - if output_type == "latent": - image = latents - - else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py deleted file mode 100644 index 3c927402fbdb..000000000000 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ /dev/null @@ -1,998 +0,0 @@ -# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import torch -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast - -from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin -from ...models.autoencoders import AutoencoderKL -from ...models.transformers import FluxTransformer2DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - USE_PEFT_BACKEND, - is_torch_xla_available, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import FluxPipelineOutput - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import FluxPipeline - - >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") - >>> prompt = "A cat holding a sign that says hello world" - >>> # Depending on the variant being used, the pipeline call will slightly vary. - >>> # Refer to the pipeline documentation for more details. - >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] - >>> image.save("flux.png") - ``` -""" - - -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.16, -): - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): - r""" - The Flux pipeline for image inpainting. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - - Args: - transformer ([`FluxTransformer2DModel`]): - Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_encoder_2 ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`T5TokenizerFast`): - Second Tokenizer of class - [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). - """ - - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _optional_components = [] - _callback_tensor_inputs = ["latents", "prompt_embeds"] - - def __init__( - self, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - text_encoder_2: T5EncoderModel, - tokenizer_2: T5TokenizerFast, - transformer: FluxTransformer2DModel, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - transformer=transformer, - scheduler=scheduler, - ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 - ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, - vae_latent_channels=self.vae.config.latent_channels, - do_normalize=False, - do_binarize=True, - do_convert_grayscale=True, - ) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 - ) - self.default_sample_size = 64 - - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) - - return prompt_embeds, pooled_prompt_embeds, text_ids - - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - return image_latents - - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, timesteps, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(num_inference_steps * strength, num_inference_steps) - - t_start = int(max(num_inference_steps - init_timestep, 0)) - timesteps = timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) - - return timesteps, num_inference_steps - t_start - - def check_inputs( - self, - prompt, - prompt_2, - strength, - height, - width, - prompt_embeds=None, - pooled_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - - @staticmethod - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) - latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - height = height // vae_scale_factor - width = width // vae_scale_factor - - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) - - return latents - - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - image=None, - timestep=None, - is_strength_max=None, - ): - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) - - shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - # return latents.to(device=device, dtype=dtype), latent_image_ids - - if latents is None: - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) - else: - image_latents = latents - - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise) - noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) - image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - return latents, noise, image_latents, latent_image_ids - - def prepare_mask_latents( - self, - mask, - masked_image, - batch_size, - num_images_per_prompt, - height, - width, - dtype, - device, - generator, - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(2 * height // self.vae_scale_factor, 2 * width // self.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - masked_image = masked_image.to(device=device, dtype=dtype) - - if masked_image.shape[1] == 16: - masked_image_latents = masked_image - else: - masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) - - masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - return mask, masked_image_latents - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - image: PipelineImageInput = None, - mask_image: PipelineImageInput = None, - masked_image_latents: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - padding_mask_crop: Optional[int] = None, - strength: float = 0.6, - num_inference_steps: int = 28, - timesteps: List[int] = None, - guidance_scale: float = 7.0, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both - numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list - or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a - list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image - latents as `image`, but if passing latents directly it is not encoded again. - mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask - are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a - single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one - color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, - H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, - 1)`, or `(H, W)`. - mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): - `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask - latents tensor will ge generated by `mask_image`. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - padding_mask_crop (`int`, *optional*, defaults to `None`): - The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to - image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region - with the same aspect ration of the image and contains all masked area, and then expand that area based - on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before - resizing to the original image size for inpainting. This is useful when the masked area is small while - the image is large and contain information irrelevant for inpainting, such as background. - strength (`float`, *optional*, defaults to 1.0): - Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a - starting point and more noise is added the higher the `strength`. The number of denoising steps depends - on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising - process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 - essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - - Examples: - - Returns: - [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. - """ - - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - strength, - height, - width, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - max_sequence_length=max_sequence_length, - ) - - self._guidance_scale = guidance_scale - self._joint_attention_kwargs = joint_attention_kwargs - self._interrupt = False - is_strength_max = strength == 1.0 - - # 2. Preprocess mask and image - if padding_mask_crop is not None: - crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) - resize_mode = "fill" - else: - crops_coords = None - resize_mode = "default" - - original_image = image - init_image = self.image_processor.preprocess( - image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode - ) - init_image = init_image.to(dtype=torch.float32) - - # 3. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - - # 4.Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) - mu = calculate_shift( - image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, - ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - timesteps, - sigmas, - mu=mu, - ) - timesteps, num_inference_steps = self.get_timesteps(timesteps, num_inference_steps, strength, device) - - if num_inference_steps < 1: - raise ValueError( - f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" - f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." - ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - num_channels_transformer = self.transformer.config.in_channels - - latents, noise, image_latents, latent_image_ids = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - init_image, - latent_timestep, - is_strength_max, - ) - - mask_condition = self.mask_processor.preprocess( - mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords - ) - - if masked_image_latents is None: - masked_image = init_image * (mask_condition < 0.5) - else: - masked_image = masked_image_latents - - mask, masked_image_latents = self.prepare_mask_latents( - mask_condition, - masked_image, - batch_size, - num_images_per_prompt, - height, - width, - prompt_embeds.dtype, - device, - generator, - ) - masked_image_latents = self._pack_latents( - masked_image_latents, - batch_size, - num_channels_latents, - 2 * (int(height) // self.vae_scale_factor), - 2 * (int(width) // self.vae_scale_factor), - ) - mask = self._pack_latents( - mask.repeat(1, num_channels_latents, 1, 1), - batch_size, - num_channels_latents, - 2 * (int(height) // self.vae_scale_factor), - 2 * (int(width) // self.vae_scale_factor), - ) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - - if num_channels_transformer == 132: - # default case for inpainting models - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] - if ( - num_channels_latents + num_channels_mask + num_channels_masked_image - != self.transformer.config.in_channels - ): - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" - f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.transformer` or your `mask_image` or `image` input." - ) - if num_channels_transformer != 64: - raise ValueError( - f"The transformer {self.transformer.__class__} should have 64 input channels, not {self.transformer.config.in_channels}." - ) - - # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - if num_channels_transformer == 132: - latents = torch.cat([latents, mask, masked_image_latents], dim=1) - - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.tensor([guidance_scale], device=device) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - - noise_pred = self.transformer( - hidden_states=latents, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if num_channels_transformer == 64: - init_latents_proper = image_latents - init_mask = mask - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( - init_latents_proper, torch.tensor([noise_timestep]), noise - ) - - latents = (1 - init_mask) * init_latents_proper + init_mask * latents - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - if output_type == "latent": - image = latents - - else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return FluxPipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ec4ddf7307f2..477beed49f52 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -287,36 +287,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class FluxImg2ImgPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class FluxInpaintPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class FluxControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From be077f7f4ef0a19c6c1c5d49e7c98d09a0dcd276 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 3 Sep 2024 10:29:25 +0530 Subject: [PATCH 09/10] Apply suggestions from code review --- src/diffusers/loaders/lora_conversion_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 8adeb70b47c9..ae5c131feaf4 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -562,6 +562,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): new_key += ".attn.to_out.0" elif "processor.proj_lora2" in old_key: new_key += ".attn.to_add_out" + # Handle text latents. elif "processor.qkv_lora2" in old_key and "up" not in old_key: handle_qkv( old_state_dict, @@ -574,6 +575,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): ], ) # continue + # Handle image latents. elif "processor.qkv_lora1" in old_key and "up" not in old_key: handle_qkv( old_state_dict, From abaa00c54808392292bcba0677d2b51337e2ce2a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 3 Sep 2024 10:31:07 +0530 Subject: [PATCH 10/10] style --- src/diffusers/loaders/lora_conversion_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index ae5c131feaf4..f6dea33e8e82 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -562,7 +562,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): new_key += ".attn.to_out.0" elif "processor.proj_lora2" in old_key: new_key += ".attn.to_add_out" - # Handle text latents. + # Handle text latents. elif "processor.qkv_lora2" in old_key and "up" not in old_key: handle_qkv( old_state_dict, @@ -575,7 +575,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): ], ) # continue - # Handle image latents. + # Handle image latents. elif "processor.qkv_lora1" in old_key and "up" not in old_key: handle_qkv( old_state_dict,