Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp TPU internals to be more efficient #441

Merged
merged 45 commits into from
Jun 14, 2022
Merged

Revamp TPU internals to be more efficient #441

merged 45 commits into from
Jun 14, 2022

Conversation

muellerzr
Copy link
Collaborator

@muellerzr muellerzr commented Jun 11, 2022

Revamp TPU internals

What does this add?

  • Make prepare_model use MpModelWrapper to distribute the model across all devices efficiently when used
  • Changes prepare_dataloader to create an MpDeviceLoader allowing for the dataloaders to be more efficient.
  • Changes DataLoaderShard to no longer do xm.mark_step, MpDeviceLoader will handle this for us. Instead if on TPU we set the device as None in prepare_dataloader letting MpDeviceLoader take over with the device when needed.
  • Allow for FP16 and BF16 precision types on the TPU via AcceleratorState since they are now supported

Who is it for?

  • Users of TPUs

Why is it needed?

We currently have a number of "bad practices" in TPU handling due to improvements in the xla API. Here are some benchmarks I ran on a high-memory colab TPU instance with the nlp example script:

Baseline:

  • Post warmup: 56 seconds

W/ MpModelWrapper and MpDeviceLoader:

  • Post warmup: 48 seconds

W/ previous and default_tensor_type change:

  • Post warmup: 34 to 36 seconds

Roughly a 40% boost to speed once this is all done. I also saw some speed increases on the initial launch as well anecdotally, but I did not time them.

I also saw a 2x speedup when using the new DataLoader class vs the old one

About default_tensor_type:

Though we do set the device to bf16 which helps with automatically converting the tensors to the right types, if we add torch.set_default_tensor_type('torch.FloatTensor') there is a considerable speedup when it comes to training TPUs.

I'm a bit unsure as to where it would be best to put this, as either something hidden when we initialize the Accelerator or as a util that should get called when you are training on TPUs, open to ideas

Anticipated maintenance burden? (What will happen in say, 3 months if something changes)

This is pretty stable and considered as "good practices" when running on TPUs w/ XLA, so it's unlikely these will change. However after this PR a subsequent PR will be opened to change the nlp_example and cv_example notebook, as another best practice is to declare the model outside of xm.spawn. This includes the internals to make the model as memory efficient as we can, so that will be the last stage needed.

@muellerzr muellerzr added enhancement New feature or request TPU Bug or feature on TPU platforms labels Jun 11, 2022
@muellerzr muellerzr requested a review from sgugger June 11, 2022 15:48
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 11, 2022

The documentation is not available anymore as the PR was closed or merged.

@tmabraham
Copy link
Contributor

Thanks for fixing up the TPU support with best practices! Looks great to me! 🔥

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR and running the benchmarks in Colab. Could you link to the documentation that shows those are best practices recommended by the torch XLA team? I haven't seen anything personally.

I'd like to make sure those changes are not speeding up the experience in Colab at the cost of the experience on TPU machines, so I'd like the same benchmarks to be run on a TPU VM before fully approving :-)

src/accelerate/accelerator.py Show resolved Hide resolved
src/accelerate/data_loader.py Show resolved Hide resolved
@muellerzr muellerzr requested a review from sgugger June 13, 2022 16:13
@muellerzr
Copy link
Collaborator Author

@sgugger this PR also modifies the TPU selection in AcceleratorState to actually use mixed precision types.

@muellerzr muellerzr marked this pull request as draft June 13, 2022 20:09
@muellerzr muellerzr marked this pull request as ready for review June 14, 2022 20:43
@muellerzr
Copy link
Collaborator Author

We're passing 😄
image

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for revamping this!

@muellerzr muellerzr requested a review from sgugger June 14, 2022 21:11
src/accelerate/state.py Outdated Show resolved Hide resolved
muellerzr and others added 2 commits June 14, 2022 17:33
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@muellerzr muellerzr merged commit 29eef23 into main Jun 14, 2022
@muellerzr muellerzr deleted the tpu-fixes branch June 14, 2022 21:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request TPU Bug or feature on TPU platforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants