Skip to content

Commit

Permalink
[NeuralChat] support full parameters finetuning (#824)
Browse files Browse the repository at this point in the history
* support full parameters finetuning.
  • Loading branch information
lkk12014402 committed Dec 6, 2023
1 parent f2ac75f commit 2b5411f
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 60 deletions.
152 changes: 95 additions & 57 deletions intel_extension_for_transformers/llm/finetuning/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,15 @@ def finetune(self):
raise NotImplementedError(
f"Unsupported bits {finetune_args.bits}, only support 4 and 8 now."
)
if finetune_args.full_finetune:
raise ValueError(
f"qlora and full_finetune can't be True at the same time."
)
elif finetune_args.full_finetune:
if finetune_args.bits not in [16, 32]:
raise ValueError(
f"full finetune only support 16 and 32 bits."
)

config = self.load_model_config(self.model_args)
if config.architectures[0].endswith("ForCausalLM") \
Expand Down Expand Up @@ -482,48 +491,50 @@ def concatenate_data(dataset, max_seq_length):
)

if training_args.do_train:
# PEFT settings
if finetune_args.peft == "lora":
if finetune_args.lora_all_linear:
target_modules = self.find_all_linear_names(model)
else:
target_modules = finetune_args.lora_target_modules

peft_config = LoraConfig(
r=finetune_args.lora_rank,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=target_modules,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
elif finetune_args.peft == "llama_adapter":
peft_config = AdaptionPromptConfig(
adapter_layers=finetune_args.adapter_layers,
adapter_len=finetune_args.adapter_len,
task_type="CAUSAL_LM",
)
elif finetune_args.peft == "ptun":
peft_config = PromptEncoderConfig(
num_virtual_tokens=finetune_args.num_virtual_tokens,
encoder_hidden_size=finetune_args.ptun_hidden_size,
task_type="CAUSAL_LM",
)
elif finetune_args.peft == "prefix":
peft_config = PrefixTuningConfig(
num_virtual_tokens=finetune_args.num_virtual_tokens,
task_type="CAUSAL_LM",
)
elif finetune_args.peft == "prompt":
peft_config = PromptTuningConfig(
num_virtual_tokens=finetune_args.num_virtual_tokens,
task_type="CAUSAL_LM",
)
if not finetune_args.full_finetune:
# PEFT settings
if finetune_args.peft == "lora":
if finetune_args.lora_all_linear:
target_modules = self.find_all_linear_names(model)
else:
target_modules = finetune_args.lora_target_modules

peft_config = LoraConfig(
r=finetune_args.lora_rank,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=target_modules,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
elif finetune_args.peft == "llama_adapter":
peft_config = AdaptionPromptConfig(
adapter_layers=finetune_args.adapter_layers,
adapter_len=finetune_args.adapter_len,
task_type="CAUSAL_LM",
)
elif finetune_args.peft == "ptun":
peft_config = PromptEncoderConfig(
num_virtual_tokens=finetune_args.num_virtual_tokens,
encoder_hidden_size=finetune_args.ptun_hidden_size,
task_type="CAUSAL_LM",
)
elif finetune_args.peft == "prefix":
peft_config = PrefixTuningConfig(
num_virtual_tokens=finetune_args.num_virtual_tokens,
task_type="CAUSAL_LM",
)
elif finetune_args.peft == "prompt":
peft_config = PromptTuningConfig(
num_virtual_tokens=finetune_args.num_virtual_tokens,
task_type="CAUSAL_LM",
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

model = get_peft_model(model, peft_config)
if model_dtype == torch.bfloat16:
model = model.to(model_dtype)
model.print_trainable_parameters()

if finetune_args.device != 'hpu':
# Initialize our Trainer
Expand Down Expand Up @@ -806,24 +817,33 @@ def preprocess_logits_for_metrics(logits, labels):
else:
raise ValueError("Must provide model_name_or_path to load a pretrained Seq2SeqLM model.")

# PEFT settings
if finetune_args.peft == "lora":
if finetune_args.lora_all_linear:
target_modules = self.find_all_linear_names(model)
else:
target_modules = finetune_args.lora_target_modules
peft_config = LoraConfig(
r=finetune_args.lora_rank,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=target_modules,
bias="none",
task_type=TaskType.SEQ_2_SEQ_LM,
if finetune_args.qlora:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)

# model = prepare_model_for_int8_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
if not finetune_args.full_finetune:
# PEFT settings
if finetune_args.peft == "lora":
if finetune_args.lora_all_linear:
target_modules = self.find_all_linear_names(model)
else:
target_modules = finetune_args.lora_target_modules
peft_config = LoraConfig(
r=finetune_args.lora_rank,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=target_modules,
bias="none",
task_type=TaskType.SEQ_2_SEQ_LM,
)

# model = prepare_model_for_int8_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

if model_dtype == torch.bfloat16:
model = model.to(model_dtype)

if training_args.do_eval and not training_args.do_train:
config = PeftConfig.from_pretrained(model_args.model_name_or_path)
Expand All @@ -839,9 +859,26 @@ def preprocess_logits_for_metrics(logits, labels):
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8)

# Create Trainer instance
trainer = Seq2SeqTrainer(

if finetune_args.device != 'hpu':
# Create Trainer instance
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
)
else:
from optimum.habana import GaudiConfig, GaudiSeq2SeqTrainer # pylint: disable=E0611 E0401
gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
gaudi_config.use_fused_clip_norm = True
trainer = GaudiSeq2SeqTrainer(
model=model,
gaudi_config=gaudi_config,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset if training_args.do_train else None,
Expand All @@ -850,6 +887,7 @@ def preprocess_logits_for_metrics(logits, labels):
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
)


# Training
if training_args.do_train:
checkpoint = None
Expand Down
4 changes: 4 additions & 0 deletions intel_extension_for_transformers/neural_chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ class FinetuningArguments:
default=4,
metadata={"help": "How many bits to use."}
)
full_finetune: bool = field(
default=False,
metadata={"help": "Finetune the entire model without adapters."}
)

@dataclass
class TTSDatasetArguments:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def tearDownClass(self):
shutil.rmtree('./tmp', ignore_errors=True)
os.remove(test_data_file)

def test_finetune_clm(self):
def test_finetune_clm_lora(self):
model_args = ModelArguments(model_name_or_path="facebook/opt-125m")
data_args = DataArguments(train_file=test_data_file)
training_args = TrainingArguments(
Expand Down Expand Up @@ -85,7 +85,70 @@ def test_finetune_clm_qlora(self):
)
finetune_model(finetune_cfg)

def test_finetune_seq2seq(self):
def test_finetune_clm_full_finetuning(self):
model_args = ModelArguments(model_name_or_path="facebook/opt-125m")
data_args = DataArguments(train_file=test_data_file)
training_args = TrainingArguments(
output_dir='./tmp',
do_train=True,
max_steps=3,
overwrite_output_dir=True
)
finetune_args = FinetuningArguments(device='cpu', full_finetune=True,
bits=16, do_lm_eval=False)
finetune_cfg = TextGenerationFinetuningConfig(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetune_args=finetune_args,
)
finetune_model(finetune_cfg)

def test_finetune_clm_value1(self):
model_args = ModelArguments(model_name_or_path="facebook/opt-125m")
data_args = DataArguments(train_file=test_data_file)
training_args = TrainingArguments(
output_dir='./tmp',
do_train=True,
max_steps=3,
overwrite_output_dir=True
)
finetune_args = FinetuningArguments(device='cpu', full_finetune=True, do_lm_eval=False)
finetune_cfg = TextGenerationFinetuningConfig(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetune_args=finetune_args,
)
# finetune_model(finetune_cfg)
try:
finetune_model(finetune_cfg)
except ValueError:
print("code pass")

def test_finetune_clm_value2(self):
model_args = ModelArguments(model_name_or_path="facebook/opt-125m")
data_args = DataArguments(train_file=test_data_file)
training_args = TrainingArguments(
output_dir='./tmp',
do_train=True,
max_steps=3,
overwrite_output_dir=True
)
finetune_args = FinetuningArguments(device='cpu', qlora=True,
full_finetune=True, do_lm_eval=False)
finetune_cfg = TextGenerationFinetuningConfig(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetune_args=finetune_args,
)
try:
finetune_model(finetune_cfg)
except ValueError:
print("code pass")

def test_finetune_seq2seq_lora(self):
model_args = ModelArguments(model_name_or_path="google/flan-t5-small")
data_args = DataArguments(train_file=test_data_file)
training_args = Seq2SeqTrainingArguments(
Expand All @@ -103,5 +166,42 @@ def test_finetune_seq2seq(self):
)
finetune_model(finetune_cfg)

def test_finetune_seq2seq_qlora(self):
model_args = ModelArguments(model_name_or_path="google/flan-t5-small")
data_args = DataArguments(train_file=test_data_file)
training_args = Seq2SeqTrainingArguments(
output_dir='./tmp',
do_train=True,
max_steps=3,
overwrite_output_dir=True
)
finetune_args = FinetuningArguments(device='cpu', qlora=True)
finetune_cfg = TextGenerationFinetuningConfig(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetune_args=finetune_args,
)
finetune_model(finetune_cfg)

def test_finetune_seq2seq_full_finetuning(self):
model_args = ModelArguments(model_name_or_path="google/flan-t5-small")
data_args = DataArguments(train_file=test_data_file)
training_args = Seq2SeqTrainingArguments(
output_dir='./tmp',
do_train=True,
max_steps=3,
overwrite_output_dir=True
)
finetune_args = FinetuningArguments(device='cpu', full_finetune=True,
bits=16, do_lm_eval=False)
finetune_cfg = TextGenerationFinetuningConfig(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetune_args=finetune_args,
)
finetune_model(finetune_cfg)

if __name__ == "__main__":
unittest.main()
unittest.main()

0 comments on commit 2b5411f

Please sign in to comment.