Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 181 additions & 23 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,10 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.

Standalone computational layers (Conv2d, Linear, etc.) that are not part of ModuleList/Sequential are treated
individually with leaf-level logic to ensure proper device management. This includes computational layers nested
within container modules.
"""

if config.stream is not None and config.num_blocks_per_group != 1:
Expand All @@ -589,11 +593,20 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
unmatched_modules = []
unmatched_computational_layers = []
matched_module_groups = []
for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
unmatched_modules.append((name, submodule))
modules_with_group_offloading.add(name)
# Check if this is a computational layer that should be handled individually
if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
unmatched_computational_layers.append((name, submodule))
modules_with_group_offloading.add(name)
else:
# This is a container module - recursively find computational layers within it
_find_and_apply_computational_layer_hooks(submodule, name, config, modules_with_group_offloading)
unmatched_modules.append((name, submodule))
# Do NOT add the container name to modules_with_group_offloading here, because we need
# parameters from non-computational sublayers (like GroupNorm) to be gathered
Comment on lines +608 to +609
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you expand a bit more on this?

continue

for i in range(0, len(submodule), config.num_blocks_per_group):
Expand Down Expand Up @@ -622,6 +635,25 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, config=config)

# Apply leaf-level treatment to standalone computational layers at the top level
# Each computational layer gets its own ModuleGroup with hooks registered directly on it
for name, comp_layer in unmatched_computational_layers:
group = ModuleGroup(
modules=[comp_layer],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=comp_layer,
onload_leader=comp_layer,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(comp_layer, group, config=config)

# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
# part of any group (as doing so would lead to no VRAM savings).
Expand All @@ -630,28 +662,154 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
parameters = [param for _, param in parameters]
buffers = [buffer for _, buffer in buffers]

# Create a group for the unmatched submodules of the top-level module so that they are on the correct
# device when the forward pass is called.
# Create a group for the remaining unmatched submodules (non-computational containers) of the top-level
# module so that they are on the correct device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)


def _find_and_apply_computational_layer_hooks(
container_module: torch.nn.Module,
container_name: str,
config: GroupOffloadingConfig,
modules_with_group_offloading: Set[str],
) -> None:
r"""
Recursively finds all computational layers within a container module and applies individual hooks to them.
This ensures that standalone Conv2d, Linear, etc. layers nested inside container modules (like Encoder/Decoder)
get proper device management.
"""
for name, submodule in container_module.named_modules():
if name == "": # Skip the container itself
continue

# Only apply hooks to supported computational layers
if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
full_name = f"{container_name}.{name}"
group = ModuleGroup(
modules=[submodule],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=full_name,
)
_apply_group_offloading_hook(submodule, group, config=config)
modules_with_group_offloading.add(full_name)

# Also handle parameters and buffers at non-leaf levels within the container
# This is similar to what leaf-level offloading does
module_dict = dict(container_module.named_modules())
parameters = []
buffers = []

for name, param in container_module.named_parameters():
# Check if this parameter has a parent that already got a hook
has_parent_with_hook = False
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
full_parent_name = f"{container_name}.{parent_name}"
if full_parent_name in modules_with_group_offloading:
has_parent_with_hook = True
break
atoms.pop()

if not has_parent_with_hook:
parameters.append((name, param))

for name, buffer in container_module.named_buffers():
# Check if this buffer has a parent that already got a hook
has_parent_with_hook = False
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
full_parent_name = f"{container_name}.{parent_name}"
if full_parent_name in modules_with_group_offloading:
has_parent_with_hook = True
break
atoms.pop()

if not has_parent_with_hook:
buffers.append((name, buffer))

# Group parameters and buffers by their immediate parent module and apply hooks
parent_to_parameters = {}
for name, param in parameters:
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
if parent_name in module_dict:
if parent_name in parent_to_parameters:
parent_to_parameters[parent_name].append(param)
else:
parent_to_parameters[parent_name] = [param]
break
atoms.pop()

parent_to_buffers = {}
for name, buffer in buffers:
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
if parent_name in module_dict:
if parent_name in parent_to_buffers:
parent_to_buffers[parent_name].append(buffer)
else:
parent_to_buffers[parent_name] = [buffer]
break
atoms.pop()

parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
for name in parent_names:
params = parent_to_parameters.get(name, [])
bufs = parent_to_buffers.get(name, [])
parent_module = module_dict[name]
full_parent_name = f"{container_name}.{name}"

group = ModuleGroup(
modules=[],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
offload_to_disk_path=config.offload_to_disk_path,
parameters=params,
buffers=bufs,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=full_parent_name,
)
_apply_group_offloading_hook(parent_module, group, config=config)
modules_with_group_offloading.add(full_parent_name)


def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
Expand Down
Loading
Loading