From ad1fc3747324da45d499838a341ceb89e61f31af Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Fri, 21 Nov 2025 11:22:10 +0530 Subject: [PATCH 1/2] fix: group offloading to support standalone computational layers in block-level offloading --- src/diffusers/hooks/group_offloading.py | 204 +++++++++++++++++++++--- 1 file changed, 181 insertions(+), 23 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 38f291f5203c..4978e48d2d0f 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -578,6 +578,10 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. + + Standalone computational layers (Conv2d, Linear, etc.) that are not part of ModuleList/Sequential are treated + individually with leaf-level logic to ensure proper device management. This includes computational layers nested + within container modules. """ if config.stream is not None and config.num_blocks_per_group != 1: @@ -589,11 +593,20 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() unmatched_modules = [] + unmatched_computational_layers = [] matched_module_groups = [] for name, submodule in module.named_children(): if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): - unmatched_modules.append((name, submodule)) - modules_with_group_offloading.add(name) + # Check if this is a computational layer that should be handled individually + if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): + unmatched_computational_layers.append((name, submodule)) + modules_with_group_offloading.add(name) + else: + # This is a container module - recursively find computational layers within it + _find_and_apply_computational_layer_hooks(submodule, name, config, modules_with_group_offloading) + unmatched_modules.append((name, submodule)) + # Do NOT add the container name to modules_with_group_offloading here, because we need + # parameters from non-computational sublayers (like GroupNorm) to be gathered continue for i in range(0, len(submodule), config.num_blocks_per_group): @@ -622,6 +635,25 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf for group_module in group.modules: _apply_group_offloading_hook(group_module, group, config=config) + # Apply leaf-level treatment to standalone computational layers at the top level + # Each computational layer gets its own ModuleGroup with hooks registered directly on it + for name, comp_layer in unmatched_computational_layers: + group = ModuleGroup( + modules=[comp_layer], + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=comp_layer, + onload_leader=comp_layer, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=name, + ) + _apply_group_offloading_hook(comp_layer, group, config=config) + # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # when the forward pass of this module is called. This is because the top-level module is not # part of any group (as doing so would lead to no VRAM savings). @@ -630,28 +662,154 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf parameters = [param for _, param in parameters] buffers = [buffer for _, buffer in buffers] - # Create a group for the unmatched submodules of the top-level module so that they are on the correct - # device when the forward pass is called. + # Create a group for the remaining unmatched submodules (non-computational containers) of the top-level + # module so that they are on the correct device when the forward pass is called. unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] - unmatched_group = ModuleGroup( - modules=unmatched_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=module, - onload_leader=module, - parameters=parameters, - buffers=buffers, - non_blocking=False, - stream=None, - record_stream=False, - onload_self=True, - group_id=f"{module.__class__.__name__}_unmatched_group", - ) - if config.stream is None: - _apply_group_offloading_hook(module, unmatched_group, config=config) - else: - _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) + if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: + unmatched_group = ModuleGroup( + modules=unmatched_modules, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=module, + onload_leader=module, + parameters=parameters, + buffers=buffers, + non_blocking=False, + stream=None, + record_stream=False, + onload_self=True, + group_id=f"{module.__class__.__name__}_unmatched_group", + ) + if config.stream is None: + _apply_group_offloading_hook(module, unmatched_group, config=config) + else: + _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) + + +def _find_and_apply_computational_layer_hooks( + container_module: torch.nn.Module, + container_name: str, + config: GroupOffloadingConfig, + modules_with_group_offloading: Set[str], +) -> None: + r""" + Recursively finds all computational layers within a container module and applies individual hooks to them. + This ensures that standalone Conv2d, Linear, etc. layers nested inside container modules (like Encoder/Decoder) + get proper device management. + """ + for name, submodule in container_module.named_modules(): + if name == "": # Skip the container itself + continue + + # Only apply hooks to supported computational layers + if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): + full_name = f"{container_name}.{name}" + group = ModuleGroup( + modules=[submodule], + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=submodule, + onload_leader=submodule, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=full_name, + ) + _apply_group_offloading_hook(submodule, group, config=config) + modules_with_group_offloading.add(full_name) + + # Also handle parameters and buffers at non-leaf levels within the container + # This is similar to what leaf-level offloading does + module_dict = dict(container_module.named_modules()) + parameters = [] + buffers = [] + + for name, param in container_module.named_parameters(): + # Check if this parameter has a parent that already got a hook + has_parent_with_hook = False + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + full_parent_name = f"{container_name}.{parent_name}" + if full_parent_name in modules_with_group_offloading: + has_parent_with_hook = True + break + atoms.pop() + + if not has_parent_with_hook: + parameters.append((name, param)) + + for name, buffer in container_module.named_buffers(): + # Check if this buffer has a parent that already got a hook + has_parent_with_hook = False + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + full_parent_name = f"{container_name}.{parent_name}" + if full_parent_name in modules_with_group_offloading: + has_parent_with_hook = True + break + atoms.pop() + + if not has_parent_with_hook: + buffers.append((name, buffer)) + + # Group parameters and buffers by their immediate parent module and apply hooks + parent_to_parameters = {} + for name, param in parameters: + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + if parent_name in module_dict: + if parent_name in parent_to_parameters: + parent_to_parameters[parent_name].append(param) + else: + parent_to_parameters[parent_name] = [param] + break + atoms.pop() + + parent_to_buffers = {} + for name, buffer in buffers: + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + if parent_name in module_dict: + if parent_name in parent_to_buffers: + parent_to_buffers[parent_name].append(buffer) + else: + parent_to_buffers[parent_name] = [buffer] + break + atoms.pop() + + parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys()) + for name in parent_names: + params = parent_to_parameters.get(name, []) + bufs = parent_to_buffers.get(name, []) + parent_module = module_dict[name] + full_parent_name = f"{container_name}.{name}" + + group = ModuleGroup( + modules=[], + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_leader=parent_module, + onload_leader=parent_module, + offload_to_disk_path=config.offload_to_disk_path, + parameters=params, + buffers=bufs, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=full_parent_name, + ) + _apply_group_offloading_hook(parent_module, group, config=config) + modules_with_group_offloading.add(full_parent_name) def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: From 59b6b678295214b70f6ecaa3f95129b76baf50d8 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Fri, 21 Nov 2025 11:37:10 +0530 Subject: [PATCH 2/2] test: for models with standalone and deeply nested layers in block-level offloading --- tests/hooks/test_group_offloading.py | 298 +++++++++++++++++++++++++++ 1 file changed, 298 insertions(+) diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 96cbecfbf530..9099fb49afcb 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -149,6 +149,146 @@ def post_forward(self, module, output): return output +# Model simulating VAE structure with standalone computational layers +class DummyVAELikeModel(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + # Encoder container (not ModuleList/Sequential at top level) + self.encoder = torch.nn.Sequential( + torch.nn.Linear(in_features, hidden_features), + torch.nn.ReLU(), + ) + + # Standalone Conv2d layer (simulates quant_conv) + self.quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1) + + # Decoder container with nested ModuleList + self.decoder = DecoderWithNestedBlocks(hidden_features, hidden_features) + + # Standalone Conv2d layer (simulates post_quant_conv) + self.post_quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1) + + # Output projection + self.linear_out = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Encode + x = self.encoder(x) + + # Reshape for conv operations + batch_size = x.shape[0] + x_reshaped = x.view(batch_size, 1, -1, 1) + + # Apply standalone conv layers + x_reshaped = self.quant_conv(x_reshaped) + x_reshaped = self.post_quant_conv(x_reshaped) + + # Reshape back + x = x_reshaped.view(batch_size, -1) + + # Decode + x = self.decoder(x) + + # Output + x = self.linear_out(x) + return x + + +class DecoderWithNestedBlocks(torch.nn.Module): + def __init__(self, in_features: int, out_features: int) -> None: + super().__init__() + + # Container modules (not ModuleList/Sequential) + self.conv_in = torch.nn.Linear(in_features, in_features) + + # Nested ModuleList (like VAE's decoder.up_blocks) + self.up_blocks = torch.nn.ModuleList( + [torch.nn.Linear(in_features, in_features), torch.nn.Linear(in_features, in_features)] + ) + + # Non-computational layer + self.norm = torch.nn.LayerNorm(in_features) + + self.conv_out = torch.nn.Linear(in_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv_in(x) + for block in self.up_blocks: + x = block(x) + x = self.norm(x) + x = self.conv_out(x) + return x + + +# Model with only standalone computational layers at top level +class DummyModelWithStandaloneLayers(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.layer1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.layer2 = torch.nn.Linear(hidden_features, hidden_features) + self.layer3 = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer1(x) + x = self.activation(x) + x = self.layer2(x) + x = self.layer3(x) + return x + + +# Model with deeply nested structure +class DummyModelWithDeeplyNestedBlocks(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.input_layer = torch.nn.Linear(in_features, hidden_features) + self.container = ContainerWithNestedModuleList(hidden_features) + self.output_layer = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.input_layer(x) + x = self.container(x) + x = self.output_layer(x) + return x + + +class ContainerWithNestedModuleList(torch.nn.Module): + def __init__(self, features: int) -> None: + super().__init__() + + # Top-level computational layer + self.proj_in = torch.nn.Linear(features, features) + + # Nested container with ModuleList + self.nested_container = NestedContainer(features) + + # Another top-level computational layer + self.proj_out = torch.nn.Linear(features, features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj_in(x) + x = self.nested_container(x) + x = self.proj_out(x) + return x + + +class NestedContainer(torch.nn.Module): + def __init__(self, features: int) -> None: + super().__init__() + + self.blocks = torch.nn.ModuleList([torch.nn.Linear(features, features), torch.nn.Linear(features, features)]) + self.norm = torch.nn.LayerNorm(features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + x = self.norm(x) + return x + + @require_torch_accelerator class GroupOffloadTests(unittest.TestCase): in_features = 64 @@ -362,3 +502,161 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): self.assertLess( cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}" ) + + def test_vae_like_model_with_standalone_conv_layers(self): + """Test that models with standalone Conv2d layers (like VAE) work with block-level offloading.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x) + out = model(x) + + self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model.") + + def test_vae_like_model_without_streams(self): + """Test VAE-like model with block-level offloading but without streams.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x) + out = model(x) + + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams." + ) + + def test_model_with_only_standalone_layers(self): + """Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x) + out = model(x) + + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for model with standalone layers." + ) + + def test_model_with_deeply_nested_blocks(self): + """Test models with deeply nested structure where ModuleList is not at top level.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x) + out = model(x) + + self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for deeply nested model.") + + @parameterized.expand([("block_level",), ("leaf_level",)]) + def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str): + """Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x) + out = model(x) + + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), + f"Outputs do not match for standalone Conv layers with {offload_type}.", + ) + + def test_multiple_invocations_with_vae_like_model(self): + """Test that multiple forward passes work correctly with VAE-like model.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + for i in range(5): + out_ref = model_ref(x) + out = model(x) + self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.") + + def test_nested_container_parameters_offloading(self): + """Test that parameters from non-computational layers in nested containers are handled correctly.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + for i in range(3): + out_ref = model_ref(x) + out = model(x) + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), + f"Outputs do not match at iteration {i} for nested parameters.", + )