From 3744793cafdb39840e36b4b9e317fef126c19229 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Fri, 21 Feb 2025 11:04:34 +0100 Subject: [PATCH] mc_calibration does not crash with VariableArray objects --- bayesflow/utils/comp_utils.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/bayesflow/utils/comp_utils.py b/bayesflow/utils/comp_utils.py index b21b3f03c..4d4b5c550 100644 --- a/bayesflow/utils/comp_utils.py +++ b/bayesflow/utils/comp_utils.py @@ -1,4 +1,5 @@ import numpy as np +from keras import ops from sklearn.calibration import calibration_curve @@ -16,9 +17,9 @@ def expected_calibration_error(m_true, m_pred, num_bins=10): Parameters ---------- - m_true : np.ndarray of shape (num_sim, num_models) + m_true : array of shape (num_sim, num_models) The one-hot-encoded true model indices. - m_pred : tf.tensor of shape (num_sim, num_models) + m_pred : array of shape (num_sim, num_models) The predicted posterior model probabilities. num_bins : int, optional, default: 10 The number of bins to use for the calibration curves (and marginal histograms). @@ -32,11 +33,9 @@ def expected_calibration_error(m_true, m_pred, num_bins=10): Each list contains two arrays of length (num_bins) with the predicted and true probabilities for each bin. """ - # Convert tf.Tensors to numpy, if passed - if type(m_true) is not np.ndarray: - m_true = m_true.numpy() - if type(m_pred) is not np.ndarray: - m_pred = m_pred.numpy() + # Convert tensors to numpy, if passed + m_true = ops.convert_to_numpy(m_true) + m_pred = ops.convert_to_numpy(m_pred) # Extract number of models and prepare containers n_models = m_true.shape[1]