In [543]:
import numpy as np
class Point:
    def __init__(self, label, coordinates):
        self.label = label
        self.coordinates = np.array(coordinates)
        
    def getDist(self, cur_point):
        return np.linalg.norm(cur_point.coordinates - self.coordinates)

In [544]:
import csv
def getPoints(file_name):
    with open (file_name, 'r') as file_obj:
        points = []
        reader = csv.DictReader(file_obj, delimiter=',')
        for row in reader:
            coordinates = []
            label = row['label']
            for col in reader.fieldnames:
                if (col != 'label'):
                    coordinates.append(float(row[col]))
            points.append(Point(label, coordinates))
    return points

In [545]:
def get_neighbors(points, maxK):
    ans = {}
    for point in points:
        dist = [0] * len(points)
        for i in range(len(points)):
            dist[i] = points[i].getDist(point)    
        idx = np.array(dist).argsort()[1:maxK+1]
        neighbors = []
        for i in idx:
            neighbors.append(points[i])
        ans[point] = neighbors
    return ans

In [546]:
def knn(points, k, point, neighbors):
    labels = {}
    labels[1] = 0;
    max_count = 0;
    label_with_max_cout = 0;
    cur_neighbors = neighbors[point][:k]
    for neighbor in cur_neighbors:
        if (neighbor.label not in labels):
            labels[neighbor.label] = 1
        else:
            labels[neighbor.label] += 1
        if (labels[neighbor.label] > max_count):
            max_count = labels[neighbor.label]
            label_with_max_cout = neighbor.label
    return label_with_max_cout

In [547]:
def leave_one_out_error(file_name):
    print(file_name)
    points = getPoints(file_name)
    neighbors = get_neighbors(points, 10)
    for k in range(10, 0, -1):
        count = 0
        for point in points:
            guess = knn(points, k, point, neighbors)
            if (guess != point.label):
                count += 1
        print(str(k) + ': loo=' + str(count*1.0/len(points)))

In [548]:
leave_one_out_error('datasets/cancer.csv')

datasets/cancer.csv
10: loo=0.06678383128295255
9: loo=0.06678383128295255
8: loo=0.06854130052724078
7: loo=0.06854130052724078
6: loo=0.06678383128295255
5: loo=0.06678383128295255
4: loo=0.07381370826010544
3: loo=0.07381370826010544
2: loo=0.0843585237258348
1: loo=0.0843585237258348


In [549]:
leave_one_out_error('datasets/spam.csv')

datasets/spam.csv
10: loo=0.20256465985655292
9: loo=0.20256465985655292
8: loo=0.1958269941317105
7: loo=0.1958269941317105
6: loo=0.18539447946098675
5: loo=0.18539447946098675
4: loo=0.18452510323842644
3: loo=0.18452510323842644
2: loo=0.16865898717670072
1: loo=0.16865898717670072
