Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve readability of LoRA code #409

Merged
merged 1 commit into from
Jun 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 99 additions & 88 deletions src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,32 +187,22 @@ def add_adapter(self, adapter_name, config=None):
if self.peft_config[adapter_name].inference_mode:
_freeze_adapter(self.model, adapter_name)

def _find_and_replace(self, adapter_name):
lora_config = self.peft_config[adapter_name]
def _check_quantization_dependency(self):
loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
if (loaded_in_4bit or loaded_in_8bit) and not is_bnb_available():
raise ImportError(
"To use Lora with 8-bit or 4-bit quantization, please install the `bitsandbytes` package. "
"You can install it with `pip install bitsandbytes`."
)
is_target_modules_in_base_model = False
kwargs = {
"r": lora_config.r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
}
key_list = [key for key, _ in self.model.named_modules()]
is_using_layer_indexes = getattr(lora_config, "layers_to_transform", None) is not None
layer_indexing_pattern = getattr(lora_config, "layers_pattern", None)

for key in key_list:
if isinstance(lora_config.target_modules, str):
target_module_found = re.fullmatch(lora_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
def _check_target_module_exists(self, lora_config, key):
if isinstance(lora_config.target_modules, str):
target_module_found = re.fullmatch(lora_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
is_using_layer_indexes = getattr(lora_config, "layers_to_transform", None) is not None
layer_indexing_pattern = getattr(lora_config, "layers_pattern", None)

if is_using_layer_indexes and target_module_found:
layers_pattern = COMMON_LAYERS_PATTERN if layer_indexing_pattern is None else layer_indexing_pattern
Expand All @@ -230,80 +220,101 @@ def _find_and_replace(self, adapter_name):
break
else:
target_module_found = False
return target_module_found

if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)
if hasattr(target, "bias"):
bias = target.bias is not None
def _create_new_module(self, lora_config, adapter_name, target):
bias = hasattr(target, "bias") and target.bias is not None
kwargs = {
"r": lora_config.r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
}
loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)

if isinstance(target, LoraLayer):
target.update_layer(
adapter_name,
lora_config.r,
lora_config.lora_alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
eightbit_kwargs = kwargs.copy()
eightbit_kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
}
)
new_module = Linear8bitLt(
adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs
)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target.compute_dtype,
"compress_statistics": target.weight.compress_statistics,
"quant_type": target.weight.quant_type,
}
)
new_module = Linear4bit(adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs)
elif isinstance(target, torch.nn.Embedding):
embedding_kwargs = kwargs.copy()
embedding_kwargs.pop("fan_in_fan_out", None)
in_features, out_features = target.num_embeddings, target.embedding_dim
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
else:
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
eightbit_kwargs = kwargs.copy()
eightbit_kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
}
)
new_module = Linear8bitLt(
adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs
)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target.compute_dtype,
"compress_statistics": target.weight.compress_statistics,
"quant_type": target.weight.quant_type,
}
)
new_module = Linear4bit(
adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs
)
elif isinstance(target, torch.nn.Embedding):
embedding_kwargs = kwargs.copy()
embedding_kwargs.pop("fan_in_fan_out", None)
in_features, out_features = target.num_embeddings, target.embedding_dim
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
elif isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
else:
raise ValueError(
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
)
new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
elif isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
else:
raise ValueError(
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
)
new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)

return new_module

def _find_and_replace(self, adapter_name):
lora_config = self.peft_config[adapter_name]
self._check_quantization_dependency()
is_target_modules_in_base_model = False
key_list = [key for key, _ in self.model.named_modules()]

for key in key_list:
if not self._check_target_module_exists(lora_config, key):
continue

is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)

if isinstance(target, LoraLayer):
target.update_layer(
adapter_name,
lora_config.r,
lora_config.lora_alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
else:
new_module = self._create_new_module(lora_config, adapter_name, target)
self._replace_module(parent, target_name, new_module, target)

self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {lora_config.target_modules} not found in the base model. "
Expand Down
Loading