From 500051ff1503ff712b2c3be07f78cca3721ef58c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 4 Mar 2024 11:27:12 +0530 Subject: [PATCH 01/32] fix PyTorch classes and start deprecsation cycles. --- src/diffusers/models/activations.py | 5 +---- src/diffusers/models/attention.py | 2 +- src/diffusers/models/attention_processor.py | 7 ++----- src/diffusers/models/lora.py | 6 ++++++ 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index ce645f70e524..eaa8ed50e562 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -18,7 +18,6 @@ from torch import nn from ..utils import USE_PEFT_BACKEND -from .lora import LoRACompatibleLinear ACTIVATION_FUNCTIONS = { @@ -87,9 +86,7 @@ class GEGLU(nn.Module): def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() - linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear - - self.proj = linear_cls(dim_in, dim_out * 2, bias=bias) + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index a4b3ee58a865..be0ab13b0e86 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -644,7 +644,7 @@ def __init__( if inner_dim is None: inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + linear_cls = nn.Linear if activation_fn == "gelu": act_fn = GELU(dim, inner_dim, bias=bias) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5ec8876fc114..15e3c5740789 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -23,7 +23,7 @@ from ..utils import USE_PEFT_BACKEND, deprecate, logging from ..utils.import_utils import is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph -from .lora import LoRACompatibleLinear, LoRALinearLayer +from .lora import LoRALinearLayer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -181,10 +181,7 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - if USE_PEFT_BACKEND: - linear_cls = nn.Linear - else: - linear_cls = LoRACompatibleLinear + linear_cls = nn.Linear self.linear_cls = linear_cls self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 1bbc96c6f5a7..4e9e0c07ca75 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -204,6 +204,9 @@ def __init__( ): super().__init__() + deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." + deprecate("LoRALinearLayer", "1.0.0", deprecation_message) + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. @@ -264,6 +267,9 @@ def __init__( ): super().__init__() + deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." + deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message) + self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) # according to the official kohya_ss trainer kernel_size are always fixed for the up layer # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 From 9061ebbc2b4570e1dfe27e0c6c7b37bde5b2a4fd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 4 Mar 2024 14:04:57 +0530 Subject: [PATCH 02/32] remove args crafting for accommodating scale. --- src/diffusers/models/attention_processor.py | 80 +++++++++------------ 1 file changed, 35 insertions(+), 45 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 15e3c5740789..d328f0a83c80 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -20,7 +20,7 @@ from torch import nn from ..image_processor import IPAdapterMaskProcessor -from ..utils import USE_PEFT_BACKEND, deprecate, logging +from ..utils import deprecate, logging from ..utils.import_utils import is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph from .lora import LoRALinearLayer @@ -742,8 +742,6 @@ def __call__( ) -> torch.Tensor: residual = hidden_states - args = () if USE_PEFT_BACKEND else (scale,) - if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -761,15 +759,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) @@ -780,7 +778,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -915,8 +913,6 @@ def __call__( ) -> torch.Tensor: residual = hidden_states - args = () if USE_PEFT_BACKEND else (scale,) - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -929,17 +925,17 @@ def __call__( hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, *args) - value = attn.to_v(hidden_states, *args) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) @@ -953,7 +949,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -985,8 +981,6 @@ def __call__( ) -> torch.Tensor: residual = hidden_states - args = () if USE_PEFT_BACKEND else (scale,) - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -999,7 +993,7 @@ def __call__( hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query, out_dim=4) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) @@ -1008,8 +1002,8 @@ def __call__( encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, *args) - value = attn.to_v(hidden_states, *args) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) @@ -1026,7 +1020,7 @@ def __call__( hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1133,8 +1127,6 @@ def __call__( ) -> torch.FloatTensor: residual = hidden_states - args = () if USE_PEFT_BACKEND else (scale,) - if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -1162,15 +1154,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() @@ -1183,7 +1175,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1239,16 +1231,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - args = () if USE_PEFT_BACKEND else (scale,) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -1268,7 +1259,7 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1334,17 +1325,16 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - args = () if USE_PEFT_BACKEND else (scale,) if encoder_hidden_states is None: - qkv = attn.to_qkv(hidden_states, *args) + qkv = attn.to_qkv(hidden_states) split_size = qkv.shape[-1] // 3 query, key, value = torch.split(qkv, split_size, dim=-1) else: if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) - kv = attn.to_kv(encoder_hidden_states, *args) + kv = attn.to_kv(encoder_hidden_states) split_size = kv.shape[-1] // 2 key, value = torch.split(kv, split_size, dim=-1) @@ -1365,7 +1355,7 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1856,7 +1846,7 @@ def __init__( self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -1874,7 +1864,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k attn._modules.pop("processor") attn.processor = AttnProcessor() - return attn.processor(attn, hidden_states, *args, **kwargs) + return attn.processor(attn, hidden_states, **kwargs) class LoRAAttnProcessor2_0(nn.Module): @@ -1917,7 +1907,7 @@ def __init__( self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -1935,7 +1925,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k attn._modules.pop("processor") attn.processor = AttnProcessor2_0() - return attn.processor(attn, hidden_states, *args, **kwargs) + return attn.processor(attn, hidden_states, **kwargs) class LoRAXFormersAttnProcessor(nn.Module): @@ -1996,7 +1986,7 @@ def __init__( self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -2014,7 +2004,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k attn._modules.pop("processor") attn.processor = XFormersAttnProcessor() - return attn.processor(attn, hidden_states, *args, **kwargs) + return attn.processor(attn, hidden_states, **kwargs) class LoRAAttnAddedKVProcessor(nn.Module): @@ -2055,7 +2045,7 @@ def __init__( self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor: self_cls_name = self.__class__.__name__ deprecate( self_cls_name, @@ -2073,7 +2063,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k attn._modules.pop("processor") attn.processor = AttnAddedKVProcessor() - return attn.processor(attn, hidden_states, *args, **kwargs) + return attn.processor(attn, hidden_states, **kwargs) class IPAdapterAttnProcessor(nn.Module): From 76b7cbdf11c9688ca9676b38f827a3f1f12616c5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 4 Mar 2024 14:13:45 +0530 Subject: [PATCH 03/32] remove scale check in feedforward. --- src/diffusers/models/activations.py | 5 +---- src/diffusers/models/attention.py | 8 +------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index eaa8ed50e562..678ce45dfd01 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -17,8 +17,6 @@ import torch.nn.functional as F from torch import nn -from ..utils import USE_PEFT_BACKEND - ACTIVATION_FUNCTIONS = { "swish": nn.SiLU(), @@ -95,8 +93,7 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor: return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states, scale: float = 1.0): - args = () if USE_PEFT_BACKEND else (scale,) - hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index be0ab13b0e86..dfbedde6a6b1 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -17,12 +17,10 @@ import torch.nn.functional as F from torch import nn -from ..utils import USE_PEFT_BACKEND from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU from .attention_processor import Attention from .embeddings import SinusoidalPositionalEmbedding -from .lora import LoRACompatibleLinear from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm @@ -667,10 +665,6 @@ def __init__( self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: - compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) for module in self.net: - if isinstance(module, compatible_cls): - hidden_states = module(hidden_states, scale) - else: - hidden_states = module(hidden_states) + hidden_states = module(hidden_states) return hidden_states From 93b5106f09fe58fb2025281e4d8d3e2dbe235487 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 4 Mar 2024 14:28:40 +0530 Subject: [PATCH 04/32] assert against nn.Linear and not CompatibleLinear. --- tests/models/test_layers_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index f08388921a4f..b5a5bec471a6 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -22,7 +22,6 @@ from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU from diffusers.models.embeddings import get_timestep_embedding -from diffusers.models.lora import LoRACompatibleLinear from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.transformers.transformer_2d import Transformer2DModel from diffusers.utils.testing_utils import ( @@ -482,7 +481,7 @@ def test_spatial_transformer_default_ff_layers(self): assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout - assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear dim = 32 inner_dim = 128 @@ -506,7 +505,7 @@ def test_spatial_transformer_geglu_approx_ff_layers(self): assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout - assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear dim = 32 inner_dim = 128 From 80953172f1adac99ccf19674297d391b564507e5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 14:07:58 +0530 Subject: [PATCH 05/32] remove conv_cls and lineaR_cls. --- src/diffusers/models/downsampling.py | 2 +- src/diffusers/models/embeddings.py | 5 ++--- src/diffusers/models/resnet.py | 7 +++---- src/diffusers/models/transformers/transformer_2d.py | 5 ++--- src/diffusers/models/upsampling.py | 2 +- .../pipelines/wuerstchen/modeling_wuerstchen_common.py | 10 ++++------ .../pipelines/wuerstchen/modeling_wuerstchen_prior.py | 7 +++---- 7 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 80fb065a6f4c..4573c0aa45a7 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -103,7 +103,7 @@ def __init__( self.padding = padding stride = 2 self.name = name - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + conv_cls = nn.Conv2d if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 49f385d5f493..c15ff24cbcda 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -18,10 +18,9 @@ import torch from torch import nn -from ..utils import USE_PEFT_BACKEND, deprecate +from ..utils import deprecate from .activations import get_activation from .attention_processor import Attention -from .lora import LoRACompatibleLinear def get_timestep_embedding( @@ -200,7 +199,7 @@ def __init__( sample_proj_bias=True, ): super().__init__() - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + linear_cls = nn.Linear self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 84cb31f430a0..295bf78e2564 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -30,7 +30,6 @@ KDownsample2D, downsample_2d, ) -from .lora import LoRACompatibleConv, LoRACompatibleLinear from .normalization import AdaGroupNorm from .upsampling import ( # noqa FirUpsample2D, @@ -102,7 +101,7 @@ def __init__( self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + conv_cls = nn.Conv2d if groups_out is None: groups_out = groups @@ -267,8 +266,8 @@ def __init__( self.time_embedding_norm = time_embedding_norm self.skip_time_act = skip_time_act - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear + conv_cls = nn.Conv2d if groups_out is None: groups_out = groups diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index bd632660f46c..efe06ffd546e 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -22,7 +22,6 @@ from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version from ..attention import BasicTransformerBlock from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection -from ..lora import LoRACompatibleConv, LoRACompatibleLinear from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle @@ -105,8 +104,8 @@ def __init__( self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + conv_cls = nn.Conv2d + linear_cls = nn.Linear # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 080103504c53..b9f485f89c21 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -111,7 +111,7 @@ def __init__( self.use_conv_transpose = use_conv_transpose self.name = name self.interpolate = interpolate - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + conv_cls = nn.Conv2d if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py index 7ad4e2f54e7c..3641a916c1db 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -17,8 +17,6 @@ import torch.nn as nn from ...models.attention_processor import Attention -from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear -from ...utils import USE_PEFT_BACKEND class WuerstchenLayerNorm(nn.LayerNorm): @@ -34,7 +32,7 @@ def forward(self, x): class TimestepBlock(nn.Module): def __init__(self, c, c_timestep): super().__init__() - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + linear_cls = nn.Linear self.mapper = linear_cls(c_timestep, c * 2) def forward(self, x, t): @@ -46,8 +44,8 @@ class ResBlock(nn.Module): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): super().__init__() - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + conv_cls = nn.Conv2d + linear_cls = nn.Linear self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) @@ -81,7 +79,7 @@ class AttnBlock(nn.Module): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + linear_cls = nn.Linear self.self_attn = self_attn self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index c44c259ab0b4..8cc294eaf79a 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -28,9 +28,8 @@ AttnAddedKVProcessor, AttnProcessor, ) -from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear from ...models.modeling_utils import ModelMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version +from ...utils import is_torch_version from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm @@ -41,8 +40,8 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + conv_cls = nn.Conv2d + linear_cls = nn.Linear self.c_r = c_r self.projection = conv_cls(c_in, c, kernel_size=1) From 7e96549671d0096e6119cff2d8d0785124fa8d0f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 14:09:04 +0530 Subject: [PATCH 06/32] remove scale --- src/diffusers/models/activations.py | 2 +- src/diffusers/models/attention.py | 2 +- src/diffusers/models/downsampling.py | 2 +- src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/lora.py | 4 ++-- src/diffusers/models/unets/unet_2d_blocks.py | 14 +++++++------- .../versatile_diffusion/modeling_text_unet.py | 2 +- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 678ce45dfd01..7eceeafabe28 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -92,7 +92,7 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor: # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - def forward(self, hidden_states, scale: float = 1.0): + def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index dfbedde6a6b1..00c8899e0536 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -664,7 +664,7 @@ def __init__( if final_dropout: self.net.append(nn.Dropout(dropout)) - def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for module in self.net: hidden_states = module(hidden_states) return hidden_states diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 4573c0aa45a7..0555261a2a70 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -131,7 +131,7 @@ def __init__( else: self.conv = conv - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels if self.norm is not None: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c15ff24cbcda..b6659490b905 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -257,7 +257,7 @@ class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__( - self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + self, embedding_size: int = 256, set_W_to_weight=True, log=True, flip_sin_to_cos=False ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 4e9e0c07ca75..1ce090bd00bc 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -366,7 +366,7 @@ def _unfuse_lora(self): self.w_up = None self.w_down = None - def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.padding_mode != "zeros": hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode) padding = (0, 0) @@ -448,7 +448,7 @@ def _unfuse_lora(self): self.w_up = None self.w_down = None - def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.lora_layer is None: out = super().forward(hidden_states) return out diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 9ebf6982ca82..99faca321d46 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -1348,7 +1348,7 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () @@ -1447,7 +1447,7 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None, scale=scale) @@ -1545,7 +1545,7 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=None, scale=scale) cross_attention_kwargs = {"scale": scale} @@ -1816,7 +1816,7 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () @@ -2058,7 +2058,7 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () @@ -2684,7 +2684,7 @@ def __init__( self.resolution_idx = resolution_idx def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None ) -> torch.FloatTensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=temb, scale=scale) @@ -2784,7 +2784,7 @@ def __init__( self.resolution_idx = resolution_idx def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None ) -> torch.FloatTensor: for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=temb, scale=scale) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 97c533be5864..9b8399f0952d 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1589,7 +1589,7 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () From 59181e443463eeb52c16ace6a22cd1a7c8890a6c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 14:16:42 +0530 Subject: [PATCH 07/32] =?UTF-8?q?=F0=9F=91=8B=20scale.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/diffusers/models/downsampling.py | 10 +---- src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/lora.py | 4 +- src/diffusers/models/unets/unet_2d_blocks.py | 38 ++++++++----------- .../versatile_diffusion/modeling_text_unet.py | 4 +- 5 files changed, 22 insertions(+), 36 deletions(-) diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 0555261a2a70..6776fd363692 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -18,8 +18,6 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import USE_PEFT_BACKEND -from .lora import LoRACompatibleConv from .normalization import RMSNorm from .upsampling import upfirdn2d_native @@ -143,13 +141,7 @@ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels - if not USE_PEFT_BACKEND: - if isinstance(self.conv, LoRACompatibleConv): - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) - else: - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states) return hidden_states diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b6659490b905..c15ff24cbcda 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -257,7 +257,7 @@ class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__( - self, embedding_size: int = 256, set_W_to_weight=True, log=True, flip_sin_to_cos=False + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 1ce090bd00bc..4e9e0c07ca75 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -366,7 +366,7 @@ def _unfuse_lora(self): self.w_up = None self.w_down = None - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: if self.padding_mode != "zeros": hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode) padding = (0, 0) @@ -448,7 +448,7 @@ def _unfuse_lora(self): self.w_up = None self.w_down = None - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: if self.lora_layer is None: out = super().forward(hidden_states) return out diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 99faca321d46..246fe76ff551 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -1370,13 +1370,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -1449,11 +1449,11 @@ def __init__( def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None, scale=scale) + hidden_states = resnet(hidden_states, temb=None) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale) + hidden_states = downsampler(hidden_states) return hidden_states @@ -1547,13 +1547,12 @@ def __init__( def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, **cross_attention_kwargs) + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale) + hidden_states = downsampler(hidden_states) return hidden_states @@ -1838,13 +1837,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale) + hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb, scale) + hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) @@ -2080,7 +2079,7 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale) + hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) @@ -2683,11 +2682,9 @@ def __init__( self.resolution_idx = resolution_idx - def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None - ) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=temb, scale=scale) + hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2783,17 +2780,14 @@ def __init__( self.resolution_idx = resolution_idx - def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None - ) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=temb, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs) + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, scale=scale) + hidden_states = upsampler(hidden_states) return hidden_states diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 9b8399f0952d..32997eced67a 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1611,13 +1611,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) From 1413847dc5c94928f43e8e7360712ee61a38d790 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 17:05:06 +0530 Subject: [PATCH 08/32] fix: unet2dcondition --- src/diffusers/models/unets/unet_2d_condition.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index d067767315f1..fb40d6ea31b4 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1226,7 +1226,7 @@ def forward( **additional_residuals, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) @@ -1297,7 +1297,6 @@ def forward( temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, - scale=lora_scale, ) # 6. post-process From 05567383f18e041b1eb4288df12bdffb30b73df2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 17:08:02 +0530 Subject: [PATCH 09/32] fix attention.py --- src/diffusers/models/attention.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 00c8899e0536..2b40a377c7e2 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -25,7 +25,7 @@ def _chunked_feed_forward( - ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None + ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int ): # "feed_forward_chunk_size" can be used to save memory if hidden_states.shape[chunk_dim] % chunk_size != 0: @@ -34,18 +34,10 @@ def _chunked_feed_forward( ) num_chunks = hidden_states.shape[chunk_dim] // chunk_size - if lora_scale is None: - ff_output = torch.cat( - [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], - dim=chunk_dim, - ) - else: - # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete - ff_output = torch.cat( - [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], - dim=chunk_dim, - ) - + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) return ff_output @@ -393,10 +385,10 @@ def forward( if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward( - self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size ) else: - ff_output = self.ff(norm_hidden_states, scale=lora_scale) + ff_output = self.ff(norm_hidden_states) if self.norm_type == "ada_norm_zero": ff_output = gate_mlp.unsqueeze(1) * ff_output From f56f9f3de209e62975d1f4d852d5325889a93e93 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 17:18:22 +0530 Subject: [PATCH 10/32] fix: attention.py again --- src/diffusers/models/attention.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2b40a377c7e2..87e5ffc72564 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -24,9 +24,7 @@ from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm -def _chunked_feed_forward( - ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int -): +def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): # "feed_forward_chunk_size" can be used to save memory if hidden_states.shape[chunk_dim] % chunk_size != 0: raise ValueError( @@ -316,10 +314,7 @@ def forward( if self.pos_embed is not None: norm_hidden_states = self.pos_embed(norm_hidden_states) - # 1. Retrieve lora scale. - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - - # 2. Prepare GLIGEN inputs + # 1. Prepare GLIGEN inputs cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} gligen_kwargs = cross_attention_kwargs.pop("gligen", None) @@ -338,7 +333,7 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) - # 2.5 GLIGEN Control + # 1.2 GLIGEN Control if gligen_kwargs is not None: hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) @@ -384,9 +379,7 @@ def forward( if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward( - self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size - ) + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) From 99971891a39d37e1a11ca2221ddabfec322153ec Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 17:34:56 +0530 Subject: [PATCH 11/32] fix: unet_2d_blocks. --- src/diffusers/models/unets/unet_2d_blocks.py | 49 +++++++------------- 1 file changed, 18 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 246fe76ff551..9938c6f44e00 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -844,8 +844,7 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -882,7 +881,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -982,7 +981,6 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - lora_scale = cross_attention_kwargs.get("scale", 1.0) if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -995,7 +993,7 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( @@ -1006,7 +1004,7 @@ def forward( ) # resnet - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -1112,22 +1110,19 @@ def forward( ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - lora_scale = cross_attention_kwargs.get("scale", 1.0) - output_states = () for resnet, attn in zip(self.resnets, self.attentions): - cross_attention_kwargs.update({"scale": lora_scale}) - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - hidden_states = attn(hidden_states, **cross_attention_kwargs) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: if self.downsample_type == "resnet": - hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale) + hidden_states = downsampler(hidden_states, temb=temb) else: - hidden_states = downsampler(hidden_states, scale=lora_scale) + hidden_states = downsampler(hidden_states) output_states += (hidden_states,) @@ -1238,8 +1233,6 @@ def forward( ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): @@ -1270,7 +1263,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1288,7 +1281,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=lora_scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -1957,8 +1950,6 @@ def forward( output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - lora_scale = cross_attention_kwargs.get("scale", 1.0) - if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask @@ -1990,7 +1981,7 @@ def custom_forward(*inputs): **cross_attention_kwargs, ) else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, @@ -2003,7 +1994,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb, scale=lora_scale) + hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) @@ -2165,7 +2156,6 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -2195,7 +2185,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2439,7 +2429,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -2493,7 +2482,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2505,7 +2494,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -3248,7 +3237,6 @@ def forward( ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - lora_scale = cross_attention_kwargs.get("scale", 1.0) if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask @@ -3286,7 +3274,7 @@ def custom_forward(*inputs): **cross_attention_kwargs, ) else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, @@ -3297,7 +3285,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb, scale=lora_scale) + hidden_states = upsampler(hidden_states, temb) return hidden_states @@ -3492,7 +3480,6 @@ def forward( if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -3521,7 +3508,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, From 4f66db0ad74e0f733d199f0eda9e3bdc1ebeebfa Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 17:40:37 +0530 Subject: [PATCH 12/32] fix-copies. --- .../versatile_diffusion/modeling_text_unet.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 32997eced67a..419faf80c09e 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1999,7 +1999,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -2053,7 +2052,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2065,7 +2064,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -2330,8 +2329,7 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -2368,7 +2366,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -2469,7 +2467,6 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - lora_scale = cross_attention_kwargs.get("scale", 1.0) if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -2482,7 +2479,7 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( @@ -2493,6 +2490,6 @@ def forward( ) # resnet - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) return hidden_states From fd348d1fc100e74f39915a780a1b3e588d604a35 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 17:54:54 +0530 Subject: [PATCH 13/32] more fixes. --- .../models/transformers/transformer_2d.py | 29 ++++--------------- src/diffusers/models/unets/unet_3d_blocks.py | 16 ++++------ .../versatile_diffusion/modeling_text_unet.py | 8 ++--- 3 files changed, 14 insertions(+), 39 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index efe06ffd546e..df887e1962a3 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -19,7 +19,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version +from ...utils import BaseOutput, deprecate, is_torch_version from ..attention import BasicTransformerBlock from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection from ..modeling_utils import ModelMixin @@ -316,9 +316,6 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - # Retrieve lora scale. - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape @@ -326,21 +323,13 @@ def forward( hidden_states = self.norm(hidden_states) if not self.use_linear_projection: - hidden_states = ( - self.proj_in(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_in(hidden_states) - ) + hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = ( - self.proj_in(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_in(hidden_states) - ) + hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) @@ -403,17 +392,9 @@ def custom_forward(*inputs): if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = ( - self.proj_out(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_out(hidden_states) - ) + hidden_states = self.proj_out(hidden_states) else: - hidden_states = ( - self.proj_out(hidden_states, scale=lora_scale) - if not USE_PEFT_BACKEND - else self.proj_out(hidden_states) - ) + hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index e1ada1021b3a..4b3709ce29ec 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -1175,8 +1175,6 @@ def forward( ): output_states = () - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) for i, (resnet, attn, motion_module) in enumerate(blocks): if self.training and self.gradient_checkpointing: @@ -1206,7 +1204,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1228,7 +1226,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=lora_scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -1355,7 +1353,6 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, num_frames: int = 1, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -1410,7 +1407,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1426,7 +1423,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1687,8 +1684,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, num_frames: int = 1, ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + hidden_states = self.resnets[0](hidden_states, temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: @@ -1737,7 +1733,7 @@ def custom_forward(*inputs): hidden_states, num_frames=num_frames, )[0] - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 419faf80c09e..200b96e7b129 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1333,7 +1333,7 @@ def forward( **additional_residuals, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) @@ -1728,8 +1728,6 @@ def forward( ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): @@ -1760,7 +1758,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1778,7 +1776,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=lora_scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) From 10c4232e706b9008358eabd729eb3956a394995b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 18:09:01 +0530 Subject: [PATCH 14/32] fix: resnet.py --- src/diffusers/models/resnet.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 295bf78e2564..851c77425ed4 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -165,12 +165,12 @@ def forward( if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor, scale=scale) - hidden_states = self.upsample(hidden_states, scale=scale) + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) elif self.downsample is not None: - input_tensor = self.downsample(input_tensor, scale=scale) - hidden_states = self.downsample(hidden_states, scale=scale) + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) @@ -342,37 +342,29 @@ def forward( input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = ( - self.upsample(input_tensor, scale=scale) - if isinstance(self.upsample, Upsample2D) - else self.upsample(input_tensor) + self.upsample(input_tensor) if isinstance(self.upsample, Upsample2D) else self.upsample(input_tensor) ) hidden_states = ( - self.upsample(hidden_states, scale=scale) - if isinstance(self.upsample, Upsample2D) - else self.upsample(hidden_states) + self.upsample(hidden_states) if isinstance(self.upsample, Upsample2D) else self.upsample(hidden_states) ) elif self.downsample is not None: input_tensor = ( - self.downsample(input_tensor, scale=scale) + self.downsample(input_tensor) if isinstance(self.downsample, Downsample2D) else self.downsample(input_tensor) ) hidden_states = ( - self.downsample(hidden_states, scale=scale) + self.downsample(hidden_states) if isinstance(self.downsample, Downsample2D) else self.downsample(hidden_states) ) - hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) - temb = ( - self.time_emb_proj(temb, scale)[:, :, None, None] - if not USE_PEFT_BACKEND - else self.time_emb_proj(temb)[:, :, None, None] - ) + temb = self.time_emb_proj(temb)[:, :, None, None] if self.time_embedding_norm == "default": if temb is not None: @@ -392,12 +384,10 @@ def forward( hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = ( - self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) - ) + input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor From 6fe19d960cfe7f533611a91e94e371841b9914f1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 18:37:55 +0530 Subject: [PATCH 15/32] more fixes --- src/diffusers/models/resnet.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 851c77425ed4..41d0147a244a 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,7 +20,6 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import USE_PEFT_BACKEND from .activations import get_activation from .attention_processor import SpatialNorm from .downsampling import ( # noqa @@ -172,19 +171,17 @@ def forward( input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) - hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states, temb) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = ( - self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) - ) + input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor From c256b77bbac48bb258cafe07703d6451d509769b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 22:15:16 +0530 Subject: [PATCH 16/32] fix i2vgenxl unet. --- src/diffusers/models/unets/unet_i2vgen_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index a096f842ab6c..5c5c6a2cc5ec 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -89,7 +89,7 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) - ff_output = self.ff(hidden_states, scale=1.0) + ff_output = self.ff(hidden_states) hidden_states = ff_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) From 45030fb316e4e4141825ade3eb8579cdc9438f7b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 23:17:22 +0530 Subject: [PATCH 17/32] depcrecate scale gently. --- src/diffusers/models/unets/unet_2d_blocks.py | 109 +++++++++++++------ 1 file changed, 77 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 9938c6f44e00..b7912126703e 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from torch import nn -from ...utils import is_torch_version, logging +from ...utils import deprecate, is_torch_version, logging from ...utils.torch_utils import apply_freeu from ..activations import get_activation from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 @@ -1341,8 +1341,12 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = None ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + output_states = () for resnet in self.resnets: @@ -1440,7 +1444,11 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor, scale: float = None) -> torch.FloatTensor: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None) @@ -1538,7 +1546,11 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor, scale: float = None) -> torch.FloatTensor: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=None) hidden_states = attn(hidden_states) @@ -1638,16 +1650,19 @@ def forward( skip_sample: Optional[torch.FloatTensor] = None, scale: float = 1.0, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, **cross_attention_kwargs) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) output_states += (hidden_states,) if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb, scale=scale) + hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) @@ -1723,16 +1738,20 @@ def forward( hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, skip_sample: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + scale: float = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: output_states = () + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb, scale) + hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb, scale) + hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) @@ -1808,8 +1827,12 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = None ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + output_states = () for resnet in self.resnets: @@ -2305,24 +2328,27 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = 1.0, + scale: float = None, ) -> torch.FloatTensor: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb, scale=scale) - cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, **cross_attention_kwargs) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: if self.upsample_type == "resnet": - hidden_states = upsampler(hidden_states, temb=temb, scale=scale) + hidden_states = upsampler(hidden_states, temb=temb) else: - hidden_states = upsampler(hidden_states, scale=scale) + hidden_states = upsampler(hidden_states) return hidden_states @@ -2555,8 +2581,12 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = 1.0, + scale: float = None, ) -> torch.FloatTensor: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -2600,11 +2630,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -2883,16 +2913,19 @@ def forward( skip_sample=None, scale: float = 1.0, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) - cross_attention_kwargs = {"scale": scale} - hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs) + hidden_states = self.attentions[0](hidden_states) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) @@ -2906,7 +2939,7 @@ def forward( skip_sample = skip_sample + skip_sample_states - hidden_states = self.resnet_up(hidden_states, temb, scale=scale) + hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample @@ -2991,13 +3024,17 @@ def forward( skip_sample=None, scale: float = 1.0, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) @@ -3011,7 +3048,7 @@ def forward( skip_sample = skip_sample + skip_sample_states - hidden_states = self.resnet_up(hidden_states, temb, scale=scale) + hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample @@ -3091,8 +3128,12 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = 1.0, + scale: float = None, ) -> torch.FloatTensor: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -3116,11 +3157,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb, scale=scale) + hidden_states = upsampler(hidden_states, temb) return hidden_states @@ -3346,8 +3387,12 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = 1.0, + scale: float = None, ) -> torch.FloatTensor: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) @@ -3370,7 +3415,7 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: From 6b5212b0199930a1009efd83ad41807ec460cd7e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 5 Mar 2024 23:20:37 +0530 Subject: [PATCH 18/32] fix-copies --- .../versatile_diffusion/modeling_text_unet.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 200b96e7b129..43fa485fed8b 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1840,8 +1840,12 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = 1.0, + scale: float = None, ) -> torch.FloatTensor: + if scale is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -1885,11 +1889,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states From ee645c7f66c45e6f3d214691f29eab7c9b4ecd4e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 8 Mar 2024 09:34:04 +0530 Subject: [PATCH 19/32] Apply suggestions from code review Co-authored-by: YiYi Xu --- src/diffusers/models/resnet.py | 8 ++------ src/diffusers/models/unets/unet_2d_blocks.py | 4 ++-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 41d0147a244a..af2f65e8e69f 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -339,21 +339,17 @@ def forward( input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = ( - self.upsample(input_tensor) if isinstance(self.upsample, Upsample2D) else self.upsample(input_tensor) + self.upsample(input_tensor) ) hidden_states = ( - self.upsample(hidden_states) if isinstance(self.upsample, Upsample2D) else self.upsample(hidden_states) + self.upsample(hidden_states) ) elif self.downsample is not None: input_tensor = ( self.downsample(input_tensor) - if isinstance(self.downsample, Downsample2D) - else self.downsample(input_tensor) ) hidden_states = ( self.downsample(hidden_states) - if isinstance(self.downsample, Downsample2D) - else self.downsample(hidden_states) ) hidden_states = self.conv1(hidden_states) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index b7912126703e..f9284d3554cd 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -1341,9 +1341,9 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = None + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - if scale is not None: + if len(args) > 0 or kwargs.get("scale",None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) From bfdfc2028d3e2a0a369c0799bfdf0085bf07a61e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Mar 2024 09:39:00 +0530 Subject: [PATCH 20/32] quality --- src/diffusers/models/resnet.py | 16 ++++------------ src/diffusers/models/unets/unet_2d_blocks.py | 2 +- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index af2f65e8e69f..d68aebe5c4e0 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -338,19 +338,11 @@ def forward( if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() - input_tensor = ( - self.upsample(input_tensor) - ) - hidden_states = ( - self.upsample(hidden_states) - ) + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) elif self.downsample is not None: - input_tensor = ( - self.downsample(input_tensor) - ) - hidden_states = ( - self.downsample(hidden_states) - ) + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index f9284d3554cd..bfb26b54b1b8 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -1343,7 +1343,7 @@ def __init__( def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - if len(args) > 0 or kwargs.get("scale",None) is not None: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) From d4fa31d160666fe8763b9022b4af37c691ddf89d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Mar 2024 10:06:09 +0530 Subject: [PATCH 21/32] throw warning when scale is passed to the the BasicTransformerBlock class. --- src/diffusers/models/attention.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 87e5ffc72564..a18d637e55e2 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -17,6 +17,7 @@ import torch.nn.functional as F from torch import nn +from ..utils import logging from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU from .attention_processor import Attention @@ -24,6 +25,9 @@ from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm +logger = logging.get_logger(__name__) + + def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): # "feed_forward_chunk_size" can be used to save memory if hidden_states.shape[chunk_dim] % chunk_size != 0: @@ -287,6 +291,10 @@ def forward( class_labels: Optional[torch.LongTensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warn.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated.") + # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention batch_size = hidden_states.shape[0] From 15ef1e774c1fbac817d3fc05301d32bf201c9b03 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Mar 2024 10:20:38 +0530 Subject: [PATCH 22/32] remove scale from signature. --- src/diffusers/models/attention_processor.py | 42 ++++++++++++-- src/diffusers/models/resnet.py | 23 ++++---- src/diffusers/models/unets/unet_2d_blocks.py | 56 +++++++++++-------- src/diffusers/models/unets/unet_3d_blocks.py | 26 ++++++--- src/diffusers/models/upsampling.py | 22 +++----- .../versatile_diffusion/modeling_text_unet.py | 5 +- 6 files changed, 108 insertions(+), 66 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d328f0a83c80..ac24b2081421 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -738,8 +738,13 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states if attn.spatial_norm is not None: @@ -909,8 +914,13 @@ def __call__( hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) @@ -977,8 +987,13 @@ def __call__( hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) @@ -1123,8 +1138,13 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states if attn.spatial_norm is not None: @@ -1206,8 +1226,13 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -1300,8 +1325,13 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d68aebe5c4e0..f5c5a30afc00 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,6 +20,7 @@ import torch.nn as nn import torch.nn.functional as F +from ..utils import deprecate from .activations import get_activation from .attention_processor import SpatialNorm from .downsampling import ( # noqa @@ -147,12 +148,11 @@ def __init__( bias=conv_shortcut_bias, ) - def forward( - self, - input_tensor: torch.FloatTensor, - temb: torch.FloatTensor, - scale: float = 1.0, - ) -> torch.FloatTensor: + def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states = input_tensor hidden_states = self.norm1(hidden_states, temb) @@ -322,12 +322,11 @@ def __init__( bias=conv_shortcut_bias, ) - def forward( - self, - input_tensor: torch.FloatTensor, - temb: torch.FloatTensor, - scale: float = 1.0, - ) -> torch.FloatTensor: + def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states = input_tensor hidden_states = self.norm1(hidden_states) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index bfb26b54b1b8..a0b2604392e9 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -1444,8 +1444,8 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states: torch.FloatTensor, scale: float = None) -> torch.FloatTensor: - if scale is not None: + def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) @@ -1546,8 +1546,8 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states: torch.FloatTensor, scale: float = None) -> torch.FloatTensor: - if scale is not None: + def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) @@ -1648,9 +1648,10 @@ def forward( hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, skip_sample: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: - if scale is not None: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) @@ -1738,14 +1739,15 @@ def forward( hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, skip_sample: Optional[torch.FloatTensor] = None, - scale: float = None, + *args, + **kwargs, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: - output_states = () - - if scale is not None: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) + output_states = () + for resnet in self.resnets: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) @@ -1827,9 +1829,9 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = None + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - if scale is not None: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) @@ -2328,9 +2330,10 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = None, + *args, + **kwargs, ) -> torch.FloatTensor: - if scale is not None: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) @@ -2581,9 +2584,10 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = None, + *args, + **kwargs, ) -> torch.FloatTensor: - if scale is not None: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) @@ -2911,9 +2915,10 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, skip_sample=None, - scale: float = 1.0, + *args, + **kwargs, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - if scale is not None: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) @@ -3022,9 +3027,10 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, skip_sample=None, - scale: float = 1.0, + *args, + **kwargs, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - if scale is not None: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) @@ -3128,9 +3134,10 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = None, + *args, + **kwargs, ) -> torch.FloatTensor: - if scale is not None: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) @@ -3387,9 +3394,10 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = None, + *args, + **kwargs, ) -> torch.FloatTensor: - if scale is not None: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 4b3709ce29ec..27f7aad80790 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -17,7 +17,7 @@ import torch from torch import nn -from ...utils import is_torch_version +from ...utils import deprecate, is_torch_version from ...utils.torch_utils import apply_freeu from ..attention import Attention from ..resnet import ( @@ -1005,9 +1005,14 @@ def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, num_frames: int = 1, + *args, + **kwargs, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + output_states = () blocks = zip(self.resnets, self.motion_modules) @@ -1029,18 +1034,18 @@ def custom_forward(*inputs): ) else: hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, scale + create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=scale) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -1504,9 +1509,14 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size=None, - scale: float = 1.0, num_frames: int = 1, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -1556,12 +1566,12 @@ def custom_forward(*inputs): ) else: - hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index b9f485f89c21..3b39e665153f 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -18,8 +18,7 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import USE_PEFT_BACKEND -from .lora import LoRACompatibleConv +from ..utils import deprecate from .normalization import RMSNorm @@ -141,11 +140,12 @@ def __init__( self.Conv2d_0 = conv def forward( - self, - hidden_states: torch.FloatTensor, - output_size: Optional[int] = None, - scale: float = 1.0, + self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, *args, **kwargs ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecate("scale", "1.0.0", deprecation_message) + assert hidden_states.shape[1] == self.channels if self.norm is not None: @@ -180,15 +180,9 @@ def forward( # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND: - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states) else: - if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND: - hidden_states = self.Conv2d_0(hidden_states, scale) - else: - hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.Conv2d_0(hidden_states) return hidden_states diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 43fa485fed8b..0d26226db3d2 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1840,9 +1840,10 @@ def forward( res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, - scale: float = None, + *args, + **kwargs, ) -> torch.FloatTensor: - if scale is not None: + if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "Use of `scale` is deprecated. Please remove the argument." deprecate("scale", "1.0.0", deprecation_message) From d0375fa883178a95f7daa806bba70fedcb2b678d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Mar 2024 10:25:17 +0530 Subject: [PATCH 23/32] cross_attention_kwargs, very nice catch by Yiyi --- src/diffusers/models/unets/unet_2d_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index a0b2604392e9..8aeebdcf04f2 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -1114,7 +1114,7 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + hidden_states = attn(hidden_states, **cross_attention_kwargs) output_states = output_states + (hidden_states,) if self.downsamplers is not None: From 8f76caafa8b797c6415bc70cd9b95b916f34cdc9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Mar 2024 15:10:34 +0530 Subject: [PATCH 24/32] fix: logger.warn --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 00a856e8f746..e17f67f36524 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -293,7 +293,7 @@ def forward( ) -> torch.FloatTensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: - logger.warn.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated.") + logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated.") # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention From 402ec90b33950048bed6e7b23585cf175130fb62 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Mar 2024 20:25:28 +0530 Subject: [PATCH 25/32] make deprecation message clearer. --- src/diffusers/models/attention_processor.py | 12 +++++----- src/diffusers/models/resnet.py | 4 ++-- src/diffusers/models/unets/unet_2d_blocks.py | 24 +++++++++---------- src/diffusers/models/unets/unet_3d_blocks.py | 4 ++-- src/diffusers/models/upsampling.py | 2 +- .../versatile_diffusion/modeling_text_unet.py | 2 +- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ac24b2081421..7382e6686454 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -742,7 +742,7 @@ def __call__( **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states @@ -918,7 +918,7 @@ def __call__( **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states @@ -991,7 +991,7 @@ def __call__( **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states @@ -1142,7 +1142,7 @@ def __call__( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states @@ -1230,7 +1230,7 @@ def __call__( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states @@ -1329,7 +1329,7 @@ def __call__( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index f5c5a30afc00..96e48f61542e 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -150,7 +150,7 @@ def __init__( def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) hidden_states = input_tensor @@ -324,7 +324,7 @@ def __init__( def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) hidden_states = input_tensor diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 8aeebdcf04f2..e7449f2df6bb 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -1344,7 +1344,7 @@ def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) output_states = () @@ -1446,7 +1446,7 @@ def __init__( def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: @@ -1548,7 +1548,7 @@ def __init__( def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) for resnet, attn in zip(self.resnets, self.attentions): @@ -1652,7 +1652,7 @@ def forward( **kwargs, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) output_states = () @@ -1743,7 +1743,7 @@ def forward( **kwargs, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) output_states = () @@ -1832,7 +1832,7 @@ def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) output_states = () @@ -2334,7 +2334,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) for resnet, attn in zip(self.resnets, self.attentions): @@ -2588,7 +2588,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) is_freeu_enabled = ( @@ -2919,7 +2919,7 @@ def forward( **kwargs, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: @@ -3031,7 +3031,7 @@ def forward( **kwargs, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: @@ -3138,7 +3138,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: @@ -3398,7 +3398,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) res_hidden_states_tuple = res_hidden_states_tuple[-1] diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 27f7aad80790..5dcc0eeee7a0 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -1010,7 +1010,7 @@ def forward( **kwargs, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) output_states = () @@ -1514,7 +1514,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) is_freeu_enabled = ( diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 3b39e665153f..4f79d39f65cb 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -143,7 +143,7 @@ def forward( self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, *args, **kwargs ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) assert hidden_states.shape[1] == self.channels diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 0d26226db3d2..2c86a58bff74 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1844,7 +1844,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument." + deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." deprecate("scale", "1.0.0", deprecation_message) is_freeu_enabled = ( From de01273c52117ceea0a0d749813757a1e5f4e092 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Mar 2024 21:03:48 +0530 Subject: [PATCH 26/32] address final comments. --- src/diffusers/models/attention_processor.py | 12 ++++++------ src/diffusers/models/resnet.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7382e6686454..44fbd584cd7c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -742,7 +742,7 @@ def __call__( **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states @@ -918,7 +918,7 @@ def __call__( **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states @@ -991,7 +991,7 @@ def __call__( **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states @@ -1142,7 +1142,7 @@ def __call__( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states @@ -1230,7 +1230,7 @@ def __call__( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states @@ -1329,7 +1329,7 @@ def __call__( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{attn.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 96e48f61542e..ec75861e2da0 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -150,7 +150,7 @@ def __init__( def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) hidden_states = input_tensor @@ -324,7 +324,7 @@ def __init__( def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = f"Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) hidden_states = input_tensor From f56254e5035528618a293f9c8e6b6ed626424244 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 12 Mar 2024 08:11:07 +0530 Subject: [PATCH 27/32] maintain same depcrecation message and also add it to activations. --- src/diffusers/models/activations.py | 8 ++++++- src/diffusers/models/unets/unet_2d_blocks.py | 24 +++++++++---------- src/diffusers/models/unets/unet_3d_blocks.py | 4 ++-- src/diffusers/models/upsampling.py | 2 +- .../versatile_diffusion/modeling_text_unet.py | 2 +- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 7eceeafabe28..cec83bdded9e 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -17,6 +17,8 @@ import torch.nn.functional as F from torch import nn +from ..utils import deprecate + ACTIVATION_FUNCTIONS = { "swish": nn.SiLU(), @@ -92,7 +94,11 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor: # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - def forward(self, hidden_states): + def forward(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index e7449f2df6bb..2eac81f356ab 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -1344,7 +1344,7 @@ def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) output_states = () @@ -1446,7 +1446,7 @@ def __init__( def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: @@ -1548,7 +1548,7 @@ def __init__( def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet, attn in zip(self.resnets, self.attentions): @@ -1652,7 +1652,7 @@ def forward( **kwargs, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) output_states = () @@ -1743,7 +1743,7 @@ def forward( **kwargs, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) output_states = () @@ -1832,7 +1832,7 @@ def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) output_states = () @@ -2334,7 +2334,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet, attn in zip(self.resnets, self.attentions): @@ -2588,7 +2588,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) is_freeu_enabled = ( @@ -2919,7 +2919,7 @@ def forward( **kwargs, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: @@ -3031,7 +3031,7 @@ def forward( **kwargs, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: @@ -3138,7 +3138,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: @@ -3398,7 +3398,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) res_hidden_states_tuple = res_hidden_states_tuple[-1] diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 5dcc0eeee7a0..14b737ec9d57 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -1010,7 +1010,7 @@ def forward( **kwargs, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) output_states = () @@ -1514,7 +1514,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) is_freeu_enabled = ( diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 4f79d39f65cb..4ecf6ebc26d2 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -143,7 +143,7 @@ def forward( self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, *args, **kwargs ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) assert hidden_states.shape[1] == self.channels diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 2c86a58bff74..6dc77cbd04b2 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1844,7 +1844,7 @@ def forward( **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "Use of `scale` is deprecated. Please remove the argument. Even if you pass it to the `forward()` of the `{self.__class__.__name__}` class, it won't have any effect." + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) is_freeu_enabled = ( From 0ffc8e569ebf991850bfdb94985bb741bfd426c7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 12 Mar 2024 08:30:09 +0530 Subject: [PATCH 28/32] address yiyi --- src/diffusers/models/attention.py | 2 +- src/diffusers/models/unets/unet_2d_blocks.py | 27 ++++++++++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e17f67f36524..b2bbc3c1a8b0 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -293,7 +293,7 @@ def forward( ) -> torch.FloatTensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: - logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated.") + logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 2eac81f356ab..322f91c7acb0 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -844,6 +844,10 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -981,6 +985,8 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -1109,6 +1115,8 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") output_states = () @@ -1972,8 +1980,11 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: - output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + output_states = () if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -2073,8 +2084,12 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + output_states = () for resnet in self.resnets: @@ -2180,6 +2195,10 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -3284,6 +3303,8 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -3664,6 +3685,8 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") # 1. Self-Attention if self.add_self_attention: From 69bbe935c077fe25f68e675ddeabee79333e4430 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 12 Mar 2024 08:31:58 +0530 Subject: [PATCH 29/32] fix copies --- .../deprecated/versatile_diffusion/modeling_text_unet.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 6dc77cbd04b2..f305605f4db0 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -2332,6 +2332,10 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -2470,6 +2474,8 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. From 416b4120df1526532f7089cbf36ab503b289b896 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 13 Mar 2024 07:17:44 +0530 Subject: [PATCH 30/32] Apply suggestions from code review Co-authored-by: YiYi Xu --- src/diffusers/models/attention.py | 5 ++++- src/diffusers/models/downsampling.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b2bbc3c1a8b0..b79ea7aaad97 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -656,7 +656,10 @@ def __init__( if final_dropout: self.net.append(nn.Dropout(dropout)) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) for module in self.net: hidden_states = module(hidden_states) return hidden_states diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 6776fd363692..d3b4b89a6ccd 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -129,7 +129,10 @@ def __init__( else: self.conv = conv - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) assert hidden_states.shape[1] == self.channels if self.norm is not None: From 83181b46874d66e25704d4c978e28b2f9ed0a75f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 13 Mar 2024 07:28:08 +0530 Subject: [PATCH 31/32] more depcrecation --- src/diffusers/models/attention.py | 4 ++-- src/diffusers/models/downsampling.py | 1 + .../models/transformers/transformer_2d.py | 8 ++++++- src/diffusers/models/unets/unet_2d_blocks.py | 22 +++++++++++++------ src/diffusers/models/unets/unet_3d_blocks.py | 17 +++++++++++++- .../versatile_diffusion/modeling_text_unet.py | 4 ++-- 6 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b79ea7aaad97..3d45cfa828a3 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from torch import nn -from ..utils import logging +from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU from .attention_processor import Attention @@ -293,7 +293,7 @@ def forward( ) -> torch.FloatTensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: - logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index d3b4b89a6ccd..9ae28e950e83 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -18,6 +18,7 @@ import torch.nn as nn import torch.nn.functional as F +from ..utils import deprecate from .normalization import RMSNorm from .upsampling import upfirdn2d_native diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 417c94506dea..555ea4f63808 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -19,13 +19,16 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import BaseOutput, deprecate, is_torch_version +from ...utils import BaseOutput, deprecate, is_torch_version, logging from ..attention import BasicTransformerBlock from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + @dataclass class Transformer2DModelOutput(BaseOutput): """ @@ -303,6 +306,9 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 5ca0ea17dc23..b9e9e63bbc18 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -846,7 +846,7 @@ def forward( ) -> torch.FloatTensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: - logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): @@ -986,7 +986,7 @@ def forward( ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: - logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -1116,7 +1116,7 @@ def forward( ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: - logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") output_states = () @@ -1239,6 +1239,10 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, additional_residuals: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + output_states = () blocks = list(zip(self.resnets, self.attentions)) @@ -1982,7 +1986,7 @@ def forward( ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: - logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") output_states = () @@ -2197,7 +2201,7 @@ def forward( ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: - logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") output_states = () @@ -2477,6 +2481,10 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -3304,7 +3312,7 @@ def forward( ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: - logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -3686,7 +3694,7 @@ def forward( ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: - logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") # 1. Self-Attention if self.add_self_attention: diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 14b737ec9d57..a48f1841c683 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -17,7 +17,7 @@ import torch from torch import nn -from ...utils import deprecate, is_torch_version +from ...utils import deprecate, is_torch_version, logging from ...utils.torch_utils import apply_freeu from ..attention import Attention from ..resnet import ( @@ -35,6 +35,9 @@ ) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + def get_down_block( down_block_type: str, num_layers: int, @@ -1178,6 +1181,10 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, additional_residuals: Optional[torch.FloatTensor] = None, ): + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + output_states = () blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) @@ -1358,6 +1365,10 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, num_frames: int = 1, ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -1694,6 +1705,10 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, num_frames: int = 1, ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + hidden_states = self.resnets[0](hidden_states, temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 00c595eccd04..8402028f1eed 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -2334,7 +2334,7 @@ def forward( ) -> torch.FloatTensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: - logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): @@ -2475,7 +2475,7 @@ def forward( ) -> torch.FloatTensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: - logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. From 938c9b03dd78548b23cb52db9cafcfff92c0c508 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 13 Mar 2024 07:29:47 +0530 Subject: [PATCH 32/32] fix-copies --- .../deprecated/versatile_diffusion/modeling_text_unet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 8402028f1eed..62a3a8728a2a 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -2002,6 +2002,10 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None)