diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index 88ad49d2b776..961e30155a3d 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -23,7 +23,7 @@ from ..models.attention_processor import AttentionProcessor from ..models.modeling_utils import ModelMixin from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from .controlnet import BaseOutput, zero_module +from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from .modeling_outputs import Transformer2DModelOutput from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock @@ -55,6 +55,7 @@ def __init__( guidance_embeds: bool = False, axes_dims_rope: List[int] = [16, 56, 56], num_mode: int = None, + conditioning_embedding_channels: int = None, ): super().__init__() self.out_channels = in_channels @@ -106,7 +107,14 @@ def __init__( if self.union: self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) - self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) + if conditioning_embedding_channels is not None: + self.input_hint_block = ControlNetConditioningEmbedding( + conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16) + ) + self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + else: + self.input_hint_block = None + self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) self.gradient_checkpointing = False @@ -269,6 +277,16 @@ def forward( ) hidden_states = self.x_embedder(hidden_states) + if self.input_hint_block is not None: + controlnet_cond = self.input_hint_block(controlnet_cond) + batch_size, channels, height_pw, width_pw = controlnet_cond.shape + height = height_pw // self.config.patch_size + width = width_pw // self.config.patch_size + controlnet_cond = controlnet_cond.reshape( + batch_size, channels, height, self.config.patch_size, width, self.config.patch_size + ) + controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) + controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) # add hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 6238ab8044bb..5d39a1bb5391 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -402,6 +402,7 @@ def forward( controlnet_block_samples=None, controlnet_single_block_samples=None, return_dict: bool = True, + controlnet_blocks_repeat: bool = False, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. @@ -508,7 +509,13 @@ def custom_forward(*inputs): if controlnet_block_samples is not None: interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) interval_control = int(np.ceil(interval_control)) - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index a301f6742c05..e9dedef0a58d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -752,19 +752,22 @@ def __call__( ) height, width = control_image.shape[-2:] - # vae encode - control_image = self.vae.encode(control_image).latent_dist.sample() - control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image.shape[2:] - control_image = self._pack_latents( - control_image, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True + if self.controlnet.input_hint_block is None: + # vae encode + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) # Here we ensure that `control_mode` has the same length as the control_image. if control_mode is not None: @@ -775,8 +778,9 @@ def __call__( elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] - - for control_image_ in control_image: + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True + for i, control_image_ in enumerate(control_image): control_image_ = self.prepare_image( image=control_image_, width=width, @@ -788,20 +792,20 @@ def __call__( ) height, width = control_image_.shape[-2:] - # vae encode - control_image_ = self.vae.encode(control_image_).latent_dist.sample() - control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image_.shape[2:] - control_image_ = self._pack_latents( - control_image_, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) - + if self.controlnet.nets[0].input_hint_block is None: + # vae encode + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) control_images.append(control_image_) control_image = control_images @@ -925,6 +929,7 @@ def __call__( img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, )[0] # compute the previous noisy sample x_t -> x_t-1