Skip to content

Implement saving FSDP with LoRA#295

Merged
mergify[bot] merged 3 commits intoinstructlab:mainfrom
RobotSail:lora-4
Nov 13, 2024
Merged

Implement saving FSDP with LoRA#295
mergify[bot] merged 3 commits intoinstructlab:mainfrom
RobotSail:lora-4

Conversation

@RobotSail
Copy link
Copy Markdown
Member

@RobotSail RobotSail commented Oct 23, 2024

Currently we cannot save LoRA models with FSDP, this PR addresses this limitation by instantiating a copy of the model on CPU, loading in the LoRA settings, loading the state dict after it has been gathered, and finally performing the same save as we do elsewhere throughout the codebase.

Resolves #241

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Oct 25, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @RobotSail please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

buffer_dtype=torch.bfloat16,
),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
backward_prefetch=BackwardPrefetch.BACKWARD_POST,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the impact of making this change for non-lora usage?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a performance/memory tradeoff. We should have it be configurable if possible, but I can limit it to only be this option when LoRA is used.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward prefetch vs. postfetch shouldn't be impacting correctness of LoRA, but not using prefetch could hurt default training times. I thing prefetch should be the default for non-lora cases.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 James, I'll create a follow-up issue to have this as a configurable setting.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread .pylintrc
Comment thread src/instructlab/training/utils.py
Comment thread tests/smoketest.sh Outdated

# ############### Read-only parameters ###############
MODEL_NAME="instructlab/granite-7b-lab"
MODEL_NAME="/home/ec2-user/.cache/huggingface/hub/models--instructlab--granite-7b-lab/snapshots/4fb6a018d68ab813b95c7f470e424a70f2f7e561"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't always be on ec2

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed it

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Nov 7, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @RobotSail please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Nov 7, 2024
LoRA models when training with FSDP as the distributed backend.
This is accomplished by creating a copy of the LoRA model on the CPU,
loading in the state dict after gathering it from the distributed model,
and saving after merging the adapters back into the original model.
Afterwards, the CPU copy is discarded and training continues.

Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
This commit adds a smoketest for testing LoRA + FSDP.

Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@JamesKunstle JamesKunstle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for adding the test!

Additionally introuce a max_seq_len parameter to support testing
on lower-end hardware.

Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
@mergify mergify Bot removed the one-approval label Nov 13, 2024
@nathan-weinberg nathan-weinberg removed the request for review from aldopareja November 13, 2024 20:52
@mergify mergify Bot merged commit 8a49747 into instructlab:main Nov 13, 2024
@RobotSail
Copy link
Copy Markdown
Member Author

#345

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI/CD Affects CI/CD configuration testing Relates to testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Get LoRA working with FSDP

4 participants