Skip to content
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanModularPipeline",
]
Expand Down Expand Up @@ -1090,6 +1091,7 @@
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
Wan22AutoBlocks,
WanAutoBlocks,
WanModularPipeline,
)
Expand Down
15 changes: 14 additions & 1 deletion src/diffusers/guiders/adaptive_projected_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -88,6 +88,19 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
data_batches.append(data_batch)
return data_batches

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

Expand Down
15 changes: 14 additions & 1 deletion src/diffusers/guiders/adaptive_projected_guidance_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -99,6 +99,19 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
data_batches.append(data_batch)
return data_batches

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/guiders/auto_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
data_batches.append(data_batch)
return data_batches

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/guiders/classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -99,6 +99,16 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
data_batches.append(data_batch)
return data_batches

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/guiders/classifier_free_zero_star_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -85,6 +85,16 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
data_batches.append(data_batch)
return data_batches

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/guiders/frequency_decoupled_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,16 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
data_batches.append(data_batch)
return data_batches

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

Expand Down
50 changes: 50 additions & 0 deletions src/diffusers/guiders/guider_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")

def __call__(self, data: List["BlockState"]) -> Any:
if not all(hasattr(d, "noise_pred") for d in data):
raise ValueError("Expected all data to have `noise_pred` attribute.")
Expand Down Expand Up @@ -234,6 +239,51 @@ def _prepare_batch(
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)

@classmethod
def _prepare_batch_from_block_state(
cls,
input_fields: Dict[str, Union[str, Tuple[str, str]]],
data: "BlockState",
tuple_index: int,
identifier: str,
) -> "BlockState":
"""
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
`BaseGuidance` class. It prepares the batch based on the provided tuple index.

Args:
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation. If a string is provided, it will be used as the
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
length 2 is provided, the first element must be the conditional data identifier and the second element
must be the unconditional data identifier or None.
data (`BlockState`):
The input data to be prepared.
tuple_index (`int`):
The index to use when accessing input fields that are tuples.

Returns:
`BlockState`: The prepared batch of data.
"""
from ..modular_pipelines.modular_pipeline import BlockState

data_batch = {}
for key, value in input_fields.items():
try:
if isinstance(value, str):
data_batch[key] = getattr(data, value)
elif isinstance(value, tuple):
data_batch[key] = getattr(data, value[tuple_index])
else:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)

@classmethod
@validate_hf_hub_args
def from_pretrained(
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/guiders/perturbed_attention_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,26 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
data_batches.append(data_batch)
return data_batches

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
def forward(
self,
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/guiders/skip_layer_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,26 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
data_batches.append(data_batch)
return data_batches

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

def forward(
self,
pred_cond: torch.Tensor,
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/guiders/smoothed_energy_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,26 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
data_batches.append(data_batch)
return data_batches

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

def forward(
self,
pred_cond: torch.Tensor,
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/guiders/tangential_classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -74,6 +74,16 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
data_batches.append(data_batch)
return data_batches

def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/modular_pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"InsertableDict",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
_import_structure["flux"] = [
"FluxAutoBlocks",
"FluxModularPipeline",
Expand Down Expand Up @@ -90,7 +90,7 @@
QwenImageModularPipeline,
)
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import WanAutoBlocks, WanModularPipeline
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
else:
import sys

Expand Down
Loading
Loading