# K-Nearest Neighbor

## Scratch

### Distance

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

def vector_subtract(v, w):
    """subtracts two vectors componentwise"""
    return [v_i - w_i for v_i, w_i in zip(v,w)]

def dot(v, w):
    """v_1 * w_1 + ... + v_n * w_n"""
    return sum(v_i * w_i for v_i, w_i in zip(v, w))

def sum_of_squares(v):
    """v_1 * v_1 + ... + v_n * v_n"""
    return dot(v, v)

def squared_distance(v, w):
    return sum_of_squares(vector_subtract(v, w))

def distance(v, w):
    return math.sqrt(squared_distance(v, w))

def mean(x):
    return sum(x) / len(x)

## Wrap up

In [None]:
from collections import Counter
# from linear_algebra import distance
# from stats import mean
import math, random
import matplotlib.pyplot as plt

def raw_majority_vote(labels):
    votes = Counter(labels)
    winner, _ = votes.most_common(1)[0]
    return winner

def majority_vote(labels):
    """assumes that labels are ordered from nearest to farthest"""
    vote_counts = Counter(labels)
    winner, winner_count = vote_counts.most_common(1)[0]
    num_winners = len([count
                       for count in vote_counts.values()
                       if count == winner_count])

    if num_winners == 1:
        return winner                     # unique winner, so return it
    else:
        return majority_vote(labels[:-1]) # try again without the farthest
    
def knn_classify(k, labeled_points, new_point):
    """each labeled point should be a pair (point, label)"""

    # order the labeled points from nearest to farthest
    by_distance = sorted(labeled_points,
                         key=lambda point_label: distance(point_label[0], new_point))

    # find the labels for the k closest
    k_nearest_labels = [label for _, label in by_distance[:k]]

    # and let them vote
    return majority_vote(k_nearest_labels)

cities = [(-86.75,33.56,'Python'),(-88.25,30.68,'Python'),
          (-112.01,33.43,'Java'),(-110.93,32.11,'Java'),
          (-92.23,34.73,'R'),(-121.95,37.7,'R'),
          (-118.15,33.81,'Python'),(-118.23,34.05,'Java'),
          (-122.31,37.81,'R'),(-117.6,34.05,'Python'),
          (-116.53,33.81,'Python'),(-121.5,38.51,'R'),
          (-117.16,32.73,'R'),(-122.38,37.61,'R'),
          (-121.93,37.36,'R'),(-122.01,36.98,'Python'),
          (-104.71,38.81,'Python'),(-104.86,39.75,'Python'),
          (-72.65,41.73,'R'),(-75.6,39.66,'Python'),
          (-77.03,38.85,'Python'),(-80.26,25.8,'Java'),
          (-81.38,28.55,'Java'),(-82.53,27.96,'Java'),
          (-84.43,33.65,'Python'),(-116.21,43.56,'Python'),
          (-87.75,41.78,'Java'),(-86.28,39.73,'Java'),
          (-93.65,41.53,'Java'),(-97.41,37.65,'Java'),
          (-85.73,38.18,'Python'),(-90.25,29.98,'Java'),
          (-70.31,43.65,'R'),(-76.66,39.18,'R'),
          (-71.03,42.36,'R'),(-72.53,42.2,'R'),
          (-83.01,42.41,'Python'),(-84.6,42.78,'Python'),
          (-93.21,44.88,'Python'),(-90.08,32.31,'Java'),
          (-94.58,39.11,'Java'),(-90.38,38.75,'Python'),
          (-108.53,45.8,'Python'),(-95.9,41.3,'Python'),
          (-115.16,36.08,'Java'),(-71.43,42.93,'R'),
          (-74.16,40.7,'R'),(-106.61,35.05,'Python'),
          (-78.73,42.93,'R'),(-73.96,40.78,'R'),
          (-80.93,35.21,'Python'),(-78.78,35.86,'Python'),
          (-100.75,46.76,'Java'),(-84.51,39.15,'Java'),
          (-81.85,41.4,'Java'),(-82.88,40,'Java'),
          (-97.60,35.40,'Python'),(-122.66,45.53,'Python'),
          (-75.25,39.88,'Python'),(-80.21,40.50,'Python'),
          (-71.43,41.73,'R'),(-81.11,33.95,'R'),
          (-96.73,43.56,'Python'),(-90.00,35.05,'R'),
          (-86.68,36.11,'R'),(-97.70,30.30,'Python'),
          (-96.85,32.85,'Java'),(-95.35,29.96,'Java'),
          (-98.46,29.53,'Java'),(-111.96,40.76,'Python'),
          (-73.15,44.46,'R'),(-77.33,37.50,'Python'),
          (-122.30,47.53,'Python'),(-89.33,43.13,'R'),
          (-104.81,41.15,'Java')]
cities = [([longitude, latitude], language) for longitude, latitude, language in cities]

In [None]:
def plot_state_borders(plt, color='0.8'):
    pass

def plot_cities():

    # key is language, value is pair (longitudes, latitudes)
    plots = { "Java" : ([], []), "Python" : ([], []), "R" : ([], []) }

    # we want each language to have a different marker and color
    markers = { "Java" : "o", "Python" : "s", "R" : "^" }
    colors  = { "Java" : "r", "Python" : "b", "R" : "g" }

    for (longitude, latitude), language in cities:
        plots[language][0].append(longitude)
        plots[language][1].append(latitude)

    # create a scatter series for each language
    for language, (x, y) in plots.items():
        plt.scatter(x, y, color=colors[language], marker=markers[language],
                          label=language, zorder=10)

    plot_state_borders(plt)    # assume we have a function that does this

    plt.legend(loc=0)          # let matplotlib choose the location
    plt.axis([-130,-60,20,55]) # set the axes
    plt.title("Favorite Programming Languages")
    plt.show()

In [None]:
def classify_and_plot_grid(k=1):
    plots = { "Java" : ([], []), "Python" : ([], []), "R" : ([], []) }
    markers = { "Java" : "o", "Python" : "s", "R" : "^" }
    colors  = { "Java" : "r", "Python" : "b", "R" : "g" }

    for longitude in range(-130, -60):
        for latitude in range(20, 55):
            predicted_language = knn_classify(k, cities, [longitude, latitude])
            plots[predicted_language][0].append(longitude)
            plots[predicted_language][1].append(latitude)

    # create a scatter series for each language
    for language, (x, y) in plots.items():
        plt.scatter(x, y, color=colors[language], marker=markers[language],
                          label=language, zorder=0)

    plot_state_borders(plt, color='black')    # assume we have a function that does this

    plt.legend(loc=0)          # let matplotlib choose the location
    plt.axis([-130,-60,20,55]) # set the axes
    plt.title(str(k) + "-Nearest Neighbor Programming Languages")
    plt.show()

In [None]:
#
# the curse of dimensionality
#

def random_point(dim):
    return [random.random() for _ in range(dim)]

def random_distances(dim, num_pairs):
    return [distance(random_point(dim), random_point(dim))
            for _ in range(num_pairs)]

In [None]:
if __name__ == "__main__":

    # try several different values for k
#     for k in [1, 3, 5, 7]:
    for k in [1]:
        num_correct = 0

        for location, actual_language in cities:

            other_cities = [other_city for other_city in cities if other_city != (location, actual_language)]
            predicted_language = knn_classify(k, other_cities, location)

            if predicted_language == actual_language:
                num_correct += 1

        print(k, "neighbor[s]:", num_correct, "correct out of", len(cities))

    dimensions = range(1, 101, 5)

    avg_distances = []
    min_distances = []

    random.seed(0)
    for dim in dimensions:
        distances = random_distances(dim, 10000)  # 10,000 random pairs
        avg_distances.append(mean(distances))     # track the average
        min_distances.append(min(distances))      # track the minimum
        print(dim, min(distances), mean(distances), min(distances) / mean(distances))

## SKLEARN

In [None]:
import numpy as np
from sklearn import neighbors, datasets

In [None]:
n_neighbors = 15

# import some data to play with
iris = datasets.load_iris()

X = iris.data[:, :2]
y = iris.target

h = .02  # step size in the mesh

for weights in ['distance']:
    # we create an instance of Neighbours Classifier and fit the data.
#     clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights)
    clf = neighbors.KNeighborsClassifier(n_neighbors)
    clf.fit(X, y)
    
    # Plot the decision boundary. For that, we will assign a color to each
    # point in the mesh [x_min, x_max]x[y_min, y_max].
    Z = clf.predict(X)
    
    print("result prediction")
    print(Z)
    
    correctness = 0

    for i in range(len(Z)):
        if Z[i] == iris.target[i]:
            correctness += 1
            
    print("Accuracy: {0:.2f} %".format(correctness / len(X) * 100))