Skip to content

Commit

Permalink
Add support for multi-label + missing classes
Browse files Browse the repository at this point in the history
  • Loading branch information
cgnorthcutt committed Oct 28, 2022
1 parent 81c736f commit 4010478
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 17 deletions.
5 changes: 2 additions & 3 deletions cleanlab/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
append_extra_datapoint,
train_val_split,
get_num_classes,
get_unique_classes,
is_torch_dataset,
is_tensorflow_dataset,
)
Expand Down Expand Up @@ -1299,9 +1300,7 @@ class 0, 1, ..., K-1. `pred_probs` should have been computed using 3 (or
An array of shape ``(K, )`` where K is the number of classes."""

all_classes = range(pred_probs.shape[1])
unique_classes = (
np.unique([l for lst in labels for l in lst]) if multi_label else np.unique(labels)
)
unique_classes = get_unique_classes(labels, multi_label=multi_label)
# labels must be a valid np.ndarray.
labels = cleanlab.internal.validation.labels_to_array(labels)
# When all_classes != unique_classes the class threshold for the missing classes is set to
Expand Down
18 changes: 9 additions & 9 deletions cleanlab/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class 0, 1, ..., K-1.
# Prepare multiprocessing shared data
if n_jobs > 1:
if multi_label:
_labels = RawArray("I", int2onehot(labels).flatten()) # type: ignore
_labels = RawArray("I", int2onehot(labels, K).flatten()) # type: ignore
else:
_labels = RawArray("I", labels) # type: ignore
_label_counts = RawArray("I", label_counts)
Expand Down Expand Up @@ -367,7 +367,7 @@ class 0, 1, ..., K-1.
# Remove label issues if given label == model prediction
if multi_label:
pred = _multiclass_crossval_predict(labels, pred_probs)
labels = MultiLabelBinarizer().fit_transform(labels)
labels = int2onehot(labels, K)
else:
pred = pred_probs.argmax(axis=1)
for i, pred_label in enumerate(pred):
Expand Down Expand Up @@ -732,7 +732,7 @@ def _prune_by_count(k, args=None) -> np.ndarray:

label_issues_mask = np.zeros(len(pred_probs), dtype=bool)
pred_probs_k = pred_probs[:, k]
K = len(label_counts)
K = get_num_classes(labels, pred_probs, multi_label=multi_label)
if label_counts[k] <= min_examples_per_class: # No prune if not at least min_examples_per_class
warnings.warn(
f"May not flag all label issues in class: {k}, it has too few examples (see `min_examples_per_class` argument)"
Expand All @@ -751,7 +751,7 @@ def _prune_by_count(k, args=None) -> np.ndarray:
return label_issues_mask


def _multiclass_crossval_predict(labels, pyx) -> np.ndarray:
def _multiclass_crossval_predict(labels, pred_probs) -> np.ndarray:
"""Returns a numpy 2D array of one-hot encoded
multiclass predictions. Each row in the array
provides the predictions for a particular example.
Expand All @@ -764,25 +764,25 @@ def _multiclass_crossval_predict(labels, pyx) -> np.ndarray:
These are multiclass labels. Each list in the list contains all the
labels for that example.
pyx : np.ndarray (shape (N, K))
pred_probs : np.ndarray (shape (N, K))
P(label=k|x) is a matrix with K model-predicted probabilities.
Each row of this matrix corresponds to an example `x` and contains the model-predicted
probabilities that `x` belongs to each possible class.
The columns must be ordered such that these probabilities correspond to class 0,1,2,...
`pred_probs` should have been computed using 3 (or higher) fold cross-validation."""

from sklearn.metrics import f1_score

boundaries = np.arange(0.05, 0.9, 0.05)
labels_one_hot = MultiLabelBinarizer().fit_transform(labels)
K = get_num_classes(labels=labels, pred_probs=pred_probs, multi_label=True,)
labels_one_hot = int2onehot(labels, K)
f1s = [
f1_score(
labels_one_hot,
(pyx > boundary).astype(np.uint8),
(pred_probs > boundary).astype(np.uint8),
average="micro",
)
for boundary in boundaries
]
boundary = boundaries[np.argmax(f1s)]
pred = (pyx > boundary).astype(np.uint8)
pred = (pred_probs > boundary).astype(np.uint8)
return pred
13 changes: 8 additions & 5 deletions cleanlab/internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def value_counts(x, *, num_classes=None, multi_label=False) -> np.ndarray:
if num_classes <= max(unique_classes):
raise ValueError(f"Required: num_classes > max(x), but {num_classes} <= {max(x)}.")
# Add zero counts for all missing classes in [0, 1,..., num_classes-1]
missing_classes = get_missing_classes(x, num_classes=num_classes, multi_label=multi_label)
# multi_label=False regardless because x was flattened.
missing_classes = get_missing_classes(x, num_classes=num_classes, multi_label=False)
missing_counts = [(z, 0) for z in missing_classes]
# Return counts with zeros for all missing classes.
return np.array(list(zip(*sorted(list(zip(unique_classes, counts)) + missing_counts)))[1])
Expand All @@ -195,7 +196,6 @@ def get_missing_classes(labels, *, pred_probs=None, num_classes=None, multi_labe
"""Find which classes are present in ``pred_probs`` but not present in ``labels``.
See ``count.compute_confident_joint`` for parameter docstrings."""

if pred_probs is None and num_classes is None:
raise ValueError("Both pred_probs and num_classes are None. You must provide exactly one.")
if pred_probs is not None and num_classes is not None:
Expand Down Expand Up @@ -263,18 +263,21 @@ def round_preserving_row_totals(confident_joint) -> np.ndarray:
).astype(int)


def int2onehot(labels) -> np.ndarray:
def int2onehot(labels, K) -> np.ndarray:
"""Convert list of lists to a onehot matrix for multi-labels
Parameters
----------
labels: list of lists of integers
e.g. [[0,1], [3], [1,2,3], [1], [2]]
All integers from 0,1,...,K-1 must be represented."""
All integers from 0,1,...,K-1 must be represented.
K: int
The number of classes."""

from sklearn.preprocessing import MultiLabelBinarizer

mlb = MultiLabelBinarizer()
mlb = MultiLabelBinarizer(classes=range(K))
return mlb.fit_transform(labels)


Expand Down
29 changes: 29 additions & 0 deletions tests/test_filter_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,3 +531,32 @@ def test_missing_classes():
pred_probs = np.array([[0.9, 0.1, 0.0], [0.8, 0.1, 0.1], [0.1, 0.0, 0.9], [0.95, 0.0, 0.05]])
issues = filter.find_label_issues(labels, pred_probs)
assert np.all(issues == np.array([False, False, False, True]))


def test_removing_class_consistent_results():
# Note that only one label is class 1 (we're going to change it to class 2 later...)
labels = np.array([0, 0, 0, 0, 1, 2, 2, 2])
# Third example is a label error
pred_probs = np.array([
[0.9, 0.1, 0.0], [0.8, 0.1, 0.1], [0.1, 0.0, 0.9], [0.9, 0.0, 0.1],
[0.1, 0.3, 0.6], [0.1, 0.0, 0.9], [0.1, 0.0, 0.9], [0.1, 0.0, 0.9],
])
cj_with1 = count.compute_confident_joint(labels, pred_probs)
issues_with1 = filter.find_label_issues(labels, pred_probs)

labels_no1 = labels = np.array([0, 0, 0, 0, 2, 2, 2, 2]) # change 1 to 2 (class 1 is missing!)
cj_no1 = count.compute_confident_joint(labels, pred_probs)
issues_no1 = filter.find_label_issues(labels, pred_probs)

assert np.all(issues_with1 == issues_no1)
assert np.all(cj_with1 == [
[3, 0, 1],
[0, 1, 0],
[0, 0, 3],
])
# Check that the 1, 1 entry goes away and moves to 2, 2 (since we changed label 1 to 2)
assert np.all(cj_no1 == [
[3, 0, 1],
[0, 0, 0],
[0, 0, 4],
])

0 comments on commit 4010478

Please sign in to comment.