Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions src/diffusers/models/unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .transformer_temporal import TransformerTemporalModel
Expand Down Expand Up @@ -249,6 +250,32 @@ def __init__(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)

@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}

def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor

for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

return processors

for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)

return processors

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
Expand All @@ -259,34 +286,34 @@ def set_attention_slice(self, slice_size):
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []

def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)

for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
fn_recursive_retrieve_sliceable_dims(child)

# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
fn_recursive_retrieve_sliceable_dims(module)

num_slicable_layers = len(sliceable_head_dims)
num_sliceable_layers = len(sliceable_head_dims)

if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = num_sliceable_layers * [1]

slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size

if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
Expand Down Expand Up @@ -314,6 +341,37 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:

"""
count = len(self.attn_processors.keys())

if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)

def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
Expand Down