You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
According to Pytorch XLA bfloat16, environment variable XLA_USE_BF16=1 will convert all torch.float and torch.double to bfloat16 in TPU, so in theory the model can fit x2 batch size.
Currently I'm using accelerate on TPU VM (on both v38 and v28), and it worked as expect. Then I experiment XLA_USE_BF16=1 with double the batch size (and the same code base), my code still work without throwing any error, it also run much faster on both v38 and v28:
TPU
dtype
batch size (1 core)
seconds/epoch (stable)
V38
float32
64
99
V28
float32
32
158
V38
bfloat16
128
61
V28
bfloat16
64
120
But the losses decrease a much slower, for example when training on float32, losses of first 5 epochs is:
1.48, 1.31, 1.26, 1.19, 1.18
But on bfloat16, losses of first 5 epoch is: 3.42, 3.38, 2.94, 2.5, 2.03
I also tried to change config fp16 = true but it throw Error.
So it seem there're some problems (I guess it could be gradient scaling not correctly implemented for bfloat16).
I wonder if there're something wrong from my side or bfloat16 haven't been supported yet?
Once again, thank you so much for creating this beautiful library. I tried following torch XLA but can't figure out how to start training, even messed up my code base. But with some simple code change, accelerate just automagically work out of the box (and also really fast on TPU).
The text was updated successfully, but these errors were encountered:
Hi @Luvata and thanks for the praise. fp16=True is not supported on TPU so that won't work. I'm sorry to hear your training is not going properly with bfloat16s but as far as I know, there is nothing else to do apart from setting that env variable.
If you encounter any resource that could tell us more on how to train with bfloat16 on TPUs, please do let me know as I would love the support to be as easy as possible :-)
According to Pytorch XLA bfloat16, environment variable
XLA_USE_BF16=1
will convert alltorch.float
andtorch.double
tobfloat16
in TPU, so in theory the model can fit x2 batch size.Currently I'm using accelerate on TPU VM (on both v38 and v28), and it worked as expect. Then I experiment
XLA_USE_BF16=1
with double the batch size (and the same code base), my code still work without throwing any error, it also run much faster on both v38 and v28:But the losses decrease a much slower, for example when training on float32, losses of first 5 epochs is:
1.48, 1.31, 1.26, 1.19, 1.18
But on bfloat16, losses of first 5 epoch is: 3.42, 3.38, 2.94, 2.5, 2.03
I also tried to change config
fp16 = true
but it throw Error.So it seem there're some problems (I guess it could be gradient scaling not correctly implemented for bfloat16).
I wonder if there're something wrong from my side or bfloat16 haven't been supported yet?
Once again, thank you so much for creating this beautiful library. I tried following torch XLA but can't figure out how to start training, even messed up my code base. But with some simple code change,
accelerate
just automagically work out of the box (and also really fast on TPU).The text was updated successfully, but these errors were encountered: