Skip to content

Commit

Permalink
Multilabel code restructuring with aggregation/scorer functions (#509)
Browse files Browse the repository at this point in the history
_find_label_issues_multilabel uses EMA instead of mean-pooling when computing label quality scores

Co-authored-by: Elías Snorrason <eliassno@gmail.com>
Co-authored-by: Jonas Mueller <1390638+jwmueller@users.noreply.github.com>
  • Loading branch information
3 people committed Oct 28, 2022
1 parent 00d7c21 commit b12d76b
Show file tree
Hide file tree
Showing 9 changed files with 696 additions and 332 deletions.
61 changes: 18 additions & 43 deletions cleanlab/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
get_num_classes,
is_torch_dataset,
is_tensorflow_dataset,
int2onehot,
_binarize_pred_probs_slice,
)
from cleanlab.internal.multilabel_utils import stack_complement, get_onehot_num_classes

from cleanlab.internal.latent_algebra import (
compute_inv_noise_matrix,
compute_py,
Expand Down Expand Up @@ -229,20 +229,13 @@ def _calibrate_confident_joint_multilabel(confident_joint: np.ndarray, labels: l
calibrated_cj : np.ndarray
An array of shape ``(K, 2, 2)`` of type float representing a valid
estimate of the joint *counts* of noisy and true labels in a one-vs-rest fashion."""
try:
y_one = int2onehot(labels)
except TypeError:
raise ValueError(
"wrong format for labels, should be a list of list[indices], please check the documentation in find_label_issues for further information"
)
y_one, num_classes = get_onehot_num_classes(labels)
num_classes = len(confident_joint)
calibrate_confident_joint_list: np.ndarray = np.ndarray(
shape=(num_classes, 2, 2), dtype=np.int64
)
for class_num in range(0, num_classes):
calibrate_confident_joint_list[class_num] = calibrate_confident_joint(
confident_joint[class_num], labels=y_one[:, class_num]
)
for class_num, (cj, y) in enumerate(zip(confident_joint, y_one.T)):
calibrate_confident_joint_list[class_num] = calibrate_confident_joint(cj, labels=y)

return calibrate_confident_joint_list

Expand Down Expand Up @@ -336,13 +329,7 @@ def _estimate_joint_multilabel(labels, pred_probs, *, confident_joint=None) -> n
An array of shape ``(K, 2, 2)`` representing an
estimate of the true joint distribution of noisy and true labels for each class, in a one-vs-rest format employed for multi-label settings.
"""
num_classes = get_num_classes(labels=labels, pred_probs=pred_probs)
try:
y_one = int2onehot(labels)
except TypeError:
raise ValueError(
"wrong format for labels, should be a list of list[indices], please check the documentation in find_label_issues for further information"
)
y_one, num_classes = get_onehot_num_classes(labels, pred_probs)
if confident_joint is None:
calibrated_cj = compute_confident_joint(
labels,
Expand All @@ -353,10 +340,10 @@ def _estimate_joint_multilabel(labels, pred_probs, *, confident_joint=None) -> n
else:
calibrated_cj = confident_joint
calibrated_cf: np.ndarray = np.ndarray((num_classes, 2, 2))
for class_num in range(num_classes):
pred_probabilitites = _binarize_pred_probs_slice(pred_probs, class_num)
for class_num, (label, pred_prob) in enumerate(zip(y_one.T, pred_probs.T)):
pred_probabilitites = stack_complement(pred_prob)
calibrated_cf[class_num] = estimate_joint(
labels=y_one[:, class_num],
labels=label,
pred_probs=pred_probabilitites,
confident_joint=calibrated_cj[class_num],
)
Expand Down Expand Up @@ -599,20 +586,14 @@ def _compute_confident_joint_multi_label(
where `indices_off_diagonal` is a list of arrays (one per class) and each array contains the indices of examples counted in off-diagonals of confident joint for that class.
"""

num_classes = get_num_classes(labels=labels, pred_probs=pred_probs)
try:
y_one = int2onehot(labels)
except TypeError:
raise ValueError(
"wrong format for labels, should be a list of list[indices], please check the documentation in find_label_issues for further information"
)
y_one, num_classes = get_onehot_num_classes(labels, pred_probs)
confident_joint_list: np.ndarray = np.ndarray(shape=(num_classes, 2, 2), dtype=np.int64)
indices_off_diagonal = []
for class_num in range(0, num_classes):
pred_probabilitites = _binarize_pred_probs_slice(pred_probs, class_num)
for class_num, (label, pred_prob) in enumerate(zip(y_one.T, pred_probs.T)):
pred_probabilitites = stack_complement(pred_prob)
if return_indices_of_off_diagonals:
cj, ind = compute_confident_joint(
labels=y_one[:, class_num],
labels=label,
pred_probs=pred_probabilitites,
multi_label=False,
thresholds=thresholds,
Expand All @@ -622,7 +603,7 @@ def _compute_confident_joint_multi_label(
indices_off_diagonal.append(ind)
else:
cj = compute_confident_joint(
labels=y_one[:, class_num],
labels=label,
pred_probs=pred_probabilitites,
multi_label=False,
thresholds=thresholds,
Expand Down Expand Up @@ -1395,17 +1376,11 @@ def _get_confident_thresholds_multilabel(
confident_thresholds : np.ndarray
An array of shape ``(K, 2, 2)`` where `K` is the number of classes, in a one-vs-rest format.
"""
num_classes = get_num_classes(labels=labels, pred_probs=pred_probs)
try:
y_one = int2onehot(labels)
except TypeError:
raise ValueError(
"wrong format for labels, should be a list of list[indices], please check the documentation in find_label_issues for further information"
)
y_one, num_classes = get_onehot_num_classes(labels, pred_probs)
confident_thresholds: np.ndarray = np.ndarray((num_classes, 2))
for class_num in range(num_classes):
pred_probabilitites = _binarize_pred_probs_slice(pred_probs, class_num)
for class_num, (label, pred_prob) in enumerate(zip(y_one.T, pred_probs.T)):
pred_probabilitites = stack_complement(pred_prob)
confident_thresholds[class_num] = get_confident_thresholds(
pred_probs=pred_probabilitites, labels=y_one[:, class_num]
pred_probs=pred_probabilitites, labels=labels
)
return confident_thresholds
55 changes: 21 additions & 34 deletions cleanlab/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@
from cleanlab.count import calibrate_confident_joint
from cleanlab.rank import (
order_label_issues,
get_label_quality_scores,
)
import cleanlab.internal.multilabel_scorer as ml_scorer

from cleanlab.internal.validation import assert_valid_inputs
from cleanlab.internal.util import (
value_counts,
round_preserving_row_totals,
int2onehot,
get_num_classes,
_binarize_pred_probs_slice,
)
from cleanlab.internal.multilabel_utils import stack_complement, get_onehot_num_classes

# tqdm is a module used to print time-to-complete when multiprocessing is used.
# This module is not necessary, and therefore is not a package dependency, but
Expand Down Expand Up @@ -446,16 +446,15 @@ def _find_label_issues_multilabel(
else:
label_issues_list, labels_list, pred_probs_list = per_class_issues
label_issues_idx = reduce(np.union1d, label_issues_list)
num_classes = get_num_classes(labels=labels, pred_probs=pred_probs)
label_quality_scores = np.zeros(len(labels))
for i in range(0, num_classes):
label_quality_scores += get_label_quality_scores(
labels=labels_list[i],
pred_probs=pred_probs_list[i],
method=return_indices_ranked_by,
**rank_by_kwargs,
)
label_quality_scores /= num_classes
y_one, num_classes = get_onehot_num_classes(labels, pred_probs)
label_quality_scores = ml_scorer.get_label_quality_scores(
labels=y_one,
pred_probs=pred_probs,
method=ml_scorer.MultilabelScorer(
base_scorer=ml_scorer.ClassLabelScorer.from_str(return_indices_ranked_by),
),
base_scorer_kwargs=rank_by_kwargs,
)
label_quality_scores_issues = label_quality_scores[label_issues_idx]
return label_issues_idx[np.argsort(label_quality_scores_issues)]

Expand Down Expand Up @@ -537,13 +536,7 @@ class 0, 1, ..., K-1. They need not sum to 1.0
`return_indices_ranked_by`.
"""
num_classes = get_num_classes(labels=labels, pred_probs=pred_probs)
try:
y_one = int2onehot(labels)
except TypeError:
raise ValueError(
"wrong format for labels, should be a list of list[indices], please check the documentation in find_label_issues for further information"
)
y_one, num_classes = get_onehot_num_classes(labels, pred_probs)
if return_indices_ranked_by is None:
bissues = np.zeros(y_one.shape).astype(bool)
else:
Expand All @@ -559,14 +552,14 @@ class 0, 1, ..., K-1. They need not sum to 1.0
confident_joint = None
elif confident_joint_shape != (num_classes, 2, 2):
raise ValueError("confident_joint should be of shape (num_classes, 2, 2)")
for class_num in range(0, num_classes):
pred_probabilitites = _binarize_pred_probs_slice(pred_probs, class_num)
for class_num, (label, pred_prob) in enumerate(zip(y_one.T, pred_probs.T)):
pred_probabilitites = stack_complement(pred_prob)
if confident_joint is None:
conf = None
else:
conf = confident_joint[class_num]
binary_label_issues = find_label_issues(
labels=y_one[:, class_num],
labels=label,
pred_probs=pred_probabilitites,
return_indices_ranked_by=return_indices_ranked_by,
frac_noise=frac_noise,
Expand All @@ -584,7 +577,7 @@ class 0, 1, ..., K-1. They need not sum to 1.0
bissues[:, class_num] = binary_label_issues
else:
label_issues_list.append(binary_label_issues)
labels_list.append(y_one[:, class_num])
labels_list.append(label)
pred_probs_list.append(pred_probabilitites)
if return_indices_ranked_by is None:
return bissues
Expand Down Expand Up @@ -741,18 +734,12 @@ def _find_predicted_neq_given_multilabel(labels, pred_probs) -> np.ndarray:
labeled with high confidence.
"""
try:
y_one = int2onehot(labels)
except TypeError:
raise ValueError(
"wrong format for labels, should be a list of list[indices], please check the documentation in find_label_issues for further information"
)
num_classes = get_num_classes(labels=labels, pred_probs=pred_probs)
y_one, num_classes = get_onehot_num_classes(labels, pred_probs)
pred_neq: np.ndarray = np.zeros(y_one.shape).astype(bool)
for class_num in range(num_classes):
pred_probabilitites = _binarize_pred_probs_slice(pred_probs, class_num)
for class_num, (label, pred_prob) in enumerate(zip(y_one.T, pred_probs.T)):
pred_probabilitites = stack_complement(pred_prob)
pred_neq[:, class_num] = find_predicted_neq_given(
labels=y_one[:, class_num], pred_probs=pred_probabilitites
labels=label, pred_probs=pred_probabilitites
)
return pred_neq.sum(axis=1) >= 1

Expand Down

0 comments on commit b12d76b

Please sign in to comment.