Skip to content

Commit

Permalink
Merge pull request #21169 from justinjfu/splash_precision_fix
Browse files Browse the repository at this point in the history
Disable bfloat16 on long seq lengths for splash attention kernel test
  • Loading branch information
justinjfu committed May 13, 2024
2 parents 35a512d + ebb9184 commit e4f3b3f
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tests/pallas/splash_attention_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,13 @@ def sequence_length_strategy(draw: Draw) -> tuple[int, int]:
def attention_strategy(draw: Draw) -> tuple[int, int, int, np.dtype]:
q_seq_len, kv_seq_len = draw(sequence_length_strategy())
head_dim = draw(hps.sampled_from([128, 256]))
dtype = draw(hps.sampled_from([np.dtype("float32"), np.dtype(jnp.bfloat16)]))
if q_seq_len >= 4096 and kv_seq_len >= 4096:
# Do not draw bfloat16 on longer sequence lengths, as this increases
# the risk of numerical precision errors causing false positives in
# tests.
dtype = np.dtype("float32")
else:
dtype = draw(hps.sampled_from([np.dtype("float32"), np.dtype(jnp.bfloat16)]))
return q_seq_len, kv_seq_len, head_dim, dtype


Expand Down

0 comments on commit e4f3b3f

Please sign in to comment.