diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index de762acabebf..8006d0e1c127 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -21,6 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging +from .attention_processor import AttentionProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .transformer_temporal import TransformerTemporalModel @@ -249,6 +250,32 @@ def __init__( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. @@ -259,24 +286,24 @@ def set_attention_slice(self, slice_size): Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] - def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): - fn_recursive_retrieve_slicable_dims(child) + fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): - fn_recursive_retrieve_slicable_dims(module) + fn_recursive_retrieve_sliceable_dims(module) - num_slicable_layers = len(sliceable_head_dims) + num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between @@ -284,9 +311,9 @@ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible - slice_size = num_slicable_layers * [1] + slice_size = num_sliceable_layers * [1] - slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( @@ -314,6 +341,37 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Parameters: + `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `Attention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value