Skip to content

Commit

Permalink
fix the fsdp peft autowrap policy (#1694)
Browse files Browse the repository at this point in the history
* fix the fsdp peft autowrap policy

* address comment wrt backwards compatibility
  • Loading branch information
pacman100 committed May 1, 2024
1 parent 3edcebf commit 77b7238
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,11 @@ def fsdp_auto_wrap_policy(model):
import os

from accelerate import FullyShardedDataParallelPlugin

if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"):
get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name
else:
from accelerate.utils.dataclasses import get_module_class_from_name
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy

from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
Expand All @@ -405,7 +410,7 @@ def fsdp_auto_wrap_policy(model):
).split(",")
transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding}
for layer_class in transformer_cls_names_to_wrap:
transformer_cls = FullyShardedDataParallelPlugin.get_module_class_from_name(model, layer_class)
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
Expand Down

0 comments on commit 77b7238

Please sign in to comment.