From 2b25e97b3d1d5c395986bdaba0e355b7893d69bd Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 28 May 2023 12:22:57 +0200 Subject: [PATCH] Bugfix: check_tensor_sanity use of .numpy() only possible in eager mode. Added workaround unsing tf.print, but couldn't get it to work with the logger object. Therefore the two cases are handled differently for now, but removing the if statement would be favorable. --- bayesflow/helper_functions.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/bayesflow/helper_functions.py b/bayesflow/helper_functions.py index c143d466a..82ec876a7 100644 --- a/bayesflow/helper_functions.py +++ b/bayesflow/helper_functions.py @@ -29,13 +29,20 @@ def check_tensor_sanity(tensor, logger): """Tests for the present of NaNs and Infs in a tensor.""" - - if tf.reduce_any(tf.math.is_nan(tensor)): - num_na = tf.reduce_sum(tf.cast(tf.math.is_nan(tensor), tf.int8)).numpy() - logger.warn(f"Warning! Returned estimates contain {num_na} nan values!") - if tf.reduce_any(tf.math.is_inf(tensor)): - num_inf = tf.reduce_sum(tf.cast(tf.math.is_inf(tensor), tf.int8)).numpy() - logger.warn(f"Warning! Returned estimates contain {num_inf} inf values!") + if tf.executing_eagerly(): + if tf.reduce_any(tf.math.is_nan(tensor)): + num_na = tf.reduce_sum(tf.cast(tf.math.is_nan(tensor), tf.int8)).numpy() + logger.warn(f"Warning! Returned estimates contain {num_na} nan values!") + if tf.reduce_any(tf.math.is_inf(tensor)): + num_inf = tf.reduce_sum(tf.cast(tf.math.is_inf(tensor), tf.int8)).numpy() + logger.warn(f"Warning! Returned estimates contain {num_inf} inf values!") + else: + if tf.reduce_any(tf.math.is_nan(tensor)): + num_na = tf.reduce_sum(tf.cast(tf.math.is_nan(tensor), tf.int8)) + tf.print("Warning! Returned estimates contain", num_na, "nan values!") + if tf.reduce_any(tf.math.is_inf(tensor)): + num_inf = tf.reduce_sum(tf.cast(tf.math.is_inf(tensor), tf.int8)) + tf.print(f"Warning! Returned estimates contain", num_inf, "inf values!") def merge_left_into_right(left_dict, right_dict):