From 72a67f9abce0adec9baeff546d1eb4159bf78de2 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 29 Feb 2024 01:36:46 +0530 Subject: [PATCH] update --- .../models/unets/unet_2d_condition.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index fee7a34fb216..d067767315f1 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -204,7 +204,7 @@ def __init__( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, - resnet_out_scale_factor: int = 1.0, + resnet_out_scale_factor: float = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, @@ -217,7 +217,7 @@ def __init__( class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, - addition_embed_type_num_heads=64, + addition_embed_type_num_heads: int = 64, ): super().__init__() @@ -485,9 +485,9 @@ def _check_config( up_block_types: Tuple[str], only_cross_attention: Union[bool, Tuple[bool]], block_out_channels: Tuple[int], - layers_per_block: [int, Tuple[int]], + layers_per_block: Union[int, Tuple[int]], cross_attention_dim: Union[int, Tuple[int]], - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]], + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], reverse_transformer_layers_per_block: bool, attention_head_dim: int, num_attention_heads: Optional[Union[int, Tuple[int]]], @@ -762,7 +762,7 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - def set_attention_slice(self, slice_size): + def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): r""" Enable sliced attention computation. @@ -831,7 +831,7 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - def enable_freeu(self, s1, s2, b1, b2): + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stage blocks where they are being applied. @@ -953,7 +953,7 @@ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Ten return class_emb def get_aug_embed( - self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict + self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] ) -> Optional[torch.Tensor]: aug_emb = None if self.config.addition_embed_type == "text": @@ -1004,7 +1004,9 @@ def get_aug_embed( aug_emb = self.add_embedding(image_embs, hint) return aug_emb - def process_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor, added_cond_kwargs) -> torch.Tensor: + def process_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> torch.Tensor: if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":