Implement saving FSDP with LoRA#295
Conversation
102e94c to
340326f
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
| buffer_dtype=torch.bfloat16, | ||
| ), | ||
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | ||
| backward_prefetch=BackwardPrefetch.BACKWARD_POST, |
There was a problem hiding this comment.
what is the impact of making this change for non-lora usage?
There was a problem hiding this comment.
This is a performance/memory tradeoff. We should have it be configurable if possible, but I can limit it to only be this option when LoRA is used.
There was a problem hiding this comment.
Backward prefetch vs. postfetch shouldn't be impacting correctness of LoRA, but not using prefetch could hurt default training times. I thing prefetch should be the default for non-lora cases.
There was a problem hiding this comment.
+1 James, I'll create a follow-up issue to have this as a configurable setting.
|
|
||
| # ############### Read-only parameters ############### | ||
| MODEL_NAME="instructlab/granite-7b-lab" | ||
| MODEL_NAME="/home/ec2-user/.cache/huggingface/hub/models--instructlab--granite-7b-lab/snapshots/4fb6a018d68ab813b95c7f470e424a70f2f7e561" |
There was a problem hiding this comment.
won't always be on ec2
|
This pull request has merge conflicts that must be resolved before it can be |
LoRA models when training with FSDP as the distributed backend. This is accomplished by creating a copy of the LoRA model on the CPU, loading in the state dict after gathering it from the distributed model, and saving after merging the adapters back into the original model. Afterwards, the CPU copy is discarded and training continues. Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
This commit adds a smoketest for testing LoRA + FSDP. Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
JamesKunstle
left a comment
There was a problem hiding this comment.
LGTM, thanks for adding the test!
Additionally introuce a max_seq_len parameter to support testing on lower-end hardware. Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
Currently we cannot save LoRA models with FSDP, this PR addresses this limitation by instantiating a copy of the model on CPU, loading in the LoRA settings, loading the state dict after it has been gathered, and finally performing the same save as we do elsewhere throughout the codebase.
Resolves #241