From 668d029eef2feedb5b63a3028408200a7dd39235 Mon Sep 17 00:00:00 2001 From: Benjamin Fattori Date: Sun, 4 Jun 2023 14:08:16 +0100 Subject: [PATCH] enable bf16 casting --- src/training/train_functions.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/training/train_functions.py b/src/training/train_functions.py index b5024e8..9fe288b 100644 --- a/src/training/train_functions.py +++ b/src/training/train_functions.py @@ -11,11 +11,9 @@ def to_bf16(t): - # return jax.tree_map( - # lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t - # ) - return t - + return jax.tree_map( + lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t + )