## 8. The k-Nearest Neighbors (k-NN) Classifier

Having discussed the ideal Bayes classifier and the challenges in realizing it, we now turn to a practical, data-driven approach: the **k-Nearest Neighbors (k-NN) classifier**. This method makes predictions for a new data point based on the labels of its 'closest' neighbors in the training data.

Let's assume we have a training dataset $\mathcal{D}_{\text{train}} = \{(\mathbf{x}_n, t_n)\}_{n=1}^N$, where $\mathbf{x}_n \in \mathbb{R}^D$ is the $D$-dimensional feature vector for the $n$-th sample and $t_n$ is its corresponding class label.

Given a new, unseen feature vector $\mathbf{x}_{\text{new}}$ for which we want to predict the class label, the k-NN algorithm proceeds as follows:

### 8.1 Defining Neighborhoods

1.  **Choose an integer $k$**: This $k$ is the number of nearest neighbors to consider. It's a hyperparameter of the model.
2.  **Calculate Distances**: Compute the distance between $\mathbf{x}_{\text{new}}$ and every point $\mathbf{x}_n$ in the training set. The most common distance metric is the **Euclidean distance**:
    $$
    d(\mathbf{x}_{\text{new}}, \mathbf{x}_n) = ||\mathbf{x}_{\text{new}} - \mathbf{x}_n||_2 = \sqrt{\sum_{j=1}^{D} (x_{\text{new},j} - x_{n,j})^2}
    $$
    Other distance metrics (e.g., Manhattan, Minkowski, Mahalanobis) can also be used depending on the nature of the data.
3.  **Identify the $k$ Nearest Neighbors**: Select the $k$ training samples $(\mathbf{x}_n, t_n)$ that have the smallest distances to $\mathbf{x}_{\text{new}}$. Let $\mathcal{N}_k(\mathbf{x}_{\text{new}})$ denote this set of $k$ nearest neighbors.

Visually, if we imagine a hypersphere centered at $\mathbf{x}_{\text{new}}$, its radius is expanded just enough to enclose exactly $k$ training data points. These $k$ points form the neighborhood.

### 8.2 Decision Rule: Majority Voting

Once the $k$ nearest neighbors $\mathcal{N}_k(\mathbf{x}_{\text{new}})$ are identified, the class label for $\mathbf{x}_{\text{new}}$ is determined by a **majority vote** among these $k$ neighbors.
*   For each class $C_j$, count how many of the $k$ neighbors belong to $C_j$. Let this count be $N_j(\mathbf{x}_{\text{new}})$.
*   The predicted class $f_{\text{k-NN}}(\mathbf{x}_{\text{new}})$ is the class that is most frequent among the $k$ neighbors:
    $$
    f_{\text{k-NN}}(\mathbf{x}_{\text{new}}) = \underset{C_j}{\mathrm{argmax}} \ N_j(\mathbf{x}_{\text{new}})
    $$

**Handling Ties:** If there's a tie for the most frequent class, it can be broken arbitrarily (e.g., by picking the class with the smallest index, or randomly). For binary classification (e.g., classes $C_1, C_0$), a common convention is to assign to $C_1$ if the count for $C_1$ is $\ge k/2$.

**Relationship to Posterior Probabilities:**
For a given $\mathbf{x}_{\text{new}}$, the fraction of neighbors belonging to class $C_j$, i.e., $p_j(\mathbf{x}_{\text{new}}) = N_j(\mathbf{x}_{\text{new}}) / k$, can be seen as a local estimate of the posterior probability $p(C_j | \mathbf{x}_{\text{new}})$. The k-NN rule, by choosing the class with the highest count, is essentially applying the MAP rule using these local estimates.

### 8.3 Example: Binary Classification (e.g., classes +1 and -1)

Let's consider a binary classification problem where the labels are $t \in \{+1, -1\}$.
Let $N_{+1}(\mathbf{x}_{\text{new}})$ be the number of neighbors in $\mathcal{N}_k(\mathbf{x}_{\text{new}})$ with label $+1$.
The proportion of neighbors belonging to class $+1$ is $p(\mathbf{x}_{\text{new}}) = N_{+1}(\mathbf{x}_{\text{new}}) / k$.
The k-NN decision rule can be written as:
$$
f_{\text{k-NN}}(\mathbf{x}_{\text{new}}) = \begin{cases} +1 & \text{if } p(\mathbf{x}_{\text{new}}) \ge 0.5 \\ -1 & \text{if } p(\mathbf{x}_{\text{new}}) < 0.5 \end{cases}
$$
(Here, ties where $p(\mathbf{x}_{\text{new}}) = 0.5$ are typically assigned to the class $+1$ by convention, though this can vary.)

### 8.4 Weighted k-Nearest Neighbors (Weighted k-NN)

The standard k-NN rule gives equal weight to all $k$ neighbors. However, it might be intuitive that closer neighbors are more relevant than farther ones within the neighborhood. **Weighted k-NN** addresses this by assigning a weight $w_i$ to each neighbor $i \in \mathcal{N}_k(\mathbf{x}_{\text{new}})$, typically inversely proportional to its distance $d_i = d(\mathbf{x}_{\text{new}}, \mathbf{x}_i)$.

For example, weights can be defined as $w_i = 1/d_i$ (or $w_i = 1/(d_i^2)$, or $w_i = \exp(-d_i^2/\sigma^2)$ for a Gaussian-like kernel). These weights are often normalized to sum to 1.
The decision is then made by summing the weights for each class and picking the class with the largest total weight:
$$
f_{\text{Wk-NN}}(\mathbf{x}_{\text{new}}) = \underset{C_j}{\mathrm{argmax}} \sum_{i \in \mathcal{N}_k(\mathbf{x}_{\text{new}}) \text{ and } t_i=C_j} w_i
$$
Weighted k-NN can sometimes improve performance, especially if $k$ is large or if there are irrelevant features that make distances less meaningful.

### 8.5 k-NN for Multiclass Classification

The k-NN rule extends naturally to multiclass classification problems with $K > 2$ classes. As stated in Section 8.2, we simply count the number of neighbors belonging to each of the $K$ classes within the neighborhood $\mathcal{N}_k(\mathbf{x}_{\text{new}})$ and assign $\mathbf{x}_{\text{new}}$ to the class with the highest count.



In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.utils.multiclass import unique_labels
from functools import partial # For jitting methods with static args

# Helper pure function for JIT and VMAP (core k-NN logic for one point)
# This function assumes y_train is already mapped to 0..K-1
@partial(jax.jit, static_argnames=("n_neighbors", "weights_mode", "p_norm", "num_classes"))
def _knn_predict_single_kernel(x_new, X_train, y_train_mapped, n_neighbors, weights_mode, p_norm, num_classes):
    distances = jnp.linalg.norm(X_train - x_new, ord=p_norm, axis=1)
    
    # Using jax.lax.top_k for finding k smallest distances (more efficient than full argsort)
    # top_k returns (-distances, indices), so we take -values[0] to get actual smallest distances
    neg_distances = -distances
    top_k_neg_distances, top_k_indices = jax.lax.top_k(neg_distances, k=n_neighbors)
    # k_smallest_distances = -top_k_neg_distances # if needed for 'distance' weighting without recomputing
    
    neighbor_labels_mapped = y_train_mapped[top_k_indices]

    if weights_mode == 'uniform':
        counts = jnp.bincount(neighbor_labels_mapped.astype(jnp.int32), length=num_classes)
        return jnp.argmax(counts) # Returns mapped label
    
    elif weights_mode == 'distance':
        # Get distances for the k neighbors (re-access or use k_smallest_distances)
        neighbor_distances = distances[top_k_indices]
        
        # Handle zero distances: if a neighbor is identical, its weight is effectively infinite.
        # To simplify, if distance is zero, that point's class dominates if it's unique among zero-distance points.
        # A common practical approach: add a small epsilon.
        weights_arr = 1.0 / (neighbor_distances + 1e-6) 
        
        weighted_counts = jnp.bincount(neighbor_labels_mapped.astype(jnp.int32), weights=weights_arr, length=num_classes)
        return jnp.argmax(weighted_counts) # Returns mapped label
    else:
        # This part won't be JIT-friendly if it raises Python errors based on dynamic values.
        # For JIT, all branches should lead to valid JAX computations.
        # We assume 'weights_mode' is validated before calling the kernel.
        return -1 # Should not happen if validated

@partial(jax.jit, static_argnames=("n_neighbors", "weights_mode", "p_norm", "num_classes"))
def _knn_predict_proba_single_kernel(x_new, X_train, y_train_mapped, n_neighbors, weights_mode, p_norm, num_classes):
    distances = jnp.linalg.norm(X_train - x_new, ord=p_norm, axis=1)
    neg_distances = -distances
    _, top_k_indices = jax.lax.top_k(neg_distances, k=n_neighbors)
    neighbor_labels_mapped = y_train_mapped[top_k_indices]

    if weights_mode == 'uniform':
        counts = jnp.bincount(neighbor_labels_mapped.astype(jnp.int32), length=num_classes)
        return counts / jnp.sum(counts)
    
    elif weights_mode == 'distance':
        neighbor_distances = distances[top_k_indices]
        weights_arr = 1.0 / (neighbor_distances + 1e-6)
        weighted_counts = jnp.bincount(neighbor_labels_mapped.astype(jnp.int32), weights=weights_arr, length=num_classes)
        sum_weighted_counts = jnp.sum(weighted_counts)
        # Avoid division by zero if all weights are zero (e.g., all distances are huge and epsilon is tiny)
        return jax.lax.cond(sum_weighted_counts > 1e-6,
                            lambda: weighted_counts / sum_weighted_counts,
                            lambda: jnp.ones(num_classes) / num_classes) # Fallback to uniform if sum is zero
    else:
        return jnp.zeros(num_classes) # Should not happen


class JAXKNeighborsClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, n_neighbors=5, weights='uniform', p=2):
        self.n_neighbors = n_neighbors
        self.weights = weights
        self.p = p # p for Minkowski distance (p=2 is Euclidean)

    def fit(self, X, y):
        # Check that X and y have correct shape
        X, y = check_X_y(X, y)
        
        # Store JAX arrays
        self._X_train = jnp.asarray(X)
        self._y_train_original = jnp.asarray(y) # Store original y for mapping back if needed
        
        # Store the classes seen during fit
        self.classes_ = unique_labels(y) # NumPy array of original class labels
        self._num_classes = len(self.classes_)

        # Create a mapping from original class labels to 0...K-1
        self._class_to_int_map = {label: i for i, label in enumerate(self.classes_)}
        self._int_to_class_map = {i: label for i, label in enumerate(self.classes_)}
        
        # Mapped y_train (0 to K-1)
        # Convert y to NumPy array first if it's a JAX array for list comprehension
        y_np = np.asarray(y) 
        self._y_train_mapped = jnp.array([self._class_to_int_map[label] for label in y_np])

        self.is_fitted_ = True
        return self

    def _predict_vmapped(self, X_new):
        # Vmap the kernel over the new data points
        # The kernel itself is JITted
        vmap_kernel = jax.vmap(
            _knn_predict_single_kernel, 
            in_axes=(0, None, None, None, None, None, None), # x_new varies, others are fixed
            out_axes=0
        )
        mapped_predictions = vmap_kernel(
            X_new, self._X_train, self._y_train_mapped, 
            self.n_neighbors, self.weights, self.p, self._num_classes
        )
        return mapped_predictions

    def predict(self, X):
        # Check if fit has been called
        check_is_fitted(self)
        # Input validation
        X = check_array(X)
        X_jax = jnp.asarray(X)

        if X_jax.shape[0] == 0:
            return np.array([])

        mapped_predictions = self._predict_vmapped(X_jax)
        
        # Convert mapped predictions (0..K-1) back to original class labels
        # Convert JAX array to NumPy for this mapping step
        mapped_predictions_np = np.asarray(mapped_predictions)
        original_predictions = np.array([self._int_to_class_map[mapped_label] for mapped_label in mapped_predictions_np])
        
        return original_predictions

    def _predict_proba_vmapped(self, X_new):
        vmap_kernel = jax.vmap(
            _knn_predict_proba_single_kernel,
            in_axes=(0, None, None, None, None, None, None),
            out_axes=0
        )
        probas = vmap_kernel(
            X_new, self._X_train, self._y_train_mapped,
            self.n_neighbors, self.weights, self.p, self._num_classes
        )
        return probas

    def predict_proba(self, X):
        check_is_fitted(self)
        X = check_array(X)
        X_jax = jnp.asarray(X)

        if X_jax.shape[0] == 0:
            return np.empty((0, self._num_classes))
            
        # Ensure the columns in predict_proba output match self.classes_ order
        # The _knn_predict_proba_single_kernel already produces probabilities in the 0..K-1 mapped order.
        # So, the output columns directly correspond to the order in self.classes_
        probas_jax = self._predict_proba_vmapped(X_jax)
        return np.asarray(probas_jax) # Convert to NumPy array

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import datasets
# from sklearn.neighbors import KNeighborsClassifier # We'll use our JAX version

# --- Load Iris Dataset ---
iris = datasets.load_iris()
# We only take the first two features for visualization.
X_iris = iris.data[:, :2]
y_iris = iris.target

n_neighbors = 11 # As in the example image

# --- Plotting setup ---
h = .02  # step size in the mesh
cmap_light = ListedColormap(['#A0A0E8', '#A0E8A0', '#E8E8A0']) # Adjusted for potential color diffs
cmap_bold = ['#440154', '#21908d', '#fde725'] # Viridis colors, dark, mid, light

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for i, weights in enumerate(['uniform', 'distance']):
    # Instantiate our JAX k-NN classifier
    clf = JAXKNeighborsClassifier(n_neighbors=n_neighbors, weights=weights)
    clf.fit(X_iris, y_iris)

    # Create a mesh to plot in
    x_min, x_max = X_iris[:, 0].min() - 1, X_iris[:, 0].max() + 1
    y_min, y_max = X_iris[:, 1].min() - 1, X_iris[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    
    # Obtain predictions on the mesh
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)

    # Plot decision boundary
    ax = axes[i]
    ax.contourf(xx, yy, Z, cmap=cmap_light)

    # Plot training points
    for class_idx, color in zip(clf.classes_, cmap_bold):
        idx = np.where(y_iris == class_idx)
        ax.scatter(X_iris[idx, 0], X_iris[idx, 1], c=color, label=iris.target_names[class_idx],
                   edgecolor='k', s=40)

    ax.set_xlim(xx.min(), xx.max())
    ax.set_ylim(yy.min(), yy.max())
    ax.set_title(f"3-Class classification (JAX k-NN)\n(k={n_neighbors}, weights='{weights}')")
    ax.set_xlabel(iris.feature_names[0])
    ax.set_ylabel(iris.feature_names[1])
    ax.legend(title="Classes")

plt.tight_layout()
plt.show()

### 8.6 Voronoi Diagrams (for k=1 NN)

When $k=1$ (i.e., the 1-Nearest Neighbor rule), the classification of a new point $\mathbf{x}_{\text{new}}$ is determined solely by the class of its single closest training point.
In this case, the feature space $\mathbb{R}^D$ can be partitioned into regions, called **Voronoi cells**. Each training point $\mathbf{x}_n$ defines a Voronoi cell $V_n$ consisting of all points in $\mathbb{R}^D$ that are closer to $\mathbf{x}_n$ than to any other training point $\mathbf{x}_m$ ($m \neq n$).
$$
V_n = \{ \mathbf{x} \in \mathbb{R}^D \mid ||\mathbf{x} - \mathbf{x}_n|| \le ||\mathbf{x} - \mathbf{x}_m|| \text{ for all } m \neq n \}
$$
The collection of these cells forms a **Voronoi diagram** (or Voronoi tessellation).
For the 1-NN classifier, all points $\mathbf{x}_{\text{new}}$ falling into the Voronoi cell $V_n$ (associated with training point $\mathbf{x}_n$) will be assigned the class label $t_n$. The decision boundaries of the 1-NN classifier are therefore the boundaries of these Voronoi cells.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import Voronoi, voronoi_plot_2d
from sklearn.neighbors import KNeighborsClassifier # To color regions based on 1-NN

# --- 1. Generate Synthetic Data (similar to the book's figure) ---
np.random.seed(42) # For reproducibility
N_points = 100
# Points randomly distributed in a [0,1] x [0,1] square
points = np.random.rand(N_points, 2)

# Assign classes (e.g., randomly or based on some simple rule)
# For simplicity, let's assign them randomly for this visual example
# In a real scenario, these would be your training data labels
labels = np.random.randint(0, 2, N_points)

# Colors for the classes (similar to red/green, or any two distinct colors)
class_colors_points = ['#d62728', '#2ca02c'] # Red, Green (matplotlib default)
class_colors_regions = ['#ff9999', '#99ff99'] # Lighter red, Lighter green for regions

# --- 2. Compute Voronoi Tessellation ---
vor = Voronoi(points)

# --- 3. Plotting ---
fig, ax = plt.subplots(figsize=(8, 8))

# --- Method 1: Using voronoi_plot_2d and then filling regions ---
# This method is good for drawing the lines and points, but filling by class requires extra steps.
# voronoi_plot_2d(vor, ax=ax, show_vertices=False, line_colors='blue', line_width=1, point_size=0)

# --- Method 2: Color regions using 1-NN predictions on a meshgrid (More direct for 1-NN vis) ---
# This effectively shows the decision boundary of a 1-NN classifier, which IS the Voronoi diagram.

# Create a 1-NN classifier using the generated points and labels
clf_1nn = KNeighborsClassifier(n_neighbors=1)
clf_1nn.fit(points, labels)

# Create a mesh to plot in
h = .01  # step size in the mesh
x_min, x_max = points[:, 0].min() - 0.1, points[:, 0].max() + 0.1
y_min, y_max = points[:, 1].min() - 0.1, points[:, 1].max() + 0.1
# Ensure plot limits are roughly [0,1] if points are in [0,1] for better visual consistency
x_min_plot, x_max_plot = -0.05, 1.05
y_min_plot, y_max_plot = -0.05, 1.05

xx, yy = np.meshgrid(np.arange(x_min_plot, x_max_plot, h),
                     np.arange(y_min_plot, y_max_plot, h))

# Obtain 1-NN predictions on the mesh
Z = clf_1nn.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# Plot the decision regions (filled contours)
custom_cmap = plt.cm.colors.ListedColormap(class_colors_regions)
ax.contourf(xx, yy, Z, cmap=custom_cmap, alpha=0.8) # alpha for slight transparency

# Now, overlay the Voronoi lines explicitly using voronoi_plot_2d
# This ensures the lines are exactly where the Voronoi boundaries are.
voronoi_plot_2d(vor, ax=ax, show_vertices=False, show_points=False, # We'll plot points manually
                  line_colors='blue', line_width=1.5, line_alpha=0.7)


# Plot the original "seed" points
for i in range(len(np.unique(labels))):
    ax.scatter(points[labels == i, 0], points[labels == i, 1], 
               c=class_colors_points[i], edgecolor='k', s=50, 
               label=f'Class {i}')

ax.set_title(f'Voronoi Diagram (1-NN Decision Regions)\n{N_points} random points', fontsize=14)
ax.set_xlabel('Feature 1 (x-axis)', fontsize=12)
ax.set_ylabel('Feature 2 (y-axis)', fontsize=12)
ax.legend(loc='upper right')
ax.set_xlim([x_min_plot, x_max_plot])
ax.set_ylim([y_min_plot, y_max_plot])
ax.set_aspect('equal', adjustable='box') # Ensure aspect ratio is equal for a proper Voronoi look
plt.grid(False) # Turn off grid if not desired, like in the book's figure
plt.show()


### 8.7 Characteristics of k-NN

*   **Non-parametric:** k-NN is a non-parametric method because it does not make strong assumptions about the underlying data distribution (e.g., assuming it's Gaussian). The decision boundary can be highly flexible and adapt to the local structure of the data.
*   **Discriminative:** It directly approximates the decision rule (or $p(C_j|\mathbf{x})$ locally) without modeling $p(\mathbf{x}|C_j)$ or $p(C_j)$.
*   **Instance-based or Memory-based Learning:** k-NN is often called an instance-based learner (or memory-based) because it requires storing the entire training dataset to make predictions. There isn't a distinct "training phase" in the traditional sense of learning parameters; the main computation happens at prediction time (finding neighbors).
*   **Lazy Learner:** Because the computation is deferred until a prediction is requested, k-NN is also known as a lazy learner (as opposed to eager learners like logistic regression or SVMs, which build a model explicitly during training).

While k-NN is simple to understand and implement, it has some important challenges, which we will discuss later.