Skip to content

Commit

Permalink
update knn shapely score transformation
Browse files Browse the repository at this point in the history
- Make the _knn_shapley_score function easily testable.
- Adjust hard-coded scores in tests to new transformation.
- Add new test class for property based test that asserts that the raw scores are never negative.

Resolves cleanlab#1127
  • Loading branch information
elisno committed Jun 12, 2024
1 parent b0ca6e5 commit 721614b
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 16 deletions.
19 changes: 11 additions & 8 deletions cleanlab/data_valuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,19 @@
from cleanlab.internal.neighbor.knn_graph import create_knn_graph_and_index


def _knn_shapley_score(knn_graph: csr_matrix, labels: np.ndarray, k: int) -> np.ndarray:
"""Compute the Shapley values of data points based on a knn graph."""
def _knn_shapley_score(neighbor_indices: np.ndarray, labels: np.ndarray, k: int) -> np.ndarray:
"""Compute the Shapley values of data points based on neighbor indices in a knn graph."""
N = labels.shape[0]
scores = np.zeros((N, N))
dist = knn_graph.indices.reshape(N, -1)

for y, s, dist_i in zip(labels, scores, dist):
idx = dist_i[::-1]
for y, s, neighbors_i in zip(labels, scores, neighbor_indices):
idx = neighbors_i[::-1]
ans = labels[idx]
s[idx[k - 1]] = float(ans[k - 1] == y)
ans_matches = (ans == y).flatten()
for j in range(k - 2, -1, -1):
s[idx[j]] = s[idx[j + 1]] + float(int(ans_matches[j]) - int(ans_matches[j + 1]))
return 0.5 * (np.mean(scores / k, axis=0) + 1)
return np.mean(scores / k, axis=0)


def data_shapley_knn(
Expand Down Expand Up @@ -91,7 +90,7 @@ def data_shapley_knn(
An array of transformed Data Shapley values for each data point, calibrated to indicate their relative importance.
These scores have been adjusted to fall within 0 to 1.
Values closer to 1 indicate data points that are highly influential and positively contribute to a trained ML model's performance.
Conversely, scores below 0.5 indicate data points estimated to negatively impact model performance.
Conversely, scores below 0.0 indicate data points estimated to negatively impact model performance. This function clips negative scores to 0.0.
Raises
------
Expand All @@ -113,4 +112,8 @@ def data_shapley_knn(
# Use provided knn_graph or compute it from features
if knn_graph is None:
knn_graph, _ = create_knn_graph_and_index(features, n_neighbors=k, metric=metric)
return _knn_shapley_score(knn_graph, labels, k)

num_examples = labels.shape[0]
distances = knn_graph.indices.reshape(num_examples, -1)
scores = _knn_shapley_score(distances, labels, k)
return np.maximum(scores, 0)
4 changes: 2 additions & 2 deletions cleanlab/datalab/internal/issue_manager/data_valuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class DataValuationIssueManager(IssueManager):
Since the original knn-shapley value is in [-1, 1], we transform it to [0, 1] by:
.. math::
0.5 \times (\text{shapley} + 1)
max(\text{shapley}, 0)
here shapley is the original knn-shapley value.
"""
Expand All @@ -94,7 +94,7 @@ class DataValuationIssueManager(IssueManager):
3: ["average_data_valuation"],
}

DEFAULT_THRESHOLD = 0.5
DEFAULT_THRESHOLD = 1e-6

def __init__(
self,
Expand Down
8 changes: 3 additions & 5 deletions tests/datalab/datalab/test_datalab.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,11 +1669,9 @@ def test_all_identical_dataset(self):
assert data_valuation_issues["is_data_valuation_issue"].sum() == 0

# For a full knn-graph, all data points have the same value. Here, they all contribute the same value.
# The score of 54/99 is a value that works for 11 identical data points.
# TODO: Find a reasonable test for larger dataset, with k much smaller than N. Hard to guarantee a score of 0.5.
np.testing.assert_allclose(
data_valuation_issues["data_valuation_score"].to_numpy(), 54 / 99
)
# The score of 1/11 is a value that works for 11 identical data points.
# TODO: Find a reasonable test for larger dataset, with k much smaller than N. Hard to guarantee a score of 0.0.
np.testing.assert_allclose(data_valuation_issues["data_valuation_score"].to_numpy(), 1 / 11)


class TestIssueManagersReuseKnnGraph:
Expand Down
58 changes: 57 additions & 1 deletion tests/test_data_valuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@

import numpy as np
import pytest
from hypothesis import given, settings, strategies as st
from hypothesis.strategies import composite
from hypothesis.extra.numpy import arrays

from sklearn.neighbors import NearestNeighbors

from cleanlab.data_valuation import data_shapley_knn
from cleanlab.data_valuation import _knn_shapley_score, data_shapley_knn
from cleanlab.internal.neighbor.knn_graph import create_knn_graph_and_index


class TestDataValuation:
Expand Down Expand Up @@ -52,3 +56,55 @@ def test_data_shapley_knn_with_knn_graph(self, labels, knn_graph):
assert shapley.shape == (100,)
assert np.all(shapley >= 0)
assert np.all(shapley <= 1)


@composite
def valid_data(draw):
"""
A custom strategy to generate valid labels, features, and k such that:
- labels and features have the same length
- k is less than the length of labels and features
"""
# Generate a valid length for labels and features
length = draw(st.integers(min_value=11, max_value=1000))

# Generate labels and features of the same length
labels = draw(
arrays(
dtype=np.int32,
shape=length,
elements=st.integers(min_value=0, max_value=length - 1),
)
)
features = draw(
arrays(
dtype=np.float64,
shape=(length, draw(st.integers(min_value=2, max_value=50))),
elements=st.floats(min_value=-1.0, max_value=1.0),
)
)

# Generate k such that it is less than the length of labels and features
k = draw(st.integers(min_value=1, max_value=length - 1))

return labels, features, k


class TestDataShapleyKNNScore:
"""This test class prioritizes testing the raw/untransformed outputs of the _knn_shapley_score function."""

@settings(
max_examples=1000, deadline=None
) # Increase the number of examples to test more cases
@given(valid_data())
def test_knn_shapley_score_property(self, data):
labels, features, k = data

knn_graph, _ = create_knn_graph_and_index(features, n_neighbors=k)
neighbor_indices = knn_graph.indices.reshape(-1, k)

scores = _knn_shapley_score(neighbor_indices, labels, k)

assert scores.shape == (len(labels),)
assert np.all(scores >= 0), "Shapley scores should never be negative."
assert np.all(scores <= 1), "Shapley scores should be between 0 and 1."

0 comments on commit 721614b

Please sign in to comment.