From 57bcfd0a99a7880c45eef4e20a325c81add8f330 Mon Sep 17 00:00:00 2001 From: Ulyana Date: Fri, 7 Oct 2022 19:43:40 -0700 Subject: [PATCH] Mypy typechecking fix for num_label_errors (#500) --- cleanlab/count.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/cleanlab/count.py b/cleanlab/count.py index 4ac6a693a8..b404add6bb 100644 --- a/cleanlab/count.py +++ b/cleanlab/count.py @@ -113,15 +113,21 @@ class 0, 1, ..., K-1. `pred_probs` should have been computed using 3 (or if confident_joint is None: # Original non-calibrated counts of confidently correctly and incorrectly labeled examples. - confident_joint = compute_confident_joint( - labels=labels, pred_probs=pred_probs, calibrate=False + computed_confident_joint = compute_confident_joint( + labels=labels, + pred_probs=pred_probs, + calibrate=False, ) + else: + computed_confident_joint = confident_joint + + assert isinstance(computed_confident_joint, np.ndarray) - if estimation_method is "off_diagonal": - num_issues = np.sum(confident_joint) - np.trace(confident_joint) - elif estimation_method is "off_diagonal_calibrated": + if estimation_method == "off_diagonal": + num_issues = np.sum(computed_confident_joint) - np.trace(computed_confident_joint) + elif estimation_method == "off_diagonal_calibrated": # Estimate_joint calibrates the row sums to match the prior distribution of given labels and normalizes to sum to 1 - joint = estimate_joint(labels, pred_probs, confident_joint=confident_joint) + joint = estimate_joint(labels, pred_probs, confident_joint=computed_confident_joint) frac_issues = 1.0 - joint.trace() num_issues = np.rint(frac_issues * len(labels)).astype(int) else: