Skip to content

Commit

Permalink
Adding a kNN classifier description.
Browse files Browse the repository at this point in the history
  • Loading branch information
openAccess committed Jan 25, 2012
1 parent a45063d commit 2f6a9ab
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions openAccess/gitPLoS/classifiers/kNN.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
#!/usr/bin/env python
# openAccess: getPLoS
#
#Copyright (c) 2001-2012 openAccess Project
# Author: Bill OConnor
# URL: <https://github.com/openAccess/gitPLoS>
# For license information, see LICENSE.TXT
#
"""
Description
===========
kNN - k Nearest Neighbor classifier. Takes a list of
(category, [values]) tuples as training data. Data points
are classified by calculating the distance from each
point in the training data. The distances are sorted giving
higher preference to smallest distances. Then k shortest
distances are used to calculate the weights for each class.
"""
from __future__ import division
import math
from collections import *
from math import sqrt

def _vec_minkowski(p, q, e):
l = [ abs(p[i] - q[i])**e for i in xrange(len(q)) ]
Expand All @@ -19,10 +32,10 @@ def _vec_manhattan_dist(p, q):

def _vec_euclidean_dist(p, q):
"""
Vector distance function.
Euclidean Vector distance function.
"""
l = [ p[i]-q[i] for i in xrange(len(p)) ]
return math.sqrt(sum([ d*d for d in l ]))
return sqrt(sum([ d*d for d in l ]))

def _euclidean_dist(x, y):
"""
Expand All @@ -35,10 +48,11 @@ def _eq_weight(x, y):

class kNN(object):
"""
k Nearest Neighbor classifier.
"""
def __init__(self, data, k):
"""
data - a list of (category, value) tuples.
data - a list of (category, [values] ) tuples.
"""
self.data = data # list of (category, value) tuples used for training
self.k = k
Expand Down Expand Up @@ -67,11 +81,3 @@ def classify(self, x, **kwargs):
if len(class_lst) > 0:
return class_lst[0]
return None

if __name__ == '__main__':
data = [ ('a', [1.0]), ('b', [3.0]), ('b', [4.0]), ('a', [1.5]), ('c', [3.5]) ]
knn = kNN(data,3)
r = knn.calculate([4.0])
print knn.classify([4.0])
print r

0 comments on commit 2f6a9ab

Please sign in to comment.