In [1]:
import numpy as np
import matplotlib as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

# K NN with K-D Tree Implimentation

In [2]:
import heapq as hq

In [3]:
class Node:
    def __init__(self, data,left,right):
        self.left = left
        self.right = right
        self.data = data

In [4]:
def euclidean(x1, x2):
    return np.sqrt(np.sum((x1 - x2) ** 2))

In [5]:
class KDTree:
    def __init__(self, X):
        self.X = X
        N,self.M = X.shape
        self.root = self.build(np.array(range(N)))
        
    def build(self, data, depth=0):
        if len(data)==0: return None
        axis = depth % self.M
        data = data[np.argsort(self.X[data,axis])]
        
        median_idx = len(data) // 2
        median = data[median_idx]
        return Node(
            median,
            self.build(data[:median_idx], depth + 1),
            self.build(data[median_idx+1:], depth + 1)
        )

    def query(self, point,K):
        heap = []
        def recursive_search(node, depth):
            if not node: return
            
            axis = depth % len(point)
            if self.X[node.data][axis] <= point[axis]:
                nearer_path = node.right
                further_path = node.left
            else:
                nearer_path = node.left
                further_path = node.right
            recursive_search(nearer_path,depth+1)
            
            distance = euclidean(self.X[node.data],point)
            
            if len(heap) < K: hq.heappush(heap, (-distance,node.data))
            elif distance < -heap[0][0]:
                hq.heappop(heap)
                hq.heappush(heap, (-distance,node.data))
            
            if len(heap) < K or np.abs(point[axis] - self.X[node.data][axis]) < -heap[0][0]:
                recursive_search(further_path,depth+1)
            
            return 
        recursive_search(self.root,0)
        return heap

In [6]:
class KNN:
    def __init__(self,X,y): # Inefficient Brute Force KNN
        self.kdtree = KDTree(X)
        self.y = y
        self.values,_ = np.unique(y, return_counts=True)

    def predict(self,sample,K=9):
        y_preds = np.zeros(len(sample))
        for j in range(len(sample)):
            # Computing K Nearest Neighbor
            heap = self.kdtree.query(sample[j],K) # KD tree
            k_labels = [self.y[l] for _,l in heap]
                
            # finding argmax y_hat value
            y_pred = None
            max_wv = float("-inf")
            for v in self.values:
                wv_sum = 0
                for i in range(K):
                    wv_sum += 1*(v==k_labels[i])
                if wv_sum > max_wv:
                    max_wv = wv_sum
                    y_pred = v
            y_preds[j] = y_pred
        return y_preds

In [7]:
iris = load_iris()
X,y = iris.data, iris.target

In [8]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1,shuffle=True)

In [9]:
knn_model = KNN(X_train,y_train)

In [10]:
y_preds = knn_model.predict(X_test)

In [14]:
print("Misclassifications:",np.sum(y_preds != y_test),"out of",len(y_test))

Misclassifications: 1 out of 45


In [16]:
# Comparing with the sklearn implimentation
from sklearn.neighbors import KNeighborsClassifier
knn_model_sk = KNeighborsClassifier(n_neighbors=9, algorithm='kd_tree')
knn_model_sk.fit(X_train, y_train)
y_preds = knn_model_sk.predict(X_test)
print("Misclassifications:",np.sum(y_preds != y_test),"out of",len(y_test))

Misclassifications: 1 out of 45
