diff --git a/bayesflow/helper_functions.py b/bayesflow/helper_functions.py index c143d466a..82ec876a7 100644 --- a/bayesflow/helper_functions.py +++ b/bayesflow/helper_functions.py @@ -29,13 +29,20 @@ def check_tensor_sanity(tensor, logger): """Tests for the present of NaNs and Infs in a tensor.""" - - if tf.reduce_any(tf.math.is_nan(tensor)): - num_na = tf.reduce_sum(tf.cast(tf.math.is_nan(tensor), tf.int8)).numpy() - logger.warn(f"Warning! Returned estimates contain {num_na} nan values!") - if tf.reduce_any(tf.math.is_inf(tensor)): - num_inf = tf.reduce_sum(tf.cast(tf.math.is_inf(tensor), tf.int8)).numpy() - logger.warn(f"Warning! Returned estimates contain {num_inf} inf values!") + if tf.executing_eagerly(): + if tf.reduce_any(tf.math.is_nan(tensor)): + num_na = tf.reduce_sum(tf.cast(tf.math.is_nan(tensor), tf.int8)).numpy() + logger.warn(f"Warning! Returned estimates contain {num_na} nan values!") + if tf.reduce_any(tf.math.is_inf(tensor)): + num_inf = tf.reduce_sum(tf.cast(tf.math.is_inf(tensor), tf.int8)).numpy() + logger.warn(f"Warning! Returned estimates contain {num_inf} inf values!") + else: + if tf.reduce_any(tf.math.is_nan(tensor)): + num_na = tf.reduce_sum(tf.cast(tf.math.is_nan(tensor), tf.int8)) + tf.print("Warning! Returned estimates contain", num_na, "nan values!") + if tf.reduce_any(tf.math.is_inf(tensor)): + num_inf = tf.reduce_sum(tf.cast(tf.math.is_inf(tensor), tf.int8)) + tf.print(f"Warning! Returned estimates contain", num_inf, "inf values!") def merge_left_into_right(left_dict, right_dict):