## 1-Nearest Neighbor with KD Tree
As a lazy learner, kNN defers all the computations of a Nearest Neighbor Search to inference step, and thus using a Nearest Neighbor Search approach for classification using kNN has an expensive inference time complexity of **O(m * log k)** where m is the # of training samples and k is the number of neighbors to consider for NN Search. As such, Kd tree emerges as an alternative such that some prior computation is done at training time. Notice the uppercase in Kd tree as compared to the lowercase of kNN.<br><br>
This notebook implements KD Tree to search 1-NN to classify the handwritten dataset available [here](https://archive.ics.uci.edu/dataset/80/optical+recognition+of+handwritten+digits).

#### Import libraries and data

In [1]:
import os
import sys
import time
from math import floor

import numpy as np
import pandas as pd
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

root_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))
sys.path.append(root_dir_path)
import utils

pd.set_option("display.max_colwidth", None)

In [2]:
# load the dataset
digits_data = load_digits()
X = digits_data.data
y = digits_data.target
# split the dataset
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)
X_train[0:2, :]

array([[ 0.,  0.,  5., 13., 13.,  8.,  0.,  0.,  0.,  0., 16., 11., 13.,
        16.,  6.,  0.,  0.,  1., 16.,  5.,  2., 14.,  9.,  0.,  0.,  0.,
         9., 16., 16., 15.,  0.,  0.,  0.,  0., 10., 16., 14., 14.,  0.,
         0.,  0.,  5., 15.,  4.,  0., 16.,  6.,  0.,  0.,  6., 14.,  7.,
         6., 16.,  4.,  0.,  0.,  0.,  7., 15., 16., 10.,  0.,  0.],
       [ 0.,  0.,  3., 14., 16., 14.,  0.,  0.,  0.,  0., 13., 13., 13.,
        16.,  2.,  0.,  0.,  0.,  1.,  0.,  9., 15.,  0.,  0.,  0.,  0.,
         9., 12., 15., 16., 10.,  0.,  0.,  4., 16., 16., 16., 11.,  3.,
         0.,  0.,  0.,  4.,  9., 14.,  2.,  0.,  0.,  0.,  0.,  2., 15.,
         9.,  0.,  0.,  0.,  0.,  0.,  4., 13.,  1.,  0.,  0.,  0.]])

In [3]:
X_train.shape, X_test.shape

((1257, 64), (540, 64))

In [4]:
# unique values and their counts
unique_values, counts = np.unique(y_train, return_counts=True)
for value, count in zip(unique_values, counts):
    print(f"{value} occurs {count} time(s).")

0 occurs 125 time(s).
1 occurs 132 time(s).
2 occurs 130 time(s).
3 occurs 129 time(s).
4 occurs 121 time(s).
5 occurs 116 time(s).
6 occurs 128 time(s).
7 occurs 124 time(s).
8 occurs 131 time(s).
9 occurs 121 time(s).


Scale the training and test data.

In [5]:
normalized_training_features = utils.minmax_normalize_2d_array(X_train)
normalized_test_features = utils.minmax_normalize_2d_array(X_test)
X_train = None
X_test = None
normalized_training_features[0:2, :]

array([[0.        , 0.        , 0.3125    , 0.8125    , 0.8125    ,
        0.5       , 0.        , 0.        , 0.        , 0.        ,
        1.        , 0.6875    , 0.8125    , 1.        , 0.375     ,
        0.        , 0.        , 0.0625    , 1.        , 0.3125    ,
        0.125     , 0.875     , 0.5625    , 0.        , 0.        ,
        0.        , 0.5625    , 1.        , 1.        , 0.9375    ,
        0.        , 0.        , 0.        , 0.        , 0.625     ,
        1.        , 0.875     , 0.875     , 0.        , 0.        ,
        0.        , 0.3125    , 0.9375    , 0.25      , 0.        ,
        1.        , 0.375     , 0.        , 0.        , 0.375     ,
        0.875     , 0.4375    , 0.375     , 1.        , 0.25      ,
        0.        , 0.        , 0.        , 0.4375    , 0.9375    ,
        1.        , 0.625     , 0.        , 0.        ],
       [0.        , 0.        , 0.1875    , 0.875     , 1.        ,
        0.875     , 0.        , 0.        , 0.        , 0. 

In [6]:
class Tree_Node:
    def __init__(self, X, y, depth):
        self.X = X
        self.y = y
        self.left = None
        self.right = None
        self.depth = depth
        self.median_idx = None
        self.tr_sample = None
        self.label = None
        self.dim_to_split = None

    # Worst Case TC: O( m . (m^2)) = O(m^3)
    # since bad pivot choice for median calculation sorting means height of tree ~m
    # and worst TC for quicksort is O(m^2).
    def build_Kd_tree(self):
        """
        Builds a Kd tree rooted at this node.
        """

        # check if the current node is a leaf node
        if len(self.X) == 1:
            self.tr_sample = self.X[0, :]
            self.label = self.y[0]
        else:
            # find out the feature to calculate median
            self.dim_to_split = self.depth % (len(self.X[0]))
            # find median by sorting the tr samples  along feature
            # TC: Average case: O(m log m) using quicksort.
            # TC: Worst case: O(n^2) using quicksort.
            sorted_indices = self.X[:, self.dim_to_split].argsort()
            self.X = self.X[sorted_indices]
            # also sort the training labels to the corrsponding indices
            self.y = self.y[sorted_indices]
            if len(self.X) % 2 == 0:
                self.median_idx = int((len(self.X) / 2))
            else:
                self.median_idx = floor((len(self.X) / 2))

            # store the median point and label
            self.tr_sample = self.X[self.median_idx, :]
            self.label = self.y[self.median_idx]
            # separate left and right points
            left_points = self.X[0 : self.median_idx, :]
            left_labels = self.y[0 : self.median_idx]
            right_points = self.X[self.median_idx + 1 :, :]
            right_labels = self.y[self.median_idx + 1 :]
            # create KD subtrees' nodes from the left and the right points
            if len(left_points) == 0:
                self.left = None
            else:
                self.left = Tree_Node(left_points, left_labels, self.depth + 1)
                # build KD subtrees
                self.left.build_Kd_tree()
            if len(right_points) == 0:
                self.right = None
            else:
                self.right = Tree_Node(right_points, right_labels, self.depth + 1)
                # build KD subtree
                self.right.build_Kd_tree()

    # Worst TC: Might need to travel to all nodes when no pruning can be done. Thus, O(m)
    def find_nearest_neighbor(
        self, X_test, curr_nearest_dist, curr_nearest_sample, curr_nearest_label
    ):
        """
        Finds the nearest neighbor of test sample X_test by traversing the tree rooted at this node.
        """

        # check if the current node is the closest to the test sample till now
        test_sample_and_curr_node_dist = np.sqrt(np.sum((X_test - self.tr_sample) ** 2))
        if curr_nearest_sample is None or (
            test_sample_and_curr_node_dist < curr_nearest_dist
        ):
            curr_nearest_dist = test_sample_and_curr_node_dist
            curr_nearest_sample = self.tr_sample
            curr_nearest_label = self.label

        # traverse child nodes - find good and bad side to traverse first
        if self.left is None and self.right is None:
            # it is a leaf node
            good_child = None
            bad_child = None

        else:
            if X_test[self.dim_to_split] < self.tr_sample[self.dim_to_split]:
                # test sample is to the left side of the splitting node
                good_child = self.left
                bad_child = self.right
            else:
                good_child = self.right
                bad_child = self.left

        # traverse good side first
        if good_child is not None:
            (
                NN_dist_good_side,
                NN_good_side,
                NN_label_good_side,
            ) = good_child.find_nearest_neighbor(
                X_test, curr_nearest_dist, curr_nearest_sample, curr_nearest_label
            )
        else:
            NN_dist_good_side, NN_good_side, NN_label_good_side = (
                curr_nearest_dist,
                curr_nearest_sample,
                curr_nearest_label,
            )

        # traverse bad side second, but only if it is worth traversing
        # do not traverse bad side if split dim diff between
        # test sample and bad side child's split dim is greater than the currest best dist
        if bad_child is not None and (
            abs(X_test[self.dim_to_split] - self.tr_sample[self.dim_to_split])
            < curr_nearest_dist
        ):
            (
                NN_dist_bad_side,
                NN_bad_side,
                NN_label_bad_side,
            ) = bad_child.find_nearest_neighbor(
                X_test, curr_nearest_dist, curr_nearest_sample, curr_nearest_label
            )
        else:
            NN_dist_bad_side, NN_bad_side, NN_label_bad_side = (
                curr_nearest_dist,
                curr_nearest_sample,
                curr_nearest_label,
            )

        # choose between the left best and the right best
        if NN_dist_good_side < NN_dist_bad_side:
            curr_nearest_dist = NN_dist_good_side
            curr_nearest_sample = NN_good_side
            curr_nearest_label = NN_label_good_side
        else:
            curr_nearest_dist = NN_dist_bad_side
            curr_nearest_sample = NN_bad_side
            curr_nearest_label = NN_label_bad_side

        return curr_nearest_dist, curr_nearest_sample, curr_nearest_label

#### KD Tree Training

In [7]:
# build KD tree
KD_tree = Tree_Node(normalized_training_features, y_train, 0)
KD_tree.build_Kd_tree()

#### Inference
Search for one nearest neighbor of each test sample.

In [8]:
nearest_samples = []
nearest_labels = []
for idx in range(0, len(normalized_test_features)):
    _, nearest_sample, nearest_label = KD_tree.find_nearest_neighbor(
        normalized_test_features[idx, :], float("inf"), None, None
    )
    nearest_samples.append(nearest_sample)
    nearest_labels.append(nearest_label)

In [9]:
utils.calculate_accuracy(y_test, nearest_labels)

98.51851851851852

Despite the worst time complexity if O(m), Kd trees still perform well in practice since pruning is done to some degree in real applications. However, Kd tree struggles with the curse of dimensionality as the feature space dimensions increase.