Skip to content

Commit

Permalink
enable bf16 casting
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed Jun 4, 2023
1 parent a117722 commit 668d029
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/training/train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@


def to_bf16(t):
# return jax.tree_map(
# lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t
# )
return t

return jax.tree_map(
lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t
)



Expand Down

0 comments on commit 668d029

Please sign in to comment.