In [1]:
import collections
import operator
import numpy
from itertools import groupby

TreeNode = collections.namedtuple("TreeNode", ["value", "label", "left", "right"])

def kdtree(points, labels):
    k = len(points[0])
    
    def build(*, points, labels, depth):
        if len(points) == 0:
            return None
        
        o_points = sorted(enumerate(points), key=operator.itemgetter(1, depth % k))
        indexes = [i[0] for i in o_points]
        points = [i[1] for i in o_points]
        labels = [labels[i] for i in indexes]

        middle = len(points) // 2

        return TreeNode(
            value = points[middle], 
            label = labels[middle],
            left = build(
                points = points[:middle],
                labels = labels[:middle],
                depth = depth+1,
            ),
            right = build(
                points = points[middle+1:],
                labels = labels[middle+1:],
                depth = depth+1,
            )
        )
    
    return build(points=list(points), labels=list(labels), depth=0)

In [2]:
def calc_distance(X, Y):
    return sum((i-j)**2 for i, j in zip(X, Y))

In [6]:
Neighbor = collections.namedtuple("Neighbor", ["point", "label", "distance"])

def find_k_neighbors(*, tree, point, k):
    n_dim = len(point)
    knn = []
    best = None
    
    def search(*, tree, depth):
        nonlocal best
        
        if tree is None:
            return 
        
        distance = calc_distance(tree.value, point)
        
        if best is None or distance < best.distance:
            node = Neighbor(point=tree.value, label=tree.label, distance=distance)
            if (node not in knn):
                best = node
        
        axis = depth % n_dim
        diff = point[axis] - tree.value[axis]
        
        if diff <= 0:
            close, away = tree.left, tree.right
        else:
            close, away = tree.right, tree.left
            
        search(tree=close, depth=depth+1)
        
        if diff**2 < best.distance:
            search(tree=away, depth=depth+1)
            
    for i in range(0,k):
        search(tree=tree, depth=0)
        knn.append(best)
        best = None
        
    return knn

In [9]:
class CustomClassifier:
    def __init__(self, k):
        self.k = k
    
    def fit(self, X_train, y_train):
        X_train = list(X_train.itertuples(index=False, name=None))
        y_train = list(y_train)
        self.tree = kdtree(X_train, y_train)
    
    def predict(self, X_test):
        X_test = list(X_test.itertuples(index=False, name=None))
        y_pred = []
        for point in X_test:
            knn = find_k_neighbors(tree=self.tree, point=point, k=self.k)
            knn_labeled = sorted(knn, key=operator.itemgetter(1))
            class_groups = [(key, len(list(group))) for key, group in groupby(knn_labeled, operator.attrgetter('label'))]
            y_pred.append(max(class_groups, key=lambda x: x[1])[0])
        return y_pred
    
    def score(self, X_test, y_test):
        y_test = list(y_test)
        y_pred = self.predict(X_test)
        correct_count = sum(a == b for a, b in zip(y_pred, y_test))
        return correct_count / len(X_test)