From b1ded1338f6cbada047d22d3bce4aad9e5818830 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 19 Dec 2024 22:37:06 +0100 Subject: [PATCH 1/3] set supports gradient checkpointing to true where necessary; add missing no split modules --- src/diffusers/models/modeling_utils.py | 2 +- .../models/transformers/cogvideox_transformer_3d.py | 1 + src/diffusers/models/transformers/latte_transformer_3d.py | 1 + src/diffusers/models/transformers/transformer_allegro.py | 3 +++ .../models/transformers/transformer_cogview3plus.py | 1 + .../models/transformers/transformer_hunyuan_video.py | 6 ++++++ src/diffusers/models/transformers/transformer_ltx.py | 1 + src/diffusers/models/transformers/transformer_sd3.py | 1 + 8 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0f9c9203c926..44e88d494461 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1196,7 +1196,7 @@ def _get_signature_keys(cls, obj): # Adapted from `transformers` modeling_utils.py def _get_no_split_modules(self, device_map: str): """ - Get the modules of the model that should not be spit when using device_map. We iterate through the modules to + Get the modules of the model that should not be split when using device_map. We iterate through the modules to get the underlying `_no_split_modules`. Args: diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index b47d439774cc..e83c5be75b44 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -210,6 +210,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index d34ccfd20108..7a7127439ebc 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -27,6 +27,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock"] """ A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code: diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index fe9c7290b063..fc42e1bfb4a0 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -221,6 +221,9 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): Scaling factor to apply in 3D positional embeddings across time dimension. """ + _supports_gradient_checkpointing = True + _no_split_modules = ["AllegroTransformerBlock"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 94d852f6df4b..369509a3a35e 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 089389b5f9ad..feabcfac9bff 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -540,6 +540,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = [ + "HunyuanVideoTransformerBlock", + "HunyuanVideoSingleTransformerBlock", + "HunyuanVideoPatchEmbed", + "HunyuanVideoTokenRefiner", + ] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2ed8520a5d75..fdfe4d76b0a3 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -295,6 +295,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin """ _supports_gradient_checkpointing = True + _no_split_modules = ["LTXTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 79c4069e9a37..a0f76d9ec596 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -125,6 +125,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi """ _supports_gradient_checkpointing = True + _no_split_modules = ["JointTransformerBlock", "SD3SingleTransformerBlock"] @register_to_config def __init__( From ec07f8c6865b7979431ab769377309e5ffe74255 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 19 Dec 2024 22:51:26 +0100 Subject: [PATCH 2/3] fix cogvideox tests --- .../models/transformers/test_models_transformer_cogvideox.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py index 4c13b54e0620..73b83b9eb514 100644 --- a/tests/models/transformers/test_models_transformer_cogvideox.py +++ b/tests/models/transformers/test_models_transformer_cogvideox.py @@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self): "out_channels": 4, "time_embed_dim": 2, "text_embed_dim": 8, - "num_layers": 1, + "num_layers": 2, "sample_width": 8, "sample_height": 8, "sample_frames": 8, @@ -130,7 +130,7 @@ def prepare_init_args_and_inputs_for_common(self): "out_channels": 4, "time_embed_dim": 2, "text_embed_dim": 8, - "num_layers": 1, + "num_layers": 2, "sample_width": 8, "sample_height": 8, "sample_frames": 8, From b93532a409f3bafa4837023c09df0a1d86e6f8c6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 19 Dec 2024 23:01:21 +0100 Subject: [PATCH 3/3] update --- src/diffusers/models/transformers/latte_transformer_3d.py | 1 - src/diffusers/models/transformers/transformer_allegro.py | 1 - src/diffusers/models/transformers/transformer_ltx.py | 1 - src/diffusers/models/transformers/transformer_sd3.py | 1 - .../models/transformers/test_models_transformer_cogview3plus.py | 2 +- 5 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 7a7127439ebc..d34ccfd20108 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -27,7 +27,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True - _no_split_modules = ["BasicTransformerBlock"] """ A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code: diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index fc42e1bfb4a0..81039fd49e0d 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -222,7 +222,6 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _no_split_modules = ["AllegroTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index fdfe4d76b0a3..2ed8520a5d75 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -295,7 +295,6 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin """ _supports_gradient_checkpointing = True - _no_split_modules = ["LTXTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index a0f76d9ec596..79c4069e9a37 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -125,7 +125,6 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi """ _supports_gradient_checkpointing = True - _no_split_modules = ["JointTransformerBlock", "SD3SingleTransformerBlock"] @register_to_config def __init__( diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py index eda9813808e9..ec6c58a6734c 100644 --- a/tests/models/transformers/test_models_transformer_cogview3plus.py +++ b/tests/models/transformers/test_models_transformer_cogview3plus.py @@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self): init_dict = { "patch_size": 2, "in_channels": 4, - "num_layers": 1, + "num_layers": 2, "attention_head_dim": 4, "num_attention_heads": 2, "out_channels": 4,