diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 509c201a4549..c65e6717959c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -143,7 +143,7 @@ def __init__( double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'layer_norm_i2vgen' + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' norm_eps: float = 1e-5, final_dropout: bool = False, attention_type: str = "default", diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index bd632660f46c..8b391eeebfd9 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -92,7 +92,7 @@ def __init__( only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, - norm_type: str = "layer_norm", + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, attention_type: str = "default", @@ -100,6 +100,16 @@ def __init__( interpolation_scale: float = None, ): super().__init__() + if patch_size is not None: + if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim