# Data Valuation with Nearest Neighbor Explainers

This notebook explains how explainers of nearest-neighbor (NN) models can be used for Data Valuation, the task of evaluating the usefulness of individual training data points in classification problems.
When explaining NN models, a game is defined by first choosing an explanation point $x_\text{explain}$ and class $y_\text{explain}$; the training data points $\mathcal{D} := \mathcal{X} \times \mathcal{Y}$ are the game's players, and the definition of the utility $\nu(S)$ of a coalition $S \subseteq \mathcal{D}$ is based on the probability of the model predicting class $y_\text{explain}$ on $x_\text{explain}$ if it's training data were limited to $S$.

There is support for explaining the the `KNeighborsClassifier` model (with `'uniform'` or `'distance'` weights) and `RadiusNeighborsClassifier` model from the `scikit-learn` library.
The algorithms are based on the publications from [Jia et al. (2019)](https://doi.org/10.48550/arXiv.1908.08619/), [Wang et al. (2024)](https://doi.org/10.48550/arXiv.1908.08619)
and [Wang et al. (2023)](https://doi.org/10.48550/arXiv.2308.15709), respectively.

Let's start by generating a synthetic classification datset and fitting a simple `KNeighborsClassifier` to it.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from plot_helpers import plot_datasets
from sklearn.datasets import make_classification
from sklearn.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier

X_train, y_train = make_classification(
    n_samples=30,
    n_features=2,
    n_redundant=0,
    n_clusters_per_class=1,
    n_informative=2,
    n_classes=2,
    random_state=45,
)

fig, ax = plt.subplots(figsize=(6, 6))
plot_datasets(ax, X_train, y_train)

model = KNeighborsClassifier(n_neighbors=3)
model.fit(X_train, y_train)

x_explain = np.array([[-0.75, -0.4]])
y_explain_pred = model.predict(x_explain)[0]
print(f"Prediction: class {y_explain_pred}")

y_explain_proba = model.predict_proba(x_explain)[0]
print(f"Prediction probabilities: {y_explain_proba}")

## Using the `KNNExplainer` for Unweighted $k$-Nearest Neighbor Models

To explain the prediction, we create an explainer for the model by passing it to the constructor of `Explainer`, which will automatically dispatch to the adequate subclass `KNNExplainer`.

In [None]:
from shapiq import Explainer

explainer = Explainer(model, class_index=y_explain_pred, max_order=1)
print(type(explainer))

Note that we set `class_index=y_explain_pred`, since for now, we want to quantify the contribution of the training data to the class that was actually predicted. (We could also set a different class index if we wished to see how much the data points contribute to shifting the prediction towards another class.)

Now we can get an explanation for the prediction we saw above:

In [None]:
iv = explainer.explain(x_explain)
print(iv)

## Explaining Weighted $k$-Nearest Neighbor and Threshold Nearest Neighbor Models

There are separate explainers for weighted $k$-NN and threshold NN models, which are selected automatically when an `Explainer` is instantiated with a corresponding model:

In [None]:
wknn_model = KNeighborsClassifier(n_neighbors=3, weights="distance")
wknn_model.fit(X_train, y_train)
wknn_explainer = Explainer(wknn_model, max_order=1)
print(type(wknn_explainer))

tnn_model = RadiusNeighborsClassifier()
tnn_model.fit(X_train, y_train)
tnn_explainer = Explainer(tnn_model, max_order=1)
print(type(tnn_explainer))

They can be used just the same way:

In [None]:
print(wknn_explainer.explain(x_explain))
print(tnn_explainer.explain(x_explain))

## Identifying corrupted training samples

We can estimate the usefulness of each point of a training data set by calculating Shapley values for a set of test data points and averaging the results. This will allow us to identify potentially mislabeled data points.

First, let's create a classification data set and split it into train and test sets. We will corrupt the training data by changing the class of a few randomly selected data points.

In [None]:
from sklearn.model_selection import train_test_split

X, y = make_classification(
    n_samples=100,
    n_features=2,
    n_redundant=0,
    n_clusters_per_class=1,
    n_informative=2,
    n_classes=2,
    flip_y=0,
    random_state=49,
    class_sep=1.5,
)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

y_train_corrupted = y_train.copy()
n_corrupt = 7
rng = np.random.default_rng(seed=43)
corrupted = rng.choice(np.arange(X_train.shape[0]), size=n_corrupt, replace=False)
# Since our only class indices are 0 and 1, this is a quick way to flip the class
y_train_corrupted[corrupted] = 1 - y_train[corrupted]

fig, ax = plt.subplots(figsize=(6, 6))
plot_datasets(ax, X_train, y_train_corrupted, X_test, y_test)
# Mark corrupted datapoints
ax.scatter(
    X_train[corrupted, 0],
    X_train[corrupted, 1],
    marker="o",
    edgecolors="#b1170c",
    facecolors="none",
    s=100,
);

Now, we can use the `KNNExplainer` to compute the training points' Shapley values based on the entire test dataset by averaging the Shapley values computed using each test point.

In [None]:
from shapiq.explainer.nn.iv_utils import interaction_values_to_array

# Train the model with the corrupted training data
model = KNeighborsClassifier(n_neighbors=5)
model.fit(X_train, y_train_corrupted)

sv_test = np.zeros(X_train.shape[0], dtype=np.float64)

for x_test_current, y_test_current in zip(X_test, y_test, strict=True):
    explainer = Explainer(model, class_index=y_test_current, max_order=1)
    iv = explainer.explain(x_test_current)
    sv_current = interaction_values_to_array(iv)
    sv_test += sv_current

sv_test /= X_test.shape[0]

We can reasonably assume that the corrupted training data points will on average make the model's prediction worse, resulting in negative Shapley values. So let's filter out just those indices where the Shapley value is below zero and compare with our original array of corrupted indices:

In [None]:
print(f"Corrupted: {np.sort(corrupted)}")  # Sort for easier comparison
print(f"Negative Shapley values: {np.where(sv_test < 0)[0]}")

We have identified the set corrupted samples almost exactly. The fact that the point with index 20 was missed, however, shows that this method is not failsafe but only an estimate.