Skip to content

feat: automatically select LoRA modules when none are provided#166

Merged
mergify[bot] merged 1 commit intoinstructlab:mainfrom
RobotSail:auto-lora
Oct 4, 2024
Merged

feat: automatically select LoRA modules when none are provided#166
mergify[bot] merged 1 commit intoinstructlab:mainfrom
RobotSail:auto-lora

Conversation

@RobotSail
Copy link
Member

In the current version of the training library, we have the default value of target_modules set to
a list oflayer names which are implementation-specific and may not reflect what a given model actually
uses for the layer names. Furthermore, the default is also a subset of all projection layers in most models,
and the recommendation is generally to use all of these layers when injecting low rank adapters.

This commit resolves that issue by introducing logic to automatically resolve the target modules
and default to using all of them when they are not provided. This commit also adds validation logic
which indicates when some of the provided modules do not exist in the model. To go a step further,
the training library will also now error out when none of the provided target modules exist in the model,
supplying the user with additional context on which modules exist and how they could resolve the error

Signed-off-by: Oleg S ec2-user@ip-10-0-24-47.us-east-2.compute.internal

@RobotSail
Copy link
Member Author

resolves #164

"""
Given a pretrained model, returns all of the projection layers (matching '_proj')
"""
proj_layers = set(name.split('.')[-1] for name, _ in model.named_modules() if name.endswith("_proj"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is for llama only its fine, but in general models do not always have the naming k_proj, v_proj, etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another alternative if you want to target all the linears, is to do isinstance (mod, torch.nn.Linear)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianlim The models we actively support are listed here:

assert model.__class__.__name__ in [

When I looked at their list of layers, all of them had k_proj, q_proj, v_proj, o_proj, so the assumption is true at least for supported models.

You're right though, we could go down the path of targeting all linear layers. I just have two questions about this approach:

  1. How would this affect the memory requirements?
  2. What would be the impact on training times? If we are targeting more modules for LoRA, we would potentially be dropping more pretrained weights in favor of our LoRA approximations - how would this impact the loss curve?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is as you said. That you are targeting all the proj, then it should be equivalent to putting a Lora adapter on all linears

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianlim Not necessarily. Some models will use Linear layers which are not explicitly labeled as projections. For example, in starcoder-3b, these account roughly 1.2B parameters:

Screenshot 2024-08-19 at 1 16 02 PM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should NOT use model specific names to do LoRA since model architectures are subject to change, and we might even start supporting too many of them..

@RobotSail RobotSail requested review from aldopareja and removed request for cdoern August 19, 2024 14:15
@RobotSail RobotSail force-pushed the auto-lora branch 3 times, most recently from e2ab9dd to dd1cb74 Compare August 19, 2024 19:14
@ktam3 ktam3 added the jira label Aug 27, 2024
Copy link
Collaborator

@Maxusmusti Maxusmusti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mergify mergify bot added the one-approval label Aug 29, 2024
Copy link
Contributor

@aldopareja aldopareja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this look alright, just a minor comment that can be ignored for the moment.

"""
Given a pretrained model, returns all of the projection layers (matching '_proj')
"""
proj_layers = set(name.split('.')[-1] for name, _ in model.named_modules() if name.endswith("_proj"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should NOT use model specific names to do LoRA since model architectures are subject to change, and we might even start supporting too many of them..

)
command.extend(train_args.lora.target_modules)
if train_args.lora.target_modules:
command.extend(train_args.lora.target_modules)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have this only for granite models?, how about non-granite models?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifying the target modules? It should be fine for all models, since we may want to target different modules depending on what we want them to learn

@mergify mergify bot removed the one-approval label Aug 29, 2024
Copy link
Contributor

@JamesKunstle JamesKunstle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@mergify
Copy link
Contributor

mergify bot commented Sep 26, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @RobotSail please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 26, 2024
@RobotSail
Copy link
Member Author

@mergify rebase

@mergify
Copy link
Contributor

mergify bot commented Sep 30, 2024

rebase

☑️ Nothing to do

Details
  • -conflict [📌 rebase requirement]
  • -closed [📌 rebase requirement]
  • queue-position = -1 [📌 rebase requirement]
  • any of:
    • #commits-behind > 0 [📌 rebase requirement]
    • #commits > 1 [📌 rebase requirement]
    • -linear-history [📌 rebase requirement]

@JamesKunstle
Copy link
Contributor

@RobotSail Is this good to go post-rebase?

@RobotSail
Copy link
Member Author

@JamesKunstle Yeah it should be. Mergify didn't merge it automatically

@ktam3 ktam3 linked an issue Oct 2, 2024 that may be closed by this pull request
In the current version of the training library, we have the default value of target_modules set to
a list oflayer names which are implementation-specific and may not reflect what a given model actually
uses for the layer names. Furthermore, the default is also a subset of all projection layers in most models,
and the recommendation is generally to use all of these layers when injecting low rank adapters.

This commit resolves that issue by introducing logic to automatically resolve the target modules
and default to using all of them when they are not provided. This commit also adds validation logic
which indicates when some of the provided modules do not exist in the model. To go a step further,
the training library will also now error out when none of the provided target modules exist in the model,
supplying the user with additional context on which modules exist and how they could resolve the error

Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
@mergify mergify bot removed the needs-rebase label Oct 4, 2024
@mergify mergify bot merged commit 8fc555c into instructlab:main Oct 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Provide advisory when lora target modules dont exist

6 participants