Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto_wrap_policy for PEFT with FSDP #2253

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,7 @@ def get_module_class_from_name(module, name):
return module_class

def set_auto_wrap_policy(self, model):
from peft import PeftModel
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy

default_transformer_cls_names_to_wrap = (
Expand All @@ -986,11 +987,36 @@ def set_auto_wrap_policy(self, model):
else:
transformer_cls_to_wrap.add(transformer_cls)

self.auto_wrap_policy = functools.partial(
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)

# In an FSDP setting PEFT models require individually wrapping unfrozen parameters

if isinstance(model, PeftModel):
print("PEFT wrapping")
fs4r marked this conversation as resolved.
Show resolved Hide resolved
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
fs4r marked this conversation as resolved.
Show resolved Hide resolved
)

def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
fs4r marked this conversation as resolved.
Show resolved Hide resolved
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False

lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, auto_wrap_policy])

self.auto_wrap_policy = auto_wrap_policy

elif auto_wrap_policy == FSDP_AUTO_WRAP_POLICY[1]:
min_num_params = int(os.environ.get("FSDP_MIN_NUM_PARAMS", 0))
if min_num_params > 0:
Expand Down
26 changes: 26 additions & 0 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,32 @@ def test_state_dict_type(self):
self.assertTrue(fsdp_plugin.state_dict_config.offload_to_cpu)
self.assertTrue(fsdp_plugin.state_dict_config.rank0_only)

def test_auto_wrap_policy_peft(self):
from peft import LoraConfig, TaskType, get_peft_model
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

peft_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)
model = AutoModel.from_pretrained(BERT_BASE_CASED)
model = get_peft_model(model, peft_config)

env = self.dist_env.copy()
env["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP"
env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "BertLayer"
env["FSDP_USE_ORIG_PARAMS"] = "false"
env["RANK"] = "0"
with mockenv_context(**env): #
fsdp_plugin = FullyShardedDataParallelPlugin()
fsdp_plugin.set_auto_wrap_policy(model)
kwargs = {
"sharding_strategy": fsdp_plugin.sharding_strategy,
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
"use_orig_params": fsdp_plugin.use_orig_params,
}
torch.distributed.init_process_group(backend="nccl")
model = FSDP(model, **kwargs)

def test_auto_wrap_policy(self):
model = AutoModel.from_pretrained(BERT_BASE_CASED)
for policy in FSDP_AUTO_WRAP_POLICY:
Expand Down
Loading