diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index 88ad49d2b776..339cf0bcaa33 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -22,12 +22,15 @@ from ..loaders import PeftAdapterMixin from ..models.attention_processor import AttentionProcessor from ..models.modeling_utils import ModelMixin -from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from .controlnet import BaseOutput, zero_module -from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) +from .controlnet import (BaseOutput, ControlNetConditioningEmbedding, + zero_module) +from .embeddings import (CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, FluxPosEmbed) from .modeling_outputs import Transformer2DModelOutput -from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock - +from .transformers.transformer_flux import (FluxSingleTransformerBlock, + FluxTransformerBlock) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -55,6 +58,7 @@ def __init__( guidance_embeds: bool = False, axes_dims_rope: List[int] = [16, 56, 56], num_mode: int = None, + conditioning_embedding_channels: int = None, ): super().__init__() self.out_channels = in_channels @@ -106,7 +110,15 @@ def __init__( if self.union: self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) - self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) + if conditioning_embedding_channels is not None: + self.input_hint_block = ControlNetConditioningEmbedding( + conditioning_embedding_channels=conditioning_embedding_channels, + block_out_channels=(16,16,16,16) + ) + self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + else: + self.input_hint_block = None + self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) self.gradient_checkpointing = False @@ -269,6 +281,16 @@ def forward( ) hidden_states = self.x_embedder(hidden_states) + if self.input_hint_block is not None: + controlnet_cond = self.input_hint_block(controlnet_cond) + batch_size, channels, height_pw, width_pw = controlnet_cond.shape + height = height_pw // self.config.patch_size + width = width_pw // self.config.patch_size + controlnet_cond = controlnet_cond.reshape( + batch_size, channels, height, self.config.patch_size, width, self.config.patch_size + ) + controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) + controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) # add hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 6238ab8044bb..e7b0c930e25d 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -402,6 +402,7 @@ def forward( controlnet_block_samples=None, controlnet_single_block_samples=None, return_dict: bool = True, + controlnet_blocks_repeat: bool = False, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. @@ -508,7 +509,11 @@ def custom_forward(*inputs): if controlnet_block_samples is not None: interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) interval_control = int(np.ceil(interval_control)) - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index a301f6742c05..89c76eba54d4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -739,6 +739,7 @@ def __call__( ) # 3. Prepare control image + controlnet_blocks_repeat = False num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, FluxControlNetModel): control_image = self.prepare_image( @@ -752,19 +753,22 @@ def __call__( ) height, width = control_image.shape[-2:] - # vae encode - control_image = self.vae.encode(control_image).latent_dist.sample() - control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image.shape[2:] - control_image = self._pack_latents( - control_image, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) + if self.controlnet.input_hint_block is None: + # vae encode + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + else: + controlnet_blocks_repeat = True # Here we ensure that `control_mode` has the same length as the control_image. if control_mode is not None: @@ -776,7 +780,7 @@ def __call__( elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] - for control_image_ in control_image: + for i, control_image_ in enumerate(control_image): control_image_ = self.prepare_image( image=control_image_, width=width, @@ -788,19 +792,22 @@ def __call__( ) height, width = control_image_.shape[-2:] - # vae encode - control_image_ = self.vae.encode(control_image_).latent_dist.sample() - control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image_.shape[2:] - control_image_ = self._pack_latents( - control_image_, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) + if self.controlnet.nets[i].input_hint_block is None: + # vae encode + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + else: + controlnet_blocks_repeat = True control_images.append(control_image_) @@ -925,6 +932,7 @@ def __call__( img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, )[0] # compute the previous noisy sample x_t -> x_t-1