Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions src/diffusers/models/controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
from ..loaders import PeftAdapterMixin
from ..models.attention_processor import AttentionProcessor
from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from .controlnet import BaseOutput, zero_module
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..utils import (USE_PEFT_BACKEND, is_torch_version, logging,
scale_lora_layers, unscale_lora_layers)
from .controlnet import (BaseOutput, ControlNetConditioningEmbedding,
zero_module)
from .embeddings import (CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings, FluxPosEmbed)
from .modeling_outputs import Transformer2DModelOutput
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock

from .transformers.transformer_flux import (FluxSingleTransformerBlock,
FluxTransformerBlock)

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -55,6 +58,7 @@ def __init__(
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
conditioning_embedding_channels: int = None,
):
super().__init__()
self.out_channels = in_channels
Expand Down Expand Up @@ -106,7 +110,15 @@ def __init__(
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)

self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
if conditioning_embedding_channels is not None:
self.input_hint_block = ControlNetConditioningEmbedding(
conditioning_embedding_channels=conditioning_embedding_channels,
block_out_channels=(16,16,16,16)
)
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
else:
self.input_hint_block = None
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))

self.gradient_checkpointing = False

Expand Down Expand Up @@ -269,6 +281,16 @@ def forward(
)
hidden_states = self.x_embedder(hidden_states)

if self.input_hint_block is not None:
controlnet_cond = self.input_hint_block(controlnet_cond)
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
height = height_pw // self.config.patch_size
width = width_pw // self.config.patch_size
controlnet_cond = controlnet_cond.reshape(
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
)
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
# add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)

Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def forward(
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Expand Down Expand Up @@ -508,7 +509,11 @@ def custom_forward(*inputs):
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

Expand Down
62 changes: 35 additions & 27 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,7 @@ def __call__(
)

# 3. Prepare control image
controlnet_blocks_repeat = False
num_channels_latents = self.transformer.config.in_channels // 4
if isinstance(self.controlnet, FluxControlNetModel):
control_image = self.prepare_image(
Expand All @@ -752,19 +753,22 @@ def __call__(
)
height, width = control_image.shape[-2:]

# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
if self.controlnet.input_hint_block is None:
# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
else:
controlnet_blocks_repeat = True

# Here we ensure that `control_mode` has the same length as the control_image.
if control_mode is not None:
Expand All @@ -776,7 +780,7 @@ def __call__(
elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []

for control_image_ in control_image:
for i, control_image_ in enumerate(control_image):
control_image_ = self.prepare_image(
image=control_image_,
width=width,
Expand All @@ -788,19 +792,22 @@ def __call__(
)
height, width = control_image_.shape[-2:]

# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
if self.controlnet.nets[i].input_hint_block is None:
# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
else:
controlnet_blocks_repeat = True

control_images.append(control_image_)

Expand Down Expand Up @@ -925,6 +932,7 @@ def __call__(
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
controlnet_blocks_repeat=controlnet_blocks_repeat,
)[0]

# compute the previous noisy sample x_t -> x_t-1
Expand Down
Loading