From cdca0bf43bd588626b776b202bdbf03798a62e49 Mon Sep 17 00:00:00 2001 From: Anzhella Pankratova Date: Thu, 10 Oct 2024 19:51:56 +0300 Subject: [PATCH 1/9] Add support of Xlabs Controlnets --- src/diffusers/models/controlnet_flux.py | 27 ++++++++++++++++++- .../models/transformers/transformer_flux.py | 5 +++- .../flux/pipeline_flux_controlnet.py | 13 ++++++++- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index 88ad49d2b776..e752c30aa0ce 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn +from einops import rearrange from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import PeftAdapterMixin @@ -55,6 +56,7 @@ def __init__( guidance_embeds: bool = False, axes_dims_rope: List[int] = [16, 56, 56], num_mode: int = None, + is_xlabs_controlnet: bool = False, ): super().__init__() self.out_channels = in_channels @@ -106,7 +108,27 @@ def __init__( if self.union: self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) - self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) + if self.is_xlabs_controlnet: + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + else: + self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) self.gradient_checkpointing = False @@ -269,6 +291,9 @@ def forward( ) hidden_states = self.x_embedder(hidden_states) + if self.is_xlabs_controlnet: + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) # add hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 6238ab8044bb..61b59be7f04b 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -508,7 +508,10 @@ def custom_forward(*inputs): if controlnet_block_samples is not None: interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) interval_control = int(np.ceil(interval_control)) - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + if len(controlnet_block_samples) == 2: + hidden_states = hidden_states + controlnet_block_samples[index_block % 2] + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index a301f6742c05..38d3f7a37656 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -740,7 +740,7 @@ def __call__( # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 - if isinstance(self.controlnet, FluxControlNetModel): + if isinstance(self.controlnet, FluxControlNetModel) and not self.controlnet.is_xlabs_controlnet: control_image = self.prepare_image( image=control_image, width=width, @@ -773,6 +773,17 @@ def __call__( control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) + elif isinstance(self.controlnet, FluxControlNetModel) and self.controlnet.is_xlabs_controlnet: + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] From f38671ebc9cdf566c23b4b86530866f41c8b69f8 Mon Sep 17 00:00:00 2001 From: Anzhella Pankratova Date: Fri, 11 Oct 2024 16:33:48 +0300 Subject: [PATCH 2/9] use torch reshape instead of einops, fix pipeline_flux_controlnet.py --- src/diffusers/models/controlnet_flux.py | 10 ++++- .../models/transformers/transformer_flux.py | 1 + .../flux/pipeline_flux_controlnet.py | 40 +++++++------------ 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index e752c30aa0ce..7c3690912b9b 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn -from einops import rearrange from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import PeftAdapterMixin @@ -293,7 +292,14 @@ def forward( if self.is_xlabs_controlnet: controlnet_cond = self.input_hint_block(controlnet_cond) - controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + batch_size, channels, height_pw, width_pw = controlnet_cond.shape + height = height_pw // self.config.patch_size + width = width_pw // self.config.patch_size + controlnet_cond = controlnet_cond.reshape( + batch_size, channels, height, self.config.patch_size, width, self.config.patch_size + ) + controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) + controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) # add hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 61b59be7f04b..a216636f2f49 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -508,6 +508,7 @@ def custom_forward(*inputs): if controlnet_block_samples is not None: interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. if len(controlnet_block_samples) == 2: hidden_states = hidden_states + controlnet_block_samples[index_block % 2] else: diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 38d3f7a37656..c8a2c26b5096 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -740,7 +740,7 @@ def __call__( # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 - if isinstance(self.controlnet, FluxControlNetModel) and not self.controlnet.is_xlabs_controlnet: + if isinstance(self.controlnet, FluxControlNetModel): control_image = self.prepare_image( image=control_image, width=width, @@ -752,19 +752,20 @@ def __call__( ) height, width = control_image.shape[-2:] - # vae encode - control_image = self.vae.encode(control_image).latent_dist.sample() - control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image.shape[2:] - control_image = self._pack_latents( - control_image, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) + if not self.controlnet.is_xlabs_controlnet: + # vae encode + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) # Here we ensure that `control_mode` has the same length as the control_image. if control_mode is not None: @@ -773,17 +774,6 @@ def __call__( control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) - elif isinstance(self.controlnet, FluxControlNetModel) and self.controlnet.is_xlabs_controlnet: - control_image = self.prepare_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.vae.dtype, - ) - elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] From b8c9496310414d460dcb6f47b7141aba1485ba04 Mon Sep 17 00:00:00 2001 From: Anzhella Pankratova Date: Fri, 11 Oct 2024 18:04:26 +0300 Subject: [PATCH 3/9] Use ControlNetConditioningEmbedding for input_hint_block --- src/diffusers/models/controlnet_flux.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index 7c3690912b9b..b0eb3815a0db 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -23,7 +23,7 @@ from ..models.attention_processor import AttentionProcessor from ..models.modeling_utils import ModelMixin from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from .controlnet import BaseOutput, zero_module +from .controlnet import BaseOutput, zero_module, ControlNetConditioningEmbedding from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from .modeling_outputs import Transformer2DModelOutput from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock @@ -108,22 +108,9 @@ def __init__( self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) if self.is_xlabs_controlnet: - self.input_hint_block = nn.Sequential( - nn.Conv2d(3, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1, stride=2), - nn.SiLU(), - zero_module(nn.Conv2d(16, 16, 3, padding=1)) + self.input_hint_block = ControlNetConditioningEmbedding( + conditioning_embedding_channels=16, + block_out_channels=(16,16,16,16) ) self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) else: From 7be937e8230909dcb79d22da8f0b6297a40441bf Mon Sep 17 00:00:00 2001 From: Anzhella Pankratova Date: Mon, 14 Oct 2024 12:03:34 +0300 Subject: [PATCH 4/9] Use conditioning_embedding_channels instead of is_xlabs_controlnet in config --- src/diffusers/models/controlnet_flux.py | 9 +++++---- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index b0eb3815a0db..abfecaa9d29e 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -55,7 +55,7 @@ def __init__( guidance_embeds: bool = False, axes_dims_rope: List[int] = [16, 56, 56], num_mode: int = None, - is_xlabs_controlnet: bool = False, + conditioning_embedding_channels: int = None, ): super().__init__() self.out_channels = in_channels @@ -107,13 +107,14 @@ def __init__( if self.union: self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) - if self.is_xlabs_controlnet: + if conditioning_embedding_channels is not None: self.input_hint_block = ControlNetConditioningEmbedding( - conditioning_embedding_channels=16, + conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16,16,16,16) ) self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) else: + self.input_hint_block = None self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) self.gradient_checkpointing = False @@ -277,7 +278,7 @@ def forward( ) hidden_states = self.x_embedder(hidden_states) - if self.is_xlabs_controlnet: + if self.input_hint_block is not None: controlnet_cond = self.input_hint_block(controlnet_cond) batch_size, channels, height_pw, width_pw = controlnet_cond.shape height = height_pw // self.config.patch_size diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index c8a2c26b5096..0990c6f4aced 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -752,7 +752,7 @@ def __call__( ) height, width = control_image.shape[-2:] - if not self.controlnet.is_xlabs_controlnet: + if self.controlnet.input_hint_block is None: # vae encode control_image = self.vae.encode(control_image).latent_dist.sample() control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor From b07b48dc3c8c32838619b1937190411ede8650af Mon Sep 17 00:00:00 2001 From: Anzhella Pankratova Date: Tue, 15 Oct 2024 11:16:54 +0300 Subject: [PATCH 5/9] Add controlnet_blocks_repeat to Flux forward --- src/diffusers/models/transformers/transformer_flux.py | 5 +++-- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index a216636f2f49..e7b0c930e25d 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -402,6 +402,7 @@ def forward( controlnet_block_samples=None, controlnet_single_block_samples=None, return_dict: bool = True, + controlnet_blocks_repeat: bool = False, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. @@ -509,8 +510,8 @@ def custom_forward(*inputs): interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) interval_control = int(np.ceil(interval_control)) # For Xlabs ControlNet. - if len(controlnet_block_samples) == 2: - hidden_states = hidden_states + controlnet_block_samples[index_block % 2] + if controlnet_blocks_repeat: + hidden_states = hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 0990c6f4aced..9196529c61f0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -739,6 +739,7 @@ def __call__( ) # 3. Prepare control image + controlnet_blocks_repeat = False num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, FluxControlNetModel): control_image = self.prepare_image( @@ -766,6 +767,8 @@ def __call__( height_control_image, width_control_image, ) + else: + controlnet_blocks_repeat = True # Here we ensure that `control_mode` has the same length as the control_image. if control_mode is not None: @@ -926,6 +929,7 @@ def __call__( img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, )[0] # compute the previous noisy sample x_t -> x_t-1 From bff8644a73d79cbd99fe2f756de10fdb519e71a0 Mon Sep 17 00:00:00 2001 From: Anzhella Pankratova Date: Tue, 15 Oct 2024 12:23:20 +0300 Subject: [PATCH 6/9] Fix for FluxMultiControlNetModel --- .../flux/pipeline_flux_controlnet.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 9196529c61f0..89c76eba54d4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -780,7 +780,7 @@ def __call__( elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] - for control_image_ in control_image: + for i, control_image_ in enumerate(control_image): control_image_ = self.prepare_image( image=control_image_, width=width, @@ -792,19 +792,22 @@ def __call__( ) height, width = control_image_.shape[-2:] - # vae encode - control_image_ = self.vae.encode(control_image_).latent_dist.sample() - control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image_.shape[2:] - control_image_ = self._pack_latents( - control_image_, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) + if self.controlnet.nets[i].input_hint_block is None: + # vae encode + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + else: + controlnet_blocks_repeat = True control_images.append(control_image_) From 8cf0105b2d4e8fd9a51a9760130681f20d7262df Mon Sep 17 00:00:00 2001 From: Anzhella Pankratova Date: Tue, 15 Oct 2024 18:11:52 +0300 Subject: [PATCH 7/9] Fix import order --- src/diffusers/models/controlnet_flux.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index abfecaa9d29e..339cf0bcaa33 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -22,12 +22,15 @@ from ..loaders import PeftAdapterMixin from ..models.attention_processor import AttentionProcessor from ..models.modeling_utils import ModelMixin -from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from .controlnet import BaseOutput, zero_module, ControlNetConditioningEmbedding -from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) +from .controlnet import (BaseOutput, ControlNetConditioningEmbedding, + zero_module) +from .embeddings import (CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, FluxPosEmbed) from .modeling_outputs import Transformer2DModelOutput -from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock - +from .transformers.transformer_flux import (FluxSingleTransformerBlock, + FluxTransformerBlock) logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 5b01e1ee089e8979b886efe1f14672947d75155f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 15 Oct 2024 22:48:31 +0200 Subject: [PATCH 8/9] style --- src/diffusers/models/controlnet_flux.py | 16 ++++++---------- .../models/transformers/transformer_flux.py | 4 +++- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index 339cf0bcaa33..961e30155a3d 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -22,15 +22,12 @@ from ..loaders import PeftAdapterMixin from ..models.attention_processor import AttentionProcessor from ..models.modeling_utils import ModelMixin -from ..utils import (USE_PEFT_BACKEND, is_torch_version, logging, - scale_lora_layers, unscale_lora_layers) -from .controlnet import (BaseOutput, ControlNetConditioningEmbedding, - zero_module) -from .embeddings import (CombinedTimestepGuidanceTextProjEmbeddings, - CombinedTimestepTextProjEmbeddings, FluxPosEmbed) +from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module +from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from .modeling_outputs import Transformer2DModelOutput -from .transformers.transformer_flux import (FluxSingleTransformerBlock, - FluxTransformerBlock) +from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -112,8 +109,7 @@ def __init__( if conditioning_embedding_channels is not None: self.input_hint_block = ControlNetConditioningEmbedding( - conditioning_embedding_channels=conditioning_embedding_channels, - block_out_channels=(16,16,16,16) + conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16) ) self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) else: diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index e7b0c930e25d..5d39a1bb5391 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -511,7 +511,9 @@ def custom_forward(*inputs): interval_control = int(np.ceil(interval_control)) # For Xlabs ControlNet. if controlnet_blocks_repeat: - hidden_states = hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] From 1cc5da30da3f40ca2988e920379f96ca4b4e2301 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 15 Oct 2024 23:51:36 +0200 Subject: [PATCH 9/9] up --- .../pipelines/flux/pipeline_flux_controlnet.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 89c76eba54d4..e9dedef0a58d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -739,7 +739,6 @@ def __call__( ) # 3. Prepare control image - controlnet_blocks_repeat = False num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, FluxControlNetModel): control_image = self.prepare_image( @@ -753,6 +752,8 @@ def __call__( ) height, width = control_image.shape[-2:] + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True if self.controlnet.input_hint_block is None: # vae encode control_image = self.vae.encode(control_image).latent_dist.sample() @@ -767,8 +768,6 @@ def __call__( height_control_image, width_control_image, ) - else: - controlnet_blocks_repeat = True # Here we ensure that `control_mode` has the same length as the control_image. if control_mode is not None: @@ -779,7 +778,8 @@ def __call__( elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] - + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True for i, control_image_ in enumerate(control_image): control_image_ = self.prepare_image( image=control_image_, @@ -792,7 +792,7 @@ def __call__( ) height, width = control_image_.shape[-2:] - if self.controlnet.nets[i].input_hint_block is None: + if self.controlnet.nets[0].input_hint_block is None: # vae encode control_image_ = self.vae.encode(control_image_).latent_dist.sample() control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor @@ -806,9 +806,6 @@ def __call__( height_control_image, width_control_image, ) - else: - controlnet_blocks_repeat = True - control_images.append(control_image_) control_image = control_images