Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mse #1110

Merged
merged 14 commits into from
Oct 24, 2023
26 changes: 25 additions & 1 deletion tests/compare_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def cosine(gt: np.ndarray, pred: np.ndarray, *args):

result = (gt @ pred) / (np.linalg.norm(gt, 2) * np.linalg.norm(pred, 2))

# When tensor gt is a multiple of tensor pred, their similarity is also 1.
return -1 if math.isnan(result) else result


Expand All @@ -63,6 +62,29 @@ def euclidean(gt: np.ndarray, pred: np.ndarray, *args):
return np.linalg.norm(gt.reshape(-1) - pred.reshape(-1))


# def mse(gt: np.ndarray, pred: np.ndarray, *args):
# return np.mean((gt - pred) ** 2)

def divide(gt: np.ndarray, pred: np.ndarray):

# remove the zero values in the same location.
gt_mask = np.equal(gt, 0)
pred_mask = np.equal(pred, 0)
mask = gt_mask & pred_mask
gt = gt[~mask]
pred = pred[~mask]

# to avoid divide zero.
pred = np.where(np.equal(pred, 0), 1e-7, pred)

result = np.divide(gt, pred)
return result


def mean(gt: np.ndarray):
return np.mean(gt)


def allclose(gt: np.ndarray, pred: np.ndarray, thresh: float):
return np.allclose(gt, pred, atol=thresh)

Expand Down Expand Up @@ -122,6 +144,8 @@ def compare_binfile(result_path: Tuple[str, str],
compare_op = gt
if compare_op(similarity, threshold):
return False, similarity_info
if (mean(divide(gt_arr, pred_arr)) > 1.5 or mean(divide(gt_arr, pred_arr)) < 0.6):
return False, similarity_info, f"\nmaybe a case of multiples"
return True, similarity_info


Expand Down
Loading