Skip to content

Commit

Permalink
Merge branch 'cleanlab:master' into spurious_correlations
Browse files Browse the repository at this point in the history
  • Loading branch information
01PrathamS committed Nov 12, 2023
2 parents fb5aa35 + 5fbf6c0 commit f35b3c1
Show file tree
Hide file tree
Showing 17 changed files with 883 additions and 34 deletions.
3 changes: 2 additions & 1 deletion cleanlab/datalab/internal/issue_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def _resolve_required_args(self, pred_probs, features, knn_graph):
"label": {"pred_probs": pred_probs, "features": features},
"outlier": {"pred_probs": pred_probs, "features": features, "knn_graph": knn_graph},
"near_duplicate": {"features": features, "knn_graph": knn_graph},
"non_iid": {"features": features, "knn_graph": knn_graph},
"non_iid": {"pred_probs": pred_probs, "features": features, "knn_graph": knn_graph},
"data_valuation": {"knn_graph": knn_graph},
}

args_dict = {
Expand Down
1 change: 1 addition & 0 deletions cleanlab/datalab/internal/issue_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .outlier import OutlierIssueManager
from .noniid import NonIIDIssueManager
from .imbalance import ClassImbalanceIssueManager
from .data_valuation import DataValuationIssueManager
145 changes: 145 additions & 0 deletions cleanlab/datalab/internal/issue_manager/data_valuation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (C) 2017-2023 Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Union,
)

import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix

from cleanlab.datalab.internal.issue_manager import IssueManager

if TYPE_CHECKING: # pragma: no cover
import pandas as pd
from cleanlab.datalab.datalab import Datalab


class DataValuationIssueManager(IssueManager):
"""Manages data sample with low valuation."""

description: ClassVar[
str
] = """
Examples that contribute minimally to a model's training
receive lower valuation scores.
"""

issue_name: ClassVar[str] = "data_valuation"
issue_score_key: ClassVar[str]
verbosity_levels: ClassVar[Dict[int, List[str]]] = {
0: [],
1: [],
2: [],
3: ["average_data_valuation"],
}

DEFAULT_THRESHOLDS = 1e-6

def __init__(
self,
datalab: Datalab,
threshold: Optional[float] = None,
k: int = 10,
**kwargs,
):
super().__init__(datalab)
self.k = k
self.threshold = threshold if threshold is not None else self.DEFAULT_THRESHOLDS

def find_issues(
self,
**kwargs,
) -> None:
"""Calculate the data valuation score with a provided or existing knn graph.
Based on KNN-Shapley value described in https://arxiv.org/abs/1911.07128
The larger the score, the more valuable the data point is, the more contribution it will make to the model's training.
"""
knn_graph = self._process_knn_graph_from_inputs(kwargs)
labels = self.datalab.labels.reshape(-1, 1)
assert knn_graph is not None, "knn_graph must be already calculated by other issue managers"
assert labels is not None, "labels must be provided"

scores = _knn_shapley_score(knn_graph, labels)

self.issues = pd.DataFrame(
{
f"is_{self.issue_name}_issue": scores < self.threshold,
self.issue_score_key: scores,
},
)
self.summary = self.make_summary(score=scores.mean())

self.info = self.collect_info(self.issues)

def _process_knn_graph_from_inputs(self, kwargs: Dict[str, Any]) -> Union[csr_matrix, None]:
"""Determine if a knn_graph is provided in the kwargs or if one is already stored in the associated Datalab instance."""
knn_graph_kwargs: Optional[csr_matrix] = kwargs.get("knn_graph", None)
knn_graph_stats = self.datalab.get_info("statistics").get("weighted_knn_graph", None)

knn_graph: Optional[csr_matrix] = None
if knn_graph_kwargs is not None:
knn_graph = knn_graph_kwargs
elif knn_graph_stats is not None:
knn_graph = knn_graph_stats

if isinstance(knn_graph, csr_matrix) and kwargs.get("k", 0) > (
knn_graph.nnz // knn_graph.shape[0]
):
# If the provided knn graph is insufficient, then we need to recompute the knn graph
# with the provided features
knn_graph = None
return knn_graph

def collect_info(self, issues: pd.DataFrame) -> dict:
issues_info = {
"num_low_valuation_issues": sum(issues[f"is_{self.issue_name}_issue"]),
"average_data_valuation": issues[self.issue_score_key].mean(),
}

info_dict = {
**issues_info,
}

return info_dict


def _knn_shapley_score(knn_graph: csr_matrix, labels: np.ndarray) -> np.ndarray:
"""Compute the Shapley values of data points based on a knn graph."""
N = labels.shape[0]
scores = np.zeros((N, N))
dist = knn_graph.indices.reshape(N, -1)
k = dist.shape[1]

for i, y in enumerate(labels):
idx = dist[i][::-1]
ans = labels[idx]
scores[idx[k - 1]][i] = float(ans[k - 1] == y) / k
cur = k - 2
for j in range(k - 1):
scores[idx[cur]][i] = scores[idx[cur + 1]][i] + float(
int(ans[cur] == y) - int(ans[cur + 1] == y)
) / k * (min(cur, k - 1) + 1) / (cur + 1)
cur -= 1
return 0.5 * (np.mean(scores, axis=1) + 1)
135 changes: 109 additions & 26 deletions cleanlab/datalab/internal/issue_manager/noniid.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,38 +124,119 @@ def __init__(
self.seed = seed
self.significance_threshold = significance_threshold

def find_issues(self, features: Optional[npt.NDArray] = None, **kwargs) -> None:
# TODO: Temporary flag introduced to decide on storing knn graphs based on pred_probs.
# Revisit and finalize the implementation.
self._skip_storing_knn_graph_for_pred_probs: bool = False

@staticmethod
def _determine_features(
features: Optional[npt.NDArray],
pred_probs: Optional[np.ndarray],
) -> npt.NDArray:
"""
Determines the feature array to be used for the non-IID check. Prioritizing the original features array over pred_probs.
Parameters
----------
features :
Original feature array or None.
pred_probs :
Predicted probabilities array or None.
Returns
-------
features_to_use :
Either the original feature array or the predicted probabilities array,
intended to be used for the non-IID check.
Raises
------
ValueError :
If both `features` and `pred_probs` are None.
"""
if features is not None:
return features

if pred_probs is not None:
return pred_probs

raise ValueError(
"If a knn_graph is not provided, either 'features' or 'pred_probs' must be provided to fit a new knn."
)

def _setup_knn(
self,
features: Optional[npt.NDArray],
pred_probs: Optional[np.ndarray],
knn_graph: Optional[csr_matrix],
metric_changes: bool,
) -> Optional[NearestNeighbors]:
"""
Selects features (or pred_probs if features are None) and sets up a NearestNeighbors object if needed.
Parameters
----------
features :
Original feature array or None.
pred_probs :
Predicted probabilities array or None.
knn_graph :
A precomputed KNN-graph stored in a csr_matrix or None. If None, a new NearestNeighbors object will be created.
metric_changes :
Whether the metric used to compute the KNN-graph has changed.
This is a result of comparing the metric of a pre-existing KNN-graph and the metric specified by the user.
Returns
-------
knn :
A NearestNeighbors object or None.
"""
if features is None and pred_probs is not None:
self._skip_storing_knn_graph_for_pred_probs = True
features_to_use = self._determine_features(features, pred_probs)

if self.metric is None:
self.metric = "cosine" if features_to_use.shape[1] > 3 else "euclidean"

if knn_graph is not None and not metric_changes:
return None

knn = NearestNeighbors(n_neighbors=self.k, metric=self.metric)

if self.metric != knn.metric:
warnings.warn(
f"Metric {self.metric} does not match metric {knn.metric} used to fit knn. "
"Most likely an existing NearestNeighbors object was passed in, but a different "
"metric was specified."
)
self.metric = knn.metric

try:
check_is_fitted(knn)
except NotFittedError:
knn.fit(features_to_use)

return knn

def find_issues(
self,
features: Optional[npt.NDArray] = None,
pred_probs: Optional[np.ndarray] = None,
**kwargs,
) -> None:
knn_graph = self._process_knn_graph_from_inputs(kwargs)
old_knn_metric = self.datalab.get_info("statistics").get("knn_metric")
metric_changes = self.metric and self.metric != old_knn_metric

knn = None # Won't be used if knn_graph is not None
metric_changes = bool(self.metric and self.metric != old_knn_metric)
knn = self._setup_knn(features, pred_probs, knn_graph, metric_changes)

if knn_graph is None or metric_changes:
if features is None:
raise ValueError(
"If a knn_graph is not provided, features must be provided to fit a new knn."
)

if self.metric is None:
self.metric = "cosine" if features.shape[1] > 3 else "euclidean"
knn = NearestNeighbors(n_neighbors=self.k, metric=self.metric)

if self.metric and self.metric != knn.metric:
warnings.warn(
f"Metric {self.metric} does not match metric {knn.metric} used to fit knn. "
"Most likely an existing NearestNeighbors object was passed in, but a different "
"metric was specified."
)
self.metric = knn.metric

try:
check_is_fitted(knn)
except NotFittedError:
knn.fit(features)

self.neighbor_index_choices = self._get_neighbors(knn=knn)
else:
self._skip_storing_knn_graph_for_pred_probs = False
self.neighbor_index_choices = self._get_neighbors(knn_graph=knn_graph)

self.num_neighbors = self.k
Expand Down Expand Up @@ -234,6 +315,8 @@ def collect_info(
def _build_statistics_dictionary(self, knn_graph: csr_matrix) -> Dict[str, Dict[str, Any]]:
statistics_dict: Dict[str, Dict[str, Any]] = {"statistics": {}}

if self._skip_storing_knn_graph_for_pred_probs:
return statistics_dict
# Add the knn graph as a statistic if necessary
graph_key = "weighted_knn_graph"
old_knn_graph = self.datalab.get_info("statistics").get(graph_key, None)
Expand Down
Loading

0 comments on commit f35b3c1

Please sign in to comment.