-
Notifications
You must be signed in to change notification settings - Fork 841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revamp TPU internals to be more efficient #441
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Thanks for fixing up the TPU support with best practices! Looks great to me! 🔥 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR and running the benchmarks in Colab. Could you link to the documentation that shows those are best practices recommended by the torch XLA team? I haven't seen anything personally.
I'd like to make sure those changes are not speeding up the experience in Colab at the cost of the experience on TPU machines, so I'd like the same benchmarks to be run on a TPU VM before fully approving :-)
@sgugger this PR also modifies the TPU selection in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for revamping this!
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Revamp TPU internals
What does this add?
prepare_model
useMpModelWrapper
to distribute the model across all devices efficiently when usedprepare_dataloader
to create anMpDeviceLoader
allowing for the dataloaders to be more efficient.DataLoaderShard
to no longer doxm.mark_step
,MpDeviceLoader
will handle this for us. Instead if on TPU we set the device asNone
inprepare_dataloader
lettingMpDeviceLoader
take over with thedevice
when needed.Who is it for?
Why is it needed?
We currently have a number of "bad practices" in TPU handling due to improvements in the xla API. Here are some benchmarks I ran on a high-memory colab TPU instance with the nlp example script:
Baseline:
W/ MpModelWrapper and MpDeviceLoader:
W/ previous and
default_tensor_type
change:Roughly a 40% boost to speed once this is all done. I also saw some speed increases on the initial launch as well anecdotally, but I did not time them.
I also saw a 2x speedup when using the new DataLoader class vs the old one
About
default_tensor_type
:Though we do set the device to
bf16
which helps with automatically converting the tensors to the right types, if we addtorch.set_default_tensor_type('torch.FloatTensor')
there is a considerable speedup when it comes to training TPUs.I'm a bit unsure as to where it would be best to put this, as either something hidden when we initialize the Accelerator or as a util that should get called when you are training on TPUs, open to ideas
Anticipated maintenance burden? (What will happen in say, 3 months if something changes)
This is pretty stable and considered as "good practices" when running on TPUs w/ XLA, so it's unlikely these will change. However after this PR a subsequent PR will be opened to change the
nlp_example
andcv_example
notebook, as another best practice is to declare the model outside ofxm.spawn
. This includes the internals to make the model as memory efficient as we can, so that will be the last stage needed.