Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ZeRO-3 generation context manager #1617

Merged
merged 1 commit into from
May 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion trl/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
Expand Down Expand Up @@ -118,6 +119,9 @@ def remove_hooks(model: "DeepSpeedEngine") -> None:
elif model.optimizer is not None:
optimizer_offload = model.optimizer

for param in iter_params(optimizer_offload.module, recurse=True):
param.ds_active_sub_modules.clear()

for hook in optimizer_offload.forward_hooks:
hook.remove()
for hook in optimizer_offload.backward_hooks:
Expand All @@ -127,6 +131,14 @@ def remove_hooks(model: "DeepSpeedEngine") -> None:
optimizer_offload.backward_hooks = []


def get_all_parameters(sub_module, recurse=False):
return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())


def iter_params(module, recurse=False):
return [param for _, param in get_all_parameters(module, recurse)]


def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
Expand All @@ -141,7 +153,6 @@ def unwrap_model_for_generation(
model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", is_peft_model: bool = False
) -> Union["PreTrainedModelWrapper", "DeepSpeedEngine"]:
"""Context manager to unwrap a model for generation.

For ZeRO-3 models, we gather the weights once to speed up generation.
"""
unwrapped_model = accelerator.unwrap_model(model)
Expand Down
Loading