Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU - Bfloat16 support #185

Closed
Luvata opened this issue Oct 15, 2021 · 2 comments
Closed

TPU - Bfloat16 support #185

Luvata opened this issue Oct 15, 2021 · 2 comments

Comments

@Luvata
Copy link

Luvata commented Oct 15, 2021

According to Pytorch XLA bfloat16, environment variable XLA_USE_BF16=1 will convert all torch.float and torch.double to bfloat16 in TPU, so in theory the model can fit x2 batch size.

Currently I'm using accelerate on TPU VM (on both v38 and v28), and it worked as expect. Then I experiment XLA_USE_BF16=1 with double the batch size (and the same code base), my code still work without throwing any error, it also run much faster on both v38 and v28:

TPU dtype batch size (1 core) seconds/epoch (stable)
V38 float32 64 99
V28 float32 32 158
V38 bfloat16 128 61
V28 bfloat16 64 120

But the losses decrease a much slower, for example when training on float32, losses of first 5 epochs is:
1.48, 1.31, 1.26, 1.19, 1.18

But on bfloat16, losses of first 5 epoch is: 3.42, 3.38, 2.94, 2.5, 2.03

I also tried to change config fp16 = true but it throw Error.

So it seem there're some problems (I guess it could be gradient scaling not correctly implemented for bfloat16).

I wonder if there're something wrong from my side or bfloat16 haven't been supported yet?

Once again, thank you so much for creating this beautiful library. I tried following torch XLA but can't figure out how to start training, even messed up my code base. But with some simple code change, accelerate just automagically work out of the box (and also really fast on TPU).

@sgugger
Copy link
Collaborator

sgugger commented Oct 15, 2021

Hi @Luvata and thanks for the praise. fp16=True is not supported on TPU so that won't work. I'm sorry to hear your training is not going properly with bfloat16s but as far as I know, there is nothing else to do apart from setting that env variable.

If you encounter any resource that could tell us more on how to train with bfloat16 on TPUs, please do let me know as I would love the support to be as easy as possible :-)

@Luvata
Copy link
Author

Luvata commented Oct 15, 2021

Thanks @sgugger for your fast reply, I will try training for longer and see if bfloat16 is truly affect the convergence.

@Luvata Luvata closed this as completed Oct 15, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants