Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lora config data model and utilities #270

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

Ssukriti
Copy link
Collaborator

@Ssukriti Ssukriti commented Nov 17, 2023

For ##136

This PR is the first step to adding a data model and utility to initialize LoraConfig . It adds support for get_peft_config to return LoraConfig or PEFTTuningConfig

It also includes a refactor to

  1. move common code needed for peft and Lora config initialization to the peft_config file .
  2. create modular functions
  3. move out parts specific to pefT Config to a separate private function ultimately called from within get_peft_config based on tuning type

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
@@ -1150,34 +1155,6 @@ def _execute_train_loop(
)
return {"loss": training_loss_tracker}

@classmethod
def _filter_params_for_prompt_config(cls, prompt_config, params):
"""Utility function to filter out required parameters for prompt_config
Copy link
Collaborator Author

@Ssukriti Ssukriti Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this belongs in peft_config as it a utility to create any Config, so I moved it. I also think create_hf_tuning_config( in thsi file belongs there, but I did not move it because I dont know the rational behind keeping create_hf_tuning_config here in prompt_tuning.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more of a legacy reason than anything, all of the code for tuning config stuff was originally written in this file as part of this module, and the refactoring to pull some of it out happened later. I don't think there was a deep reason for leaving create_hf_tuning_config here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I will move it

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
)
# Ensure that our verbalizer is a string and will not render to a hardcoded string
error.value_check(
"<NLP83837412E>",
Copy link
Collaborator Author

@Ssukriti Ssukriti Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verbalizer is also not used, we are checking if its a valid verbalizer , but that should be done in train of peft_prompt_tuning. it has nothing to do with the Config, so let me know if ok to remove (would API change again, hence didnt do)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, where is verbalizer not getting used?

Copy link
Collaborator Author

@Ssukriti Ssukriti Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are only checking type of verbalizer here. it is not used to create the PEFT config and is also an extra parameter for this function. I can remove it if we can make API changes with next release

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added deprecation warning

@@ -216,5 +276,15 @@ def get_peft_config(
tuning_config=tuning_config,
output_model_types=output_model_types,
)
return peft_config
Copy link
Collaborator Author

@Ssukriti Ssukriti Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this create_hf_tuning_config function should ideally be moved here. I dont know rationale for why its in peft_prompt_tuning.py . Let me know if we can move it here and then get rid of cls argument (would be API breaking )

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
# Lora attention dimension.
r: int
# The names of the modules to apply Lora to.
target_modules: Union[List[str], str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add an example of this in the comment as well? Also wondering how will a user know about these modules?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the same from HF. There are ways to know modules of a model by loading model with transformers and printing it. I will have to include that in actual examples that we commit to caikit NLP when feature is complete. we might have to expose a helper function in caikit NLP to get_module_names(model) or something . but I think I can implement that in follow up PRs when we get to adding examples.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would also need to figure out how do we expose these functions to cloud users then, also we would need to explain them what these modules are and what are its implication. Any suggestions on how this would be exposed to cloud users?

Copy link
Collaborator Author

@Ssukriti Ssukriti Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean ya the consuming products would have to document what parameters mean (they do that already). Model factsheets include the architecture and layers
Screenshot 2023-12-03 at 10 55 45 PM

users dont have to run any function if they know model architecture which I think should be exposed to users already from factsheet

Besides, it is optional parameter and upto cloud products if they want to expose it or not (though it is commonly tuned parameter from blog posts). I will make it more clear that it can be None as well and is Optional (I think I have to make some changes in the data model to make it Optional type - I will do that)

from library perspective we can still expose it and have further discussions down the road . Even if the parameter is left unused from cloud and not exposed to users to begin with, its not a big deal to have it in library

Copy link
Collaborator Author

@Ssukriti Ssukriti Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other ways of doing it would be to document model architectures from a common page for supported models; or if cloud products want to create a functionality to obtain model architecture in real time, we can have that discussion independently

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, when users have their own model they will most likely know the architecture

caikit_nlp/data_model/generation.py Outdated Show resolved Hide resolved
)
# Ensure that our verbalizer is a string and will not render to a hardcoded string
error.value_check(
"<NLP83837412E>",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, where is verbalizer not getting used?

caikit_nlp/modules/text_generation/peft_config.py Outdated Show resolved Hide resolved
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Copy link
Collaborator

@alex-jw-brooks alex-jw-brooks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I think this looks good! Some small questions, but once those are cleared up I think we can get it merged, thanks Sukriti!

@@ -1150,34 +1155,6 @@ def _execute_train_loop(
)
return {"loss": training_loss_tracker}

@classmethod
def _filter_params_for_prompt_config(cls, prompt_config, params):
"""Utility function to filter out required parameters for prompt_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more of a legacy reason than anything, all of the code for tuning config stuff was originally written in this file as part of this module, and the refactoring to pull some of it out happened later. I don't think there was a deep reason for leaving create_hf_tuning_config here

base_model,
cls=None,
torch_dtype=None,
verbalizer="{{input}}",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason these defaults are now being set here? I.e., the calling module is always going to pass them in positionally and override these defaults, right?

Copy link
Collaborator Author

@Ssukriti Ssukriti Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want to deprecate all these 3 parameters (once I move create_hf_tuning function to this file , we wont need cls either. Other 2 are unused). So I just set them here , so we can stop setting them from calling module and remove them in next major release. The 3 will not be needed for Lora either and dont have to be set from calling module.

config_kwargs = tuning_config.to_dict()
log.info("<NLP61012781I>", f"Parameters used: {config_kwargs}")
config_params = _filter_params_for_prompt_config(tuning_config, config_kwargs)
del config_params["output_model_types"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am probably missing something - what is the reason for deleting this?

Copy link
Collaborator Author

@Ssukriti Ssukriti Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have this parameter in our TuningConfig data model (it was there in prompt tuning, so I copied it to Lora Tuning. I didnt dig into why we need it) , but it is not in the HF prompttuningConfig or Loraconfig , so we have to remove it before passing it to get the HF TuningConfig. The way it is removed for prompt tuning currently in https://github.com/caikit/caikit-nlp/blob/main/caikit_nlp/modules/text_generation/peft_prompt_tuning.py#L798C16-L798C16 , is by copying all the relevant parameters in another dict which does not include output_model_types.

I chose to delete instead as only 1 parameter was different. I should have checked if key exists before deleting though to avoid key error. I will add that check, thanks for observing this

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
@Ssukriti Ssukriti marked this pull request as draft December 8, 2023 23:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants