Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
resnet_out_scale_factor: float = 1.0,
time_embedding_type: str = "positional",
time_embedding_dim: Optional[int] = None,
time_embedding_act_fn: Optional[str] = None,
Expand All @@ -217,7 +217,7 @@ def __init__(
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
addition_embed_type_num_heads=64,
addition_embed_type_num_heads: int = 64,
):
super().__init__()

Expand Down Expand Up @@ -485,9 +485,9 @@ def _check_config(
up_block_types: Tuple[str],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
layers_per_block: [int, Tuple[int]],
layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
reverse_transformer_layers_per_block: bool,
attention_head_dim: int,
num_attention_heads: Optional[Union[int, Tuple[int]]],
Expand Down Expand Up @@ -762,7 +762,7 @@ def set_default_attn_processor(self):

self.set_attn_processor(processor)

def set_attention_slice(self, slice_size):
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has been done to match the docs @yiyixuxu. if we don't want it to be auto by default, i can update the docs accordingly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok

r"""
Enable sliced attention computation.

Expand Down Expand Up @@ -831,7 +831,7 @@ def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def enable_freeu(self, s1, s2, b1, b2):
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.

The suffixes after the scaling factors represent the stage blocks where they are being applied.
Expand Down Expand Up @@ -953,7 +953,7 @@ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Ten
return class_emb

def get_aug_embed(
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
) -> Optional[torch.Tensor]:
aug_emb = None
if self.config.addition_embed_type == "text":
Expand Down Expand Up @@ -1004,7 +1004,9 @@ def get_aug_embed(
aug_emb = self.add_embedding(image_embs, hint)
return aug_emb

def process_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor, added_cond_kwargs) -> torch.Tensor:
def process_encoder_hidden_states(
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
) -> torch.Tensor:
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
Expand Down