Skip to content

Adding FSDP Support to Training Library#213

Merged
Maxusmusti merged 47 commits intomainfrom
ap/accelerate-fsdp-tmp2
Sep 26, 2024
Merged

Adding FSDP Support to Training Library#213
Maxusmusti merged 47 commits intomainfrom
ap/accelerate-fsdp-tmp2

Conversation

@aldopareja
Copy link
Contributor

@aldopareja aldopareja commented Sep 18, 2024

Adds support for FSDP and FSDP w/ CPU Offloading.

Introduces accelerate as a distributed backend abstraction (for FSDP/DeepSpeed)
Also fixes mistral template and cleans up data processing.

-Mustafa

@mergify mergify bot added the ci-failure label Sep 18, 2024
@aldopareja aldopareja force-pushed the ap/accelerate-fsdp-tmp2 branch from 560c2ec to 0b4d516 Compare September 18, 2024 19:11
@mergify mergify bot added ci-failure dependencies Pull requests that update a dependency file and removed ci-failure labels Sep 18, 2024
@Maxusmusti Maxusmusti changed the title Ap/accelerate fsdp tmp2 Adding FSDP Support to Training Library Sep 24, 2024
@mergify mergify bot added ci-failure CI/CD Affects CI/CD configuration and removed ci-failure labels Sep 24, 2024
This was referenced Sep 24, 2024
@mergify
Copy link
Contributor

mergify bot commented Sep 25, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @aldopareja please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
@mergify mergify bot added documentation Improvements or additions to documentation ci-failure labels Sep 25, 2024
…ining_backend to TrainingArgs.distributed_backend and DistributedTrainingBackend to DistributedBackend

Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
@RobotSail RobotSail force-pushed the ap/accelerate-fsdp-tmp2 branch from e2b4ae4 to 95eb2c0 Compare September 25, 2024 17:24
@mergify mergify bot removed the ci-failure label Sep 25, 2024
Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com>
@mergify mergify bot added ci-failure and removed ci-failure labels Sep 25, 2024
Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com>
@mergify mergify bot added ci-failure and removed ci-failure labels Sep 25, 2024
cpu_offload_optimizer_pin_memory=False,
)
)
fsdp_options: FSDPOptions = Field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be a factory? I think it can just be an assignment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm following the current convention set by DeepSpeedOptions in the file, so imo if we want to change this, we should make a follow-up PR that updates both of them

reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to expose this ever? This adds a bit of memory overhead for some performance- I think customarily it's probably a default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, I think it's fine for now, but I will open an issue to track this, as I'm not sure how much of a performance hit compared to memory gain this option will be for us. Might be a nice bonus trick to avoid offloading in some configurations if performance isn't horrendous

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracked in #228

}
return ds_config
def setup_optimizer(args, model):
if args.distributed_training_framework == "fsdp":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typical way to do this is via this pattern:

Suggested change
if args.distributed_training_framework == "fsdp":
if DistributedBackend(args.distributed_training_framework) == DistributedBackend.FSDP:

This collects "magic strings" like "fsdp" would be into the Enum object.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: it actually has to be DistributedBackend.FSDP.value, since by this point the args have gone through the main_ds argparse post-torchrun and args.distributed_training_framework is just a string

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in latest commit

model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
)
accelerator = setup_accelerator(args, model, grad_accum)
if args.distributed_training_framework == "fsdp":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same enum trick here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: it actually has to be DistributedBackend.FSDP.value, since by this point the args have gone through the main_ds argparse post-torchrun and args.distributed_training_framework is just a string

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in latest commit

),
lr_scheduler=lr_scheduler,
dist_init_required=True,
model, optimizer, _, lr_scheduler = accelerator.prepare(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see here that we're "double preparing" the model- is that okay? Is Accelerate smart enough to handle this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have verified that it is, originally I had some conditionals to avoid it but accelerate was one step ahead

global_grad_norm = accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen this here conventionally, only at the top of the training loop. I guess it can be either place. I also see that this is where they put it in the docs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it aint broke 🤷🏻‍♂️

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

Copy link
Contributor

@JamesKunstle JamesKunstle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO nothing that I noticed is blocking an approval. The only thing that I really want is for this PR to be rebased as a single commit so the history is a bit neater. Once that's done I'll approve!

Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com>
Copy link
Contributor

@JamesKunstle JamesKunstle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI/CD Affects CI/CD configuration dependencies Pull requests that update a dependency file documentation Improvements or additions to documentation hold

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants