Skip to content

Commit

Permalink
merge conflicts and broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jwmueller committed Oct 30, 2022
1 parent 8d164a0 commit c82a3d3
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion cleanlab/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ class 0, 1, ..., K-1. `pred_probs` should have been computed using 3 (or
confident_thresholds : np.ndarray
An array of shape ``(K, )`` where K is the number of classes."""

labels = cleanlab.internal.validation.labels_to_array(labels)
labels = labels_to_array(labels)
all_classes = range(pred_probs.shape[1])
unique_classes = get_unique_classes(labels)
BIG_VALUE = 2
Expand Down
6 changes: 1 addition & 5 deletions cleanlab/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ class 0, 1, ..., K-1.
else:
assert n_jobs >= 1

# Number of examples in each class of labels
if multi_label:
return _find_label_issues_multilabel(
labels,
Expand All @@ -258,10 +257,7 @@ class 0, 1, ..., K-1.
verbose,
)

else:
label_counts = value_counts(labels)

# Number of classes
# Else this is standard multi-class classification
K = get_num_classes(
labels=labels, pred_probs=pred_probs, label_matrix=confident_joint, multi_label=multi_label
)
Expand Down
2 changes: 1 addition & 1 deletion cleanlab/internal/multilabel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_onehot_num_classes(labels, pred_probs=None):
"""Returns OneHot encoding of MultiLabel Data, and number of classes"""
num_classes = get_num_classes(labels=labels, pred_probs=pred_probs)
try:
y_one = int2onehot(labels)
y_one = int2onehot(labels, K=num_classes)
except TypeError:
raise ValueError(
"wrong format for labels, should be a list of list[indices], please check the documentation in find_label_issues for further information"
Expand Down
2 changes: 2 additions & 0 deletions cleanlab/internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def round_preserving_row_totals(confident_joint) -> np.ndarray:
).astype(int)


# TODO: move to multilabel_utils.py
def int2onehot(labels, K) -> np.ndarray:
"""Convert list of lists to a onehot matrix for multi-labels
Expand All @@ -282,6 +283,7 @@ def int2onehot(labels, K) -> np.ndarray:
return mlb.fit_transform(labels)


# TODO: move to multilabel_utils.py
def onehot2int(onehot_matrix) -> list:
"""Convert a onehot matrix for multi-labels to a list of lists of ints
Expand Down
7 changes: 4 additions & 3 deletions tests/test_filter_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def make_multilabel_data(
seed=1,
):
np.random.seed(seed=seed)
m = len(means) + len(
num_classes = len(means)
m = num_classes + len(
box_multilabels
) # number of classes by treating each multilabel as 1 unique label
n = sum(sizes)
Expand Down Expand Up @@ -172,7 +173,7 @@ def make_multi(X, Y, bx1, by1, bx2, by2, label_list):
ps = np.bincount(labels_idx) / float(len(labels_idx))
inv = compute_inv_noise_matrix(py, noise_matrix, ps=ps)

y_train = int2onehot(noisy_labels)
y_train = int2onehot(noisy_labels, K=num_classes)
clf = MultiOutputClassifier(LogisticRegression())
pyi = cross_val_predict(clf, X_train, y_train, method="predict_proba")
pred_probs = np.zeros(y_train.shape)
Expand Down Expand Up @@ -310,7 +311,7 @@ def test_calibrate_joint_multilabel():
labels=dataset["labels"],
multi_label=True,
)
y_one = int2onehot(dataset["labels"])
y_one = int2onehot(dataset["labels"], K=dataset["pred_probs"].shape[1])
# Check calibration
for class_num in range(0, len(calibrated_cj)):
label_counts = np.bincount(y_one[:, class_num])
Expand Down
3 changes: 2 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ def test_round_preserving_sum():


def test_one_hot():
num_classes = 4
labels = [[0], [0, 1], [0, 1], [2], [0, 2, 3]]
assert onehot2int(int2onehot(labels)) == labels
assert onehot2int(int2onehot(labels, K=num_classes)) == labels


def test_num_unique():
Expand Down

0 comments on commit c82a3d3

Please sign in to comment.