Skip to content

Commit

Permalink
FIX DeepSpeed recursion error (#1892)
Browse files Browse the repository at this point in the history
Happened when accessing attribute before init.
  • Loading branch information
ret-1 committed Jul 3, 2024
1 parent 018a1f4 commit 31c0d85
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)


Expand Down
2 changes: 2 additions & 0 deletions src/peft/mixed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "base_model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.base_model, name)

def forward(self, *args: Any, **kwargs: Any):
Expand Down
2 changes: 2 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "base_model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.base_model, name)

@contextmanager
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)

def forward(self, *args, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/adaption_prompt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,6 @@ def __getattr__(self, name: str):
except AttributeError:
# This is necessary as e.g. causal models have various methods that we
# don't want to re-implement here.
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)
2 changes: 2 additions & 0 deletions src/peft/tuners/boft/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)

def get_peft_config_as_dict(self, inference: bool = False):
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)

def get_peft_config_as_dict(self, inference: bool = False):
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/ln_tuning/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)

# TODO: here need to handle the modules_to_save rather than the target_modules
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)

def get_peft_config_as_dict(self, inference: bool = False):
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/lycoris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/mixed/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)

def _set_adapter_layers(self, enabled=True):
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/poly/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)

def get_peft_config_as_dict(self, inference: bool = False):
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/vera/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)

def get_peft_config_as_dict(self, inference: bool = False):
Expand Down
68 changes: 67 additions & 1 deletion tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,21 @@
from scipy import stats
from torch import nn

from peft import AdaLoraConfig, LoraConfig, PeftModel, PromptTuningConfig, VeraConfig, get_peft_model
from peft import (
AdaLoraConfig,
LoraConfig,
PeftMixedModel,
PeftModel,
PeftModelForCausalLM,
PeftModelForFeatureExtraction,
PeftModelForQuestionAnswering,
PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
PromptTuningConfig,
VeraConfig,
get_peft_model,
)
from peft.utils import infer_device


Expand Down Expand Up @@ -601,3 +615,55 @@ def test_vera_mixing_save_projection_raises(self):
)
with pytest.raises(ValueError, match=msg):
model.add_adapter("other", config1)


class TestNoInfiniteRecursionDeepspeed:
# see #1892 for details
classes = [
PeftModel,
PeftMixedModel,
PeftModelForSequenceClassification,
PeftModelForQuestionAnswering,
PeftModelForTokenClassification,
PeftModelForCausalLM,
PeftModelForSeq2SeqLM,
PeftModelForFeatureExtraction,
]

@pytest.fixture
def wrap_init(self):
# emulates the wrapper from DeepSpeed
import functools

def decorator(f):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
hasattr(self, "abc") # any hasattr will do
f(self, *args, **kwargs)

return wrapper

return decorator

@pytest.fixture
def model(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
# to emulate LMs:
self.prepare_inputs_for_generation = None
self._prepare_encoder_decoder_kwargs_for_generation = None

return MyModule()

@pytest.mark.parametrize("cls", classes)
def test_no_infinite_recursion(self, cls, model, wrap_init):
original_init = cls.__init__
try:
cls.__init__ = wrap_init(cls.__init__)
# this would trigger an infinite loop before the fix in 1892
cls(model, LoraConfig(target_modules=["linear"]))
finally:
# ensure there are no side effects of this test
cls.__init__ = original_init

0 comments on commit 31c0d85

Please sign in to comment.