From 77b7238b907c85f21ec756176351d85faf2aeaad Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 1 May 2024 09:08:55 +0530 Subject: [PATCH] fix the fsdp peft autowrap policy (#1694) * fix the fsdp peft autowrap policy * address comment wrt backwards compatibility --- src/peft/utils/other.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 964501790a..2b49e82daf 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -393,6 +393,11 @@ def fsdp_auto_wrap_policy(model): import os from accelerate import FullyShardedDataParallelPlugin + + if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"): + get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name + else: + from accelerate.utils.dataclasses import get_module_class_from_name from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder @@ -405,7 +410,7 @@ def fsdp_auto_wrap_policy(model): ).split(",") transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding} for layer_class in transformer_cls_names_to_wrap: - transformer_cls = FullyShardedDataParallelPlugin.get_module_class_from_name(model, layer_class) + transformer_cls = get_module_class_from_name(model, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: