Skip to content

Commit

Permalink
Add compatibility for tensorflow and pytorch Dataset objects (#311)
Browse files Browse the repository at this point in the history
* validation_func docstring

* torch,tf compatibility+tests

* keras test

* skip tests if python < 3.7

* pytorch numpy int bug on windows

* make tensorflow test work on windows

* move tf env variable setting

* pytorch test increase epochs

* install cpu-tensorflow on windows CI

* torch test optimizer to adam

* fix bugs in shuffled TF dataset

* dummy unit test for TF on windows

* dummy code for TF windows testing

* deal with np.int bug on windows

* remove windows debugging code

* docstrings for new functionality

* address merge conflicts

* reformat after merge

* addressed comments
  • Loading branch information
jwmueller committed Jul 28, 2022
1 parent a0a8e15 commit b93fdeb
Show file tree
Hide file tree
Showing 8 changed files with 637 additions and 80 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ jobs:
python-version: ${{ matrix.python }}
- name: Install development dependencies
run: pip install -r requirements-dev.txt
- name: Overwrite tensorflow version on Windows
if: matrix.os == 'windows-latest'
run: |
pip uninstall -y tensorflow
pip install tensorflow-cpu
- name: Install cleanlab
run: pip install -e .
- name: Test with coverage
Expand Down
90 changes: 56 additions & 34 deletions cleanlab/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,31 +285,39 @@ def fit(
Parameters
----------
X : np.array or pd.DataFrame
X : np.ndarray or DatasetLike
Data features (i.e. training inputs for ML), typically an array of shape ``(N, ...)``,
where N is the number of examples. Sparse matrices are supported.
If not an array or DataFrame, then ``X`` must support list-based indexing:
where N is the number of examples.
Supported `DatasetLike` types beyond ``np.ndarray`` include:
``pd.DataFrame``, ``scipy.sparse.csr_matrix``, ``torch.utils.data.Dataset``, ``tensorflow.data.Dataset``,
or any dataset object ``X`` that supports list-based indexing:
``X[index_list]`` to select a subset of training examples.
The classifier that this instance was initialized with,
``clf``, must be able to fit() and predict() data with this format.
``clf``, must be able to fit() and predict() data of this format.
labels : np.array or pd.Series
Note
----
If providing `X` as a ``tensorflow.data.Dataset``,
make sure ``shuffle()`` has been called before ``batch()`` (if shuffling)
and no other order-destroying operation (eg. ``repeat()``) has been applied.
labels : array_like
An array of shape ``(N,)`` of noisy labels, i.e. some labels may be erroneous.
Elements must be in the set 0, 1, ..., K-1, where K is the number of classes.
Supported `array_like` types include: ``np.ndarray``, ``pd.Series``, or ``list``.
pred_probs : np.array, optional
pred_probs : np.ndarray, optional
An array of shape ``(N, K)`` of model-predicted probabilities,
``P(label=k|x)``. Each row of this matrix corresponds
to an example `x` and contains the model-predicted probabilities that
`x` belongs to each possible class, for each of the K classes. The
columns must be ordered such that these probabilities correspond to
class 0, 1, ..., K-1. `pred_probs` should have been computed using 3 (or
higher) fold cross-validation.
columns must be ordered such that these probabilities correspond to class 0, 1, ..., K-1.
`pred_probs` should be :ref:`out-of-sample, eg. computed via cross-validation <pred_probs_cross_val>`.
Note
----
If you are not sure, leave ``pred_probs=None`` (the default) and it
will be computed for you using cross-validation with the model.
will be computed for you using cross-validation with the provided model.
thresholds : array_like, optional
An array of shape ``(K, 1)`` or ``(K,)`` of per-class threshold
Expand All @@ -324,27 +332,27 @@ class 0, 1, ..., K-1. `pred_probs` should have been computed using 3 (or
k. This is not used for pruning/filtering, only for estimating the
noise rates using confident counts.
noise_matrix : np.array, optional
noise_matrix : np.ndarray, optional
An array of shape ``(K, K)`` representing the conditional probability
matrix ``P(label=k_s | true label=k_y)``, the
fraction of examples in every class, labeled as every other class.
Assumes columns of `noise_matrix` sum to 1.
inverse_noise_matrix : np.array, optional
inverse_noise_matrix : np.ndarray, optional
An array of shape ``(K, K)`` representing the conditional probability
matrix ``P(true label=k_y | label=k_s)``,
the estimated fraction observed examples in each class ``k_s``
that are mislabeled examples from every other class ``k_y``,
Assumes columns of `inverse_noise_matrix` sum to 1.
label_issues : pd.DataFrame or np.array, optional
label_issues : pd.DataFrame or np.ndarray, optional
Specifies the label issues for each example in dataset.
If ``pd.DataFrame``, must be formatted as the one returned by:
:py:meth:`CleanLearning.find_label_issues
<cleanlab.classification.CleanLearning.find_label_issues>` or
:py:meth:`CleanLearning.get_label_issues
<cleanlab.classification.CleanLearning.get_label_issues>`.
If ``np.array``, must contain either boolean `label_issues_mask` as output by:
If ``np.ndarray``, must contain either boolean `label_issues_mask` as output by:
default :py:func:`filter.find_label_issues <cleanlab.filter.find_label_issues>`,
or integer indices as output by
:py:func:`filter.find_label_issues <cleanlab.filter.find_label_issues>`
Expand All @@ -356,10 +364,11 @@ class 0, 1, ..., K-1. `pred_probs` should have been computed using 3 (or
Caution: If you provide `label_issues` without having previously called
:py:meth:`self.find_label_issues<cleanlab.classification.CleanLearning.find_label_issues>`,
e.g. as a ``np.array``, then some functionality like training with sample weights may be disabled.
e.g. as a ``np.ndarray``, then some functionality like training with sample weights may be disabled.
sample_weight : array-like of shape (N,), optional
Array of weights that are assigned to individual samples.
sample_weight : array_like, optional
Array of weights with shape ``(N,)`` that are assigned to individual samples,
assuming total number of examples in dataset is `N`.
If not provided, samples may still be weighted by the estimated noise in the class they are labeled as.
clf_kwargs : dict, optional
Expand All @@ -375,36 +384,42 @@ class 0, 1, ..., K-1. `pred_probs` should have been computed using 3 (or
validation_func : callable, optional
Optional callable function that takes two arguments, `X_val`, `y_val`, and returns a dict
of keyword arguments passed into to `clf.fit()` which may be functions of the validation
of keyword arguments passed into to ``clf.fit()`` which may be functions of the validation
data in each cross-validation fold. Specifies how to map the validation data split in each
cross-validation fold into the appropriate format to pass into `clf`'s ``fit()`` method.
e.g. if your model's ``fit()`` method is call using `clf.fit(X, y, X_validation, y_validation)`,
then you could set `validation_func = f` where
`def f(X_val, y_val): return {"X_validation": X_val, "y_validation": y_val}`
cross-validation fold into the appropriate format to pass into `clf`'s ``fit()`` method, assuming
``clf.fit()`` can utilize validation data if it is appropriately passed in (eg. for early-stopping).
Eg. if your model's ``fit()`` method is called using ``clf.fit(X, y, X_validation, y_validation)``,
then you could set ``validation_func = f`` where
``def f(X_val, y_val): return {"X_validation": X_val, "y_validation": y_val}``
Note that `validation_func` will be ignored in the final call to `clf.fit()` on the
cleaned subset of the data. This argument is only for allowing `clf` to access the
validation data in each cross-validation fold (eg. for early-stopping or hyperparameter-selection
purposes). If you want to pass in validation data even in the final training of `clf.fit()`
purposes). If you want to pass in validation data even in the final training call to ``clf.fit()``
on the cleaned data subset, you should explicitly pass in that data yourself
(eg. via `clf_final_kwargs` or `clf_kwargs`).
y: np.array or pd.Series, optional
Alternative argument that can be specified instead of `labels`. Specifying `y` has the same effect as specifying `labels`, and is offered as an alternative for compatibility with sklearn.
y: array_like, optional
Alternative argument that can be specified instead of `labels`.
Specifying `y` has the same effect as specifying `labels`,
and is offered as an alternative for compatibility with sklearn.
Returns
-------
CleanLearning
``self`` - Fitted estimator that has all the same methods as any sklearn estimator.
self : CleanLearning
Fitted estimator that has all the same methods as any sklearn estimator.
After calling ``self.fit()``, this estimator also stores extra attributes such as:
After calling ``self.fit()``, this estimator also stores a few extra useful attributes, in particular
`self.label_issues_df`: a ``pd.DataFrame`` accessible via
* *self.label_issues_df*: a ``pd.DataFrame`` accessible via
:py:meth:`get_label_issues
<cleanlab.classification.CleanLearning.get_label_issues>`
of similar format as the one returned by: :py:meth:`CleanLearning.find_label_issues<cleanlab.classification.CleanLearning.find_label_issues>`.
See documentation of :py:meth:`CleanLearning.find_label_issues<cleanlab.classification.CleanLearning.find_label_issues>`
for column descriptions.
After calling ``self.fit()``, `self.label_issues_df` may also contain an extra column:
* *sample_weight*: Numeric values that were used to weight examples during
Expand All @@ -418,6 +433,13 @@ class 0, 1, ..., K-1. `pred_probs` should have been computed using 3 (or
In other words, examples with label issues were removed, so this weights the data proportionally
so that the classifier trains as if it had all the true labels,
not just the subset of cleaned data left after pruning out the label issues.
Note
----
If ``CleanLearning.fit()`` does not work for your data/model, you can run the same procedure yourself:
* Utilize :ref:`cross-validation <pred_probs_cross_val>` to get out-of-sample `pred_probs` for each example.
* Call :py:func:`filter.find_label_issues <cleanlab.filter.find_label_issues>` with `pred_probs`.
* Filter the examples with detected issues and train your model on the remaining data.
"""

if labels is not None and y is not None:
Expand Down Expand Up @@ -566,7 +588,7 @@ def predict(self, *args, **kwargs):
Parameters
----------
X : np.array
X : np.ndarray
An array of shape ``(N, ...)`` of test data."""

return self.clf.predict(*args, **kwargs)
Expand All @@ -577,7 +599,7 @@ def predict_proba(self, *args, **kwargs):
Parameters
----------
X : np.array
X : np.ndarray
An array of shape ``(N, ...)`` of test data."""

return self.clf.predict_proba(*args, **kwargs)
Expand All @@ -588,13 +610,13 @@ def score(self, X, y, sample_weight=None):
Parameters
----------
X : np.array
X : np.ndarray
An array of shape ``(N, ...)`` of test data.
y : np.array
y : np.ndarray
An array of shape ``(N,)`` or ``(N, 1)`` of test labels.
sample_weight : np.array, optional
sample_weight : np.ndarray, optional
An array of shape ``(N,)`` or ``(N, 1)`` used to weight each example when computing the score."""

if hasattr(self.clf, "score"):
Expand Down
54 changes: 34 additions & 20 deletions cleanlab/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
append_extra_datapoint,
train_val_split,
get_num_classes,
is_torch_dataset,
is_tensorflow_dataset,
)
from cleanlab.internal.latent_algebra import (
compute_inv_noise_matrix,
Expand Down Expand Up @@ -740,35 +742,47 @@ def estimate_confident_joint_and_cv_pred_proba(
pred_probs = np.zeros(shape=(len(labels), num_classes))

# Split X and labels into "cv_n_folds" stratified folds.
for k, (cv_train_idx, cv_holdout_idx) in enumerate(kf.split(X, labels)):
clf_copy = sklearn.base.clone(clf)
# CV indices only require labels: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html
# Only split based on labels because X may have various formats:
for k, (cv_train_idx, cv_holdout_idx) in enumerate(kf.split(X=labels, y=labels)):
try:
clf_copy = sklearn.base.clone(clf) # fresh untrained copy of the model
except Exception:
raise ValueError(
"`clf` must be clonable via: sklearn.base.clone(clf). "
"You can either implement instance method `clf.get_params()` to produce a fresh untrained copy of this model, "
"or you can implement the cross-validation outside of cleanlab "
"and pass in the obtained `pred_probs` to skip cleanlab's internal cross-validation"
)

# Select the training and holdout cross-validated sets.
X_train_cv, X_holdout_cv, s_train_cv, s_holdout_cv = train_val_split(
X, labels, cv_train_idx, cv_holdout_idx
)

# Ensure no missing classes in training set.
train_cv_classes = set(s_train_cv)
all_classes = set(range(num_classes))
# dict with keys: which classes missing, values: index of holdout data from this class that is duplicated:
missing_class_inds = {}
if len(train_cv_classes) != len(all_classes):
missing_classes = all_classes.difference(train_cv_classes)
warnings.warn(
"Duplicated some data across multiple folds to ensure training does not fail "
f"because these classes do not have enough data for proper cross-validation: {missing_classes}."
)
for missing_class in missing_classes:
# Duplicate one instance of missing_class from holdout data to the training data:
holdout_inds = np.where(s_holdout_cv == missing_class)[0]
dup_idx = holdout_inds[0]
s_train_cv = np.append(s_train_cv, s_holdout_cv[dup_idx])
# labels are always np.array so don't have to consider .iloc above
X_train_cv = append_extra_datapoint(
to_data=X_train_cv, from_data=X_holdout_cv, index=dup_idx
is_tf_or_torch_dataset = is_torch_dataset(X) or is_tensorflow_dataset(X)
if not is_tf_or_torch_dataset:
# Ensure no missing classes in training set.
train_cv_classes = set(s_train_cv)
all_classes = set(range(num_classes))
if len(train_cv_classes) != len(all_classes):
missing_classes = all_classes.difference(train_cv_classes)
warnings.warn(
"Duplicated some data across multiple folds to ensure training does not fail "
f"because these classes do not have enough data for proper cross-validation: {missing_classes}."
)
missing_class_inds[missing_class] = dup_idx
for missing_class in missing_classes:
# Duplicate one instance of missing_class from holdout data to the training data:
holdout_inds = np.where(s_holdout_cv == missing_class)[0]
dup_idx = holdout_inds[0]
s_train_cv = np.append(s_train_cv, s_holdout_cv[dup_idx])
# labels are always np.array so don't have to consider .iloc above
X_train_cv = append_extra_datapoint(
to_data=X_train_cv, from_data=X_holdout_cv, index=dup_idx
)
missing_class_inds[missing_class] = dup_idx

# Map validation data into appropriate format to pass into classifier clf
if validation_func is None:
Expand Down
80 changes: 80 additions & 0 deletions cleanlab/experimental/keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (C) 2017-2022 Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab. If not, see <https://www.gnu.org/licenses/>.

"""
A wrapper class you can use to make any Keras model compatible with cleanlab and sklearn.
Most of the instance methods of this class are the same as the ones for any Keras model,
see the Keras documentation for details.
This is a good example of making any bespoke neural network compatible with cleanlab.
You must have Tensorflow installed: https://www.tensorflow.org/install
Note: Tensorflow is only compatible with Python versions >= 3.7: https://www.tensorflow.org/install/pip#software_requirements
Tips:
* If this class lacks certain functionality, you can alternatively try scikeras: https://github.com/adriangb/scikeras
* To call ``fit()`` on a Tensorflow Dataset object with a Keras model, the Dataset should already be batched.
"""

import tensorflow as tf
import numpy as np


class KerasWrapper:
"""
KerasWrapper is instantiated in the same way as a ``tf.keras.models.Sequential`` object,
except for extra argument:
* *compile_kwargs*: dict of args to pass into ``model.compile()``
"""

def __init__(
self,
layers=None,
name=None,
compile_kwargs={"loss": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)},
):
self.layers = layers
self.name = name
self.compile_kwargs = compile_kwargs
self.net = None

def get_params(self, deep=True):
return {"layers": self.layers, "name": self.name, "compile_kwargs": self.compile_kwargs}

def fit(self, X, y=None, **kwargs):
"""Note that ``X`` dataset object must already contain the labels as is required for standard Keras fit.
You can provide the labels again here as argument ``y`` to be compatible with sklearn, but they are ignored.
"""
self.net = tf.keras.models.Sequential(self.layers, self.name)
self.net.compile(**self.compile_kwargs)
self.net.fit(X, **kwargs)

def predict_proba(self, X, apply_softmax=True, **kwargs):
"""Set `apply_softmax` to True to indicate your network only outputs logits not probabilities"""
if self.net is None:
raise ValueError("must call fit() before predict()")
pred_probs = self.net.predict(X, **kwargs)
if apply_softmax:
pred_probs = tf.nn.softmax(pred_probs, axis=1)
return pred_probs

def predict(self, X, **kwargs):
pred_probs = self.predict_proba(X, **kwargs)
return np.argmax(pred_probs, axis=1)

def summary(self, **kwargs):
self.net.summary(**kwargs)
Loading

0 comments on commit b93fdeb

Please sign in to comment.