Skip to content

Commit

Permalink
Mypy typechecking fix for num_label_errors (#500)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulya-tkch committed Oct 8, 2022
1 parent c14c84f commit 57bcfd0
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions cleanlab/count.py
Expand Up @@ -113,15 +113,21 @@ class 0, 1, ..., K-1. `pred_probs` should have been computed using 3 (or

if confident_joint is None:
# Original non-calibrated counts of confidently correctly and incorrectly labeled examples.
confident_joint = compute_confident_joint(
labels=labels, pred_probs=pred_probs, calibrate=False
computed_confident_joint = compute_confident_joint(
labels=labels,
pred_probs=pred_probs,
calibrate=False,
)
else:
computed_confident_joint = confident_joint

assert isinstance(computed_confident_joint, np.ndarray)

if estimation_method is "off_diagonal":
num_issues = np.sum(confident_joint) - np.trace(confident_joint)
elif estimation_method is "off_diagonal_calibrated":
if estimation_method == "off_diagonal":
num_issues = np.sum(computed_confident_joint) - np.trace(computed_confident_joint)
elif estimation_method == "off_diagonal_calibrated":
# Estimate_joint calibrates the row sums to match the prior distribution of given labels and normalizes to sum to 1
joint = estimate_joint(labels, pred_probs, confident_joint=confident_joint)
joint = estimate_joint(labels, pred_probs, confident_joint=computed_confident_joint)
frac_issues = 1.0 - joint.trace()
num_issues = np.rint(frac_issues * len(labels)).astype(int)
else:
Expand Down

0 comments on commit 57bcfd0

Please sign in to comment.