diff --git a/tests/compare_util.py b/tests/compare_util.py index 5e1eb56ba0..7d5a300118 100644 --- a/tests/compare_util.py +++ b/tests/compare_util.py @@ -51,7 +51,6 @@ def cosine(gt: np.ndarray, pred: np.ndarray, *args): result = (gt @ pred) / (np.linalg.norm(gt, 2) * np.linalg.norm(pred, 2)) - # When tensor gt is a multiple of tensor pred, their similarity is also 1. return -1 if math.isnan(result) else result @@ -63,6 +62,29 @@ def euclidean(gt: np.ndarray, pred: np.ndarray, *args): return np.linalg.norm(gt.reshape(-1) - pred.reshape(-1)) +# def mse(gt: np.ndarray, pred: np.ndarray, *args): +# return np.mean((gt - pred) ** 2) + +def divide(gt: np.ndarray, pred: np.ndarray): + + # remove the zero values in the same location. + gt_mask = np.equal(gt, 0) + pred_mask = np.equal(pred, 0) + mask = gt_mask & pred_mask + gt = gt[~mask] + pred = pred[~mask] + + # to avoid divide zero. + pred = np.where(np.equal(pred, 0), 1e-7, pred) + + result = np.divide(gt, pred) + return result + + +def mean(gt: np.ndarray): + return np.mean(gt) + + def allclose(gt: np.ndarray, pred: np.ndarray, thresh: float): return np.allclose(gt, pred, atol=thresh) @@ -122,6 +144,8 @@ def compare_binfile(result_path: Tuple[str, str], compare_op = gt if compare_op(similarity, threshold): return False, similarity_info + if (mean(divide(gt_arr, pred_arr)) > 1.5 or mean(divide(gt_arr, pred_arr)) < 0.6): + return False, similarity_info, f"\nmaybe a case of multiples" return True, similarity_info