Skip to content

Commit

Permalink
add mse (#1110)
Browse files Browse the repository at this point in the history
* add mse

* fix mse

* fix compare_util

* fix compare_util

---------

Co-authored-by: hejunchao <hejunchao@canaan-creative.com>
Co-authored-by: HeJunchao100813 <HeJunchao100813@users.noreply.github.com>
  • Loading branch information
3 people committed Oct 24, 2023
1 parent 3a62f98 commit 355b8c4
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion tests/compare_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def cosine(gt: np.ndarray, pred: np.ndarray, *args):

result = (gt @ pred) / (np.linalg.norm(gt, 2) * np.linalg.norm(pred, 2))

# When tensor gt is a multiple of tensor pred, their similarity is also 1.
return -1 if math.isnan(result) else result


Expand All @@ -63,6 +62,29 @@ def euclidean(gt: np.ndarray, pred: np.ndarray, *args):
return np.linalg.norm(gt.reshape(-1) - pred.reshape(-1))


# def mse(gt: np.ndarray, pred: np.ndarray, *args):
# return np.mean((gt - pred) ** 2)

def divide(gt: np.ndarray, pred: np.ndarray):

# remove the zero values in the same location.
gt_mask = np.equal(gt, 0)
pred_mask = np.equal(pred, 0)
mask = gt_mask & pred_mask
gt = gt[~mask]
pred = pred[~mask]

# to avoid divide zero.
pred = np.where(np.equal(pred, 0), 1e-7, pred)

result = np.divide(gt, pred)
return result


def mean(gt: np.ndarray):
return np.mean(gt)


def allclose(gt: np.ndarray, pred: np.ndarray, thresh: float):
return np.allclose(gt, pred, atol=thresh)

Expand Down Expand Up @@ -122,6 +144,8 @@ def compare_binfile(result_path: Tuple[str, str],
compare_op = gt
if compare_op(similarity, threshold):
return False, similarity_info
if (mean(divide(gt_arr, pred_arr)) > 1.5 or mean(divide(gt_arr, pred_arr)) < 0.6):
return False, similarity_info, f"\nmaybe a case of multiples"
return True, similarity_info


Expand Down

0 comments on commit 355b8c4

Please sign in to comment.