You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I prepared the pg19 dataset and ran run_eval_needle.sh with different settings of parameters.
I have some questions about the experiment results and hope someone could help.
The model seemed to have correct predictions with mesh_dim='1,2,2,1' and dtype=float32 or bf16. However, if sequence parallelism was adopted with mesh_dim='1,1,2,2', the prediction was wrong.
mesh_dim
fp32
fp64
fp16
bf16
1,2,2,1
correct
wrong
wrong
correct
1,1,2,2
wrong
wrong
wrong
wrong
With further investigation, I found that the attention outputs began to have nan values from some middle layer and the final outputs of the networks are all nan.
...
attn_output shape: (array(1, dtype=int32), array(8192, dtype=int32), array(4096, dtype=int32)),
attn_output: [[[-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524
-0.00334746]
[-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524
-0.00334746]
[-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524
-0.00334746]
...
[ nan nan nan ... nan nan
nan]
[ nan nan nan ... nan nan
nan]
[ nan nan nan ... nan nan
nan]]]
...
attn_output shape: (array(1, dtype=int32), array(1, dtype=int32), array(4096, dtype=int32)), attn_output: [[[nan nan nan ... nan nan nan]]]
...
The inference speed was very slow with setting mesh_dim='1,1,2,2'. By counting the number of the printed attention outputs, I found that with mesh_dim='1,1,2,2', the FlaxLLaMABlock was called much more times than mesh_dim='1,2,2,1'
mesh_dim
times of calling FlaxLLaMABlock
1,2,2,1
544
1,1,2,2
65568
Did anyone have similar findings and could share insight on these questions?
The text was updated successfully, but these errors were encountered:
Hi, I prepared the
pg19
dataset and ranrun_eval_needle.sh
with different settings of parameters.I have some questions about the experiment results and hope someone could help.
Device: NVIDIA V100/A100 GPUs
Script:
Some questions about the experiment results:
mesh_dim='1,2,2,1'
anddtype=float32 or bf16
. However, if sequence parallelism was adopted withmesh_dim='1,1,2,2'
, the prediction was wrong.With further investigation, I found that the attention outputs began to have
nan
values from some middle layer and the final outputs of the networks are allnan
.mesh_dim='1,1,2,2'
. By counting the number of the printed attention outputs, I found that withmesh_dim='1,1,2,2'
, theFlaxLLaMABlock
was called much more times thanmesh_dim='1,2,2,1'
Did anyone have similar findings and could share insight on these questions?
The text was updated successfully, but these errors were encountered: