In [493]:
import csv
import math
import operator
import random

In [494]:
def loadDataset(filename, split, trainingSet=[], testSet=[]):
    with open(filename, 'r') as csvfile:
        lines = csv.reader(csvfile)
        dataset = list(lines)
        for x in range(len(dataset)):
            for y in range(4):
                dataset[x][y] = float(dataset[x][y])
            if random.random() < split:
                trainingSet.append(dataset[x])
            else:
                testSet.append(dataset[x])

In [495]:
trainingSet = []
testSet = []
split = 0.76
loadDataset(r'iris.csv', split, trainingSet, testSet)
print('Train: ' + str(len(trainingSet)))
print('Test: ' + str(len(testSet)))

Train: 40
Test: 10


In [496]:
# Euclidean distance
def euclideanDistance(point1, point2, length=3):
    distance = 0
    for i in range(length):
        distance += pow(point2[i] - point1[1], 2)
    return math.sqrt(distance)

In [497]:
def getNeighbors(training_data, testing_instance, k=1):
    distances = []
    neighbors = []
    length = len(training_data)
    for i in range(length):
        distance = euclideanDistance(training_data[i], testing_instance)
        distances.append([training_data[i], distance])
    distances.sort(key=operator.itemgetter(1))
    for i in range(k):
        neighbors.append(distances[i][0])
    return neighbors

In [498]:
training_data = [[2, 2, 2, 'a'], [4, 4, 4, 'b'], [5.4, 5.4, 5.4, 'c']]
testing_instance = [5, 5, 5]
neighbors = getNeighbors(training_data, testing_instance)
print('Instance: ' + str(testing_instance), 'Closest point: ' + str(neighbors))

Instance: [5, 5, 5] Closest point: [[5.4, 5.4, 5.4, 'c']]


In [499]:
training_data = [[2, 2, 2, 'a'], [4, 4, 4, 'b'], [5.4, 5.4, 5.4, 'c']]
testing_instance = [5, 5, 5]
k = 2
neighbors = getNeighbors(training_data, testing_instance, k)
print('Instance: ' + str(testing_instance), 'Closest ' + str(k) + ' points: ' + str(neighbors))

Instance: [5, 5, 5] Closest 2 points: [[5.4, 5.4, 5.4, 'c'], [4, 4, 4, 'b']]


In [500]:
def getResponse(neighbors):
    classVotes = {}
    for i in range(len(neighbors)):
        response = neighbors[i][-1]
        if response in classVotes:
            classVotes[response] += 1
        else:
            classVotes[response] = 1
    sortedVotes = sorted(classVotes.items(), key=operator.itemgetter(1), reverse=True)
    return sortedVotes[0][0]

In [501]:
neighbors = [[1, 1, 1, 'a'], [2, 2, 2, 'a'], [3, 3, 3, 'b']]
print(getResponse(neighbors))

a


In [502]:
def getAccuracy(testset, predictions):
    correct = sum(testset[i][-1] == predictions[i] for i in range(len(testset)))
    return (correct/float(len(testset))) * 100 

In [503]:
testset = [[1, 1, 1, 'a'], [2, 2, 2, 'a'], [3, 3, 3, 'b']]
predictions = ['a', 'a', 'a']
accuracy = getAccuracy(testset, predictions)
print(accuracy)

66.66666666666666


In [504]:
def main():
    trainingSet = []
    testSet = []
    split = 0.76
    predictions = []
    k = 3
    loadDataset(r'iris.csv', split, trainingSet, testSet)
    print('Train: ' + str(len(trainingSet)))
    print('Test: ' + str(len(testSet)))
    for i in range(len(testSet)):
        neighbors = getNeighbors(trainingSet, testSet[i], k)
        result = getResponse(neighbors)
        predictions.append(result)
        print('Predicted=' + str(result) + ', actual=' + str(testSet[i][-1]))
    accuracy = getAccuracy(testSet, predictions)
    print('Accuracy: ' + str(accuracy) + '%')
    
main()
    

Train: 39
Test: 11
[[4.438747872550238, 3.7026114079567156, 1.392436401603738, 0.155817369669899, 'Iris-setosa'], [4.029275043545472, 2.2847105365989426, 1.3351609832472833, 0.5489866887188687, 'Iris-setosa'], [4.131426943929544, 3.4182268219350394, 1.886383328900857, 0.7519077616648743, 'Iris-setosa'], [5.364614052853935, 3.7287934937376406, 1.4374800053181174, 0.8413735493168728, 'Iris-setosa'], [4.208430258441613, 3.3541509594983894, 1.9014649357654683, 0.961672972605098, 'Iris-setosa'], [5.369390600395904, 2.775953863534106, 1.001926949792523, 0.34826636542719713, 'Iris-setosa'], [5.097834124049086, 3.7328955020500234, 1.1037126585578467, 0.2939511256241618, 'Iris-setosa'], [5.373112961810461, 3.680052960513409, 1.5681631044512983, 0.5422128714447274, 'Iris-setosa'], [5.181345703842265, 2.1679235406253854, 1.9269413720687898, 0.19562040746897402, 'Iris-setosa'], [5.411279298045722, 2.8675748465563515, 1.3023749560583577, 0.45955947645466555, 'Iris-setosa'], [4.781803095700177, 3.11