Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set default value of target_modules to be None in LoraConfig #269

Merged
merged 5 commits into from
Aug 1, 2024

Conversation

willmj
Copy link
Collaborator

@willmj willmj commented Jul 29, 2024

Description of the change

Original issue (from Slack):

By default target_modules None should work, and it will use the default values HF has here for each architecture when not specified . In PEFT LoraConfig default is None.
In our Config also we should set it to None in tuning/config/peft_config.py and ensure it still works with None through unit tests.

In short, changed
default_factory=lambda: ["q_proj", "v_proj"], to default=None, in the LoraConfig class.

Related issue number

N/A

How to verify the PR

Run unit tests.
Additionally, you can run tuning on a model, and on the output model check target_modules in adapter_config.json matches the intended target_modules of the model used.

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

If you have any questions or concerns, please respond below. Thanks!

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
@willmj
Copy link
Collaborator Author

willmj commented Jul 30, 2024

Example run of TinyLLAMA-v0 with Lora (using llama model_type):

python tuning/sft_trainer.py \
--model_name_or_path Maykeye/TinyLLama-v0 \
--training_data_path tests/data/twitter_complaints_small.json \
--output_dir outputs/lora-tuning \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--learning_rate 1e-5 \
--response_template "\n### Label:" \
--dataset_text_field "output" \
--use_flash_attn false \
--torch_dtype "float32" \
--peft_method "lora" \
--r 8 \
--lora_dropout 0.05 \
--lora_alpha 16 

Results in the following adapter_config.json file:

alpha_pattern	{}
auto_mapping	null
base_model_name_or_path	"Maykeye/TinyLLama-v0"
bias	"none"
fan_in_fan_out	false
inference_mode	true
init_lora_weights	true
layer_replication	null
layers_pattern	null
layers_to_transform	null
loftq_config	{}
lora_alpha	16
lora_dropout	0.05
megatron_config	null
megatron_core	"megatron.core"
modules_to_save	null
peft_type	"LORA"
r	8
rank_pattern	{}
revision	null
target_modules	
0	"v_proj"
1	"q_proj"
task_type	"CAUSAL_LM"
use_dora	false
use_rslora	false

This matches with what we expect of a llama type model.

I also tried to run with gpt-2-medium:

python tuning/sft_trainer.py \
--model_name_or_path openai-community/gpt2-medium \
--training_data_path tests/data/twitter_complaints_small.json \
--output_dir outputs/lora-tuning-gpt \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--learning_rate 1e-5 \
--response_template "\n### Label:" \
--dataset_text_field "output" \
--use_flash_attn false \
--torch_dtype "float32" \
--peft_method "lora" \
--r 8 \
--lora_dropout 0.05 \
--lora_alpha 16

I get the following error:

ValueError: Target modules {'q_proj', 'v_proj'} not found in the base model. Please check the target modules and try again.

Which didn't get resolved until I added --target_modules "c_attn" to the command.
Then the output was expected in adapter_config.json

alpha_pattern	{}
auto_mapping	null
base_model_name_or_path	"openai-community/gpt2-medium"
bias	"none"
fan_in_fan_out	true
inference_mode	true
init_lora_weights	true
layer_replication	null
layers_pattern	null
layers_to_transform	null
loftq_config	{}
lora_alpha	16
lora_dropout	0.05
megatron_config	null
megatron_core	"megatron.core"
modules_to_save	null
peft_type	"LORA"
r	8
rank_pattern	{}
revision	null
target_modules	
0	"c_attn"
task_type	"CAUSAL_LM"
use_dora	false
use_rslora	false

@willmj willmj marked this pull request as draft July 30, 2024 17:50
@willmj
Copy link
Collaborator Author

willmj commented Aug 1, 2024

With target_modules defaulting to None, if they are not explicitly specified, they will be automatically updated by the peft library based on the model architecture.

TinyLlama-v0 test:

python tuning/sft_trainer.py \
--model_name_or_path Maykeye/TinyLLama-v0 \
--training_data_path tests/data/twitter_complaints_small.json \
--output_dir outputs/lora-tuning \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--learning_rate 1e-5 \
--response_template "\n### Label:" \
--dataset_text_field "output" \
--use_flash_attn false \
--torch_dtype "float32" \
--peft_method "lora" \
--r 8 \
--lora_dropout 0.05 \
--lora_alpha 16 

Result:

"target_modules": [
    "q_proj",
    "v_proj"
  ],

gpt2-medium test:

python tuning/sft_trainer.py \
--model_name_or_path openai-community/gpt2-medium \
--training_data_path tests/data/twitter_complaints_small.json \
--output_dir outputs/lora-tuning-gpt \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--learning_rate 1e-5 \
--response_template "\n### Label:" \
--dataset_text_field "output" \
--use_flash_attn false \
--torch_dtype "float32" \
--peft_method "lora" \
--r 8 \
--lora_dropout 0.05 \
--lora_alpha 16

Result:

"target_modules": [
    "c_attn"
  ],

opt-2-125m test:

python tuning/sft_trainer.py \
--model_name_or_path facebook/opt-125m \
--training_data_path tests/data/twitter_complaints_small.json \
--output_dir outputs/lora-tuning-opt \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--learning_rate 1e-5 \
--response_template "\n### Label:" \
--dataset_text_field "output" \
--use_flash_attn false \
--torch_dtype "float32" \
--peft_method "lora" \
--r 8 \
--lora_dropout 0.05 \
--lora_alpha 16

Result:

"target_modules": [
    "v_proj",
    "q_proj"
  ],

@willmj willmj marked this pull request as ready for review August 1, 2024 17:21
@@ -46,7 +46,7 @@ class LoraConfig:
r: int = 8
lora_alpha: int = 32
target_modules: List[str] = field(
default_factory=lambda: ["q_proj", "v_proj"],
default=None,
Copy link
Collaborator

@Ssukriti Ssukriti Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have a unit test that uses the llama model, this LoRA config, does LoRA tuning and gets apater config and ensures adapter config still have q_proj , v_proj .

@anhuong ^

by default if we set None, HF should use the q_proj etc default values per model architecture. we should test that is indeed the case and it is an optional field

@Ssukriti
Copy link
Collaborator

Ssukriti commented Aug 1, 2024

excellent testing above #269 (comment)

would be good to capture it via some unit test too, maybe it is already captured :) lets just double check and add if not

Copy link
Collaborator

@aluu317 aluu317 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Thanks @willmj. Excellent testing, and I agree with Sukriti about running the logical test via the code if possible!

@willmj willmj merged commit 59cc20b into foundation-model-stack:main Aug 1, 2024
7 checks passed
@willmj willmj deleted the 1143 branch August 1, 2024 20:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants