Skip to content

Commit

Permalink
Fix ZeRO-3 generation context manager (#1617)
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun committed May 3, 2024
1 parent 75de236 commit 0347f58
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion trl/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
Expand Down Expand Up @@ -118,6 +119,9 @@ def remove_hooks(model: "DeepSpeedEngine") -> None:
elif model.optimizer is not None:
optimizer_offload = model.optimizer

for param in iter_params(optimizer_offload.module, recurse=True):
param.ds_active_sub_modules.clear()

for hook in optimizer_offload.forward_hooks:
hook.remove()
for hook in optimizer_offload.backward_hooks:
Expand All @@ -127,6 +131,14 @@ def remove_hooks(model: "DeepSpeedEngine") -> None:
optimizer_offload.backward_hooks = []


def get_all_parameters(sub_module, recurse=False):
return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())


def iter_params(module, recurse=False):
return [param for _, param in get_all_parameters(module, recurse)]


def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
Expand All @@ -141,7 +153,6 @@ def unwrap_model_for_generation(
model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", is_peft_model: bool = False
) -> Union["PreTrainedModelWrapper", "DeepSpeedEngine"]:
"""Context manager to unwrap a model for generation.
For ZeRO-3 models, we gather the weights once to speed up generation.
"""
unwrapped_model = accelerator.unwrap_model(model)
Expand Down

0 comments on commit 0347f58

Please sign in to comment.