From 9312b366de60678241236e395c796a2460238384 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 23 Mar 2023 11:09:41 +0100 Subject: [PATCH 1/4] Helper function to disable custom attention processors. --- src/diffusers/models/controlnet.py | 9 +++++++- src/diffusers/models/unet_2d_condition.py | 21 +++++++------------ src/diffusers/models/unet_3d_condition.py | 9 +++++++- .../versatile_diffusion/modeling_text_unet.py | 9 +++++++- tests/models/test_models_unet_2d_condition.py | 4 ++-- .../stable_diffusion/test_stable_diffusion.py | 5 ++--- .../test_stable_diffusion.py | 5 ++--- tests/test_modeling_common.py | 17 +++++++-------- 8 files changed, 45 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index ac6e64e4c779..bb608ad82a7a 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging -from .attention_processor import AttentionProcessor +from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( @@ -368,6 +368,13 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 79a361763c76..cdfa77c269ea 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging -from .attention_processor import AttentionProcessor +from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( @@ -401,19 +401,12 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 8006d0e1c127..ec8865f31031 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging -from .attention_processor import AttentionProcessor +from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .transformer_temporal import TransformerTemporalModel @@ -372,6 +372,13 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index dd5410dbc0b0..5122b9be67c6 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -7,7 +7,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.attention import Attention -from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor +from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel @@ -505,6 +505,13 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index ab6f12085e0f..4953d58188cd 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -22,7 +22,7 @@ from parameterized import parameterized from diffusers import UNet2DConditionModel -from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor +from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.utils import ( floats_tensor, load_hf_numpy, @@ -531,7 +531,7 @@ def test_lora_on_off(self): with torch.no_grad(): sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - model.set_attn_processor(AttnProcessor()) + model.set_default_attn_processor() with torch.no_grad(): new_sample = model(**inputs_dict).sample diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 33ef9368586e..f4e8113a298f 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -35,7 +35,6 @@ UNet2DConditionModel, logging, ) -from diffusers.models.attention_processor import AttnProcessor from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu @@ -843,7 +842,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, ) - pipe.unet.set_attn_processor(AttnProcessor()) + pipe.unet.set_default_attn_processor() pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) outputs = pipe(**inputs) @@ -856,7 +855,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, ) - pipe.unet.set_attn_processor(AttnProcessor()) + pipe.unet.set_default_attn_processor() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 481c265cbee4..fa3c3d628e4f 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -32,7 +32,6 @@ UNet2DConditionModel, logging, ) -from diffusers.models.attention_processor import AttnProcessor from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu @@ -410,7 +409,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): "stabilityai/stable-diffusion-2-base", torch_dtype=torch.float16, ) - pipe.unet.set_attn_processor(AttnProcessor()) + pipe.unet.set_default_attn_processor() pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) outputs = pipe(**inputs) @@ -423,7 +422,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): "stabilityai/stable-diffusion-2-base", torch_dtype=torch.float16, ) - pipe.unet.set_attn_processor(AttnProcessor()) + pipe.unet.set_default_attn_processor() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e880950a7914..932c147027d3 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -25,7 +25,6 @@ from requests.exceptions import HTTPError from diffusers.models import UNet2DConditionModel -from diffusers.models.attention_processor import AttnProcessor from diffusers.training_utils import EMAModel from diffusers.utils import torch_device @@ -106,16 +105,16 @@ def test_from_save_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) - if hasattr(model, "set_attn_processor"): - model.set_attn_processor(AttnProcessor()) + if hasattr(model, "set_default_attn_processor"): + model.set_default_attn_processor() model.to(torch_device) model.eval() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname) - if hasattr(new_model, "set_attn_processor"): - new_model.set_attn_processor(AttnProcessor()) + if hasattr(new_model, "set_default_attn_processor"): + new_model.set_default_attn_processor() new_model.to(torch_device) with torch.no_grad(): @@ -135,16 +134,16 @@ def test_from_save_pretrained_variant(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) - if hasattr(model, "set_attn_processor"): - model.set_attn_processor(AttnProcessor()) + if hasattr(model, "set_default_attn_processor"): + model.set_default_attn_processor() model.to(torch_device) model.eval() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname, variant="fp16") new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") - if hasattr(new_model, "set_attn_processor"): - new_model.set_attn_processor(AttnProcessor()) + if hasattr(new_model, "set_default_attn_processor"): + new_model.set_default_attn_processor() # non-variant cannot be loaded with self.assertRaises(OSError) as error_context: From b6f91293d213acdb528aeba95873b06e59b52633 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 23 Mar 2023 11:37:43 +0100 Subject: [PATCH 2/4] Restore code deleted by mistake. --- src/diffusers/models/unet_2d_condition.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index cdfa77c269ea..49559a8723f8 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -402,6 +402,19 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. From c742ff030c16ab898fdc1695ee0d800fd5a5e4de Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 23 Mar 2023 11:38:52 +0100 Subject: [PATCH 3/4] Format --- src/diffusers/models/unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 49559a8723f8..47a5e34c757a 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -401,7 +401,7 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) - + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): From bf493a2edcb51b6d8b01f1f1f418c37499c5b714 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 23 Mar 2023 11:41:27 +0100 Subject: [PATCH 4/4] Fix modeling_text_unet copy. --- .../pipelines/versatile_diffusion/modeling_text_unet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 5122b9be67c6..b6b60962f038 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -505,7 +505,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation.