Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Adalora] Add adalora 4bit #598

Merged
merged 1 commit into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 83 additions & 8 deletions src/peft/tuners/adalora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import re
import warnings
from dataclasses import dataclass, field
Expand All @@ -9,6 +8,7 @@
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D

from ..import_utils import is_bnb_4bit_available, is_bnb_available
from ..utils import (
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
PeftType,
Expand All @@ -24,10 +24,6 @@
)


def is_bnb_available():
return importlib.util.find_spec("bitsandbytes") is not None


if is_bnb_available():
import bitsandbytes as bnb

Expand Down Expand Up @@ -128,7 +124,9 @@ def add_adapter(self, adapter_name, config=None):
def _find_and_replace(self, adapter_name):
lora_config = self.peft_config[adapter_name]
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
if loaded_in_8bit and not is_bnb_available():
loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)

if (loaded_in_8bit or loaded_in_4bit) and not is_bnb_available():
raise ImportError(
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
"You can install it with `pip install bitsandbytes`."
Expand Down Expand Up @@ -173,6 +171,18 @@ def _find_and_replace(self, adapter_name):
new_module = SVDLinear8bitLt(
adapter_name, target.in_features, target.out_features, bias=bias, **kwargs
)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target.compute_dtype,
"compress_statistics": target.weight.compress_statistics,
"quant_type": target.weight.quant_type,
}
)
new_module = SVDLinear4bit(
adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs
)
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features
Expand Down Expand Up @@ -230,7 +240,10 @@ def forward(self, *args, **kwargs):
I.requires_grad = False
num_param += 1
regu_loss += torch.norm(para_cov - I, p="fro")
regu_loss = regu_loss / num_param
if num_param > 0:
regu_loss = regu_loss / num_param
else:
regu_loss = 0
outputs.loss += orth_reg_weight * regu_loss
return outputs

Expand Down Expand Up @@ -507,7 +520,69 @@ def forward(self, x: torch.Tensor):
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
result += output
result = result + output
return result

class SVDLinear4bit(bnb.nn.Linear4bit, AdaLoraLayer):
# Low-rank matrix for SVD-based adaptation
def __init__(
self,
adapter_name,
in_features,
out_features,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
bnb.nn.Linear4bit.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
compute_dtype=kwargs.get("compute_dtype", torch.float32),
compress_statistics=kwargs.get("compress_statistics", True),
quant_type=kwargs.get("quant_type", "nf4"),
)
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False

init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name

def forward(self, x: torch.Tensor):
result = super().forward(x)

if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
return result
elif self.r[self.active_adapter] > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype

if x.dtype != torch.float32:
x = x.float()
output = (
(
self.lora_dropout[self.active_adapter](x)
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T
@ self.lora_B[self.active_adapter].T
).to(expected_dtype)
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
else:
output = (
(
self.lora_dropout[self.active_adapter](x)
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T
@ self.lora_B[self.active_adapter].T
)
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
result = result + output
return result


Expand Down
75 changes: 71 additions & 4 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@
WhisperTokenizer,
)

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
from peft import (
AdaLoraConfig,
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
)

from .testing_utils import require_bitsandbytes, require_torch_gpu, require_torch_multi_gpu

Expand Down Expand Up @@ -80,10 +86,10 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->

@require_torch_gpu
@require_bitsandbytes
class PeftInt8GPUExampleTests(unittest.TestCase):
class PeftBnbGPUExampleTests(unittest.TestCase):
r"""
A single GPU int8 test suite, this will test if training fits correctly on a single GPU device (1x NVIDIA T4 16GB)
using bitsandbytes.
A single GPU int8 + fp4 test suite, this will test if training fits correctly on a single GPU device (1x NVIDIA T4
16GB) using bitsandbytes.

The tests are the following:

Expand Down Expand Up @@ -168,6 +174,67 @@ def test_causal_lm_training(self):
# assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

@pytest.mark.single_gpu_tests
@require_torch_gpu
def test_4bit_adalora_causalLM(self):
r"""
Tests the 4bit training with adalora
"""
model_id = "facebook/opt-350m"

model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

peft_config = AdaLoraConfig(
init_r=6,
target_r=4,
tinit=50,
tfinal=100,
deltaT=5,
beta1=0.3,
beta2=0.3,
orth_reg_weight=0.2,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

model = get_peft_model(model, peft_config)

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()

model.cpu().save_pretrained(tmp_dir)

self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue("adapter_model.bin" in os.listdir(tmp_dir))

# assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

@pytest.mark.multi_gpu_tests
@require_torch_multi_gpu
def test_causal_lm_training_mutli_gpu(self):
Expand Down
Loading