Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Charles Marsh
committed
Apr 30, 2014
1 parent
e8f0eba
commit c1c16e8
Showing
4 changed files
with
51 additions
and
16 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from sklearn.neighbors import KNeighborsClassifier | ||
import numpy as np | ||
|
||
|
||
class kNN(object): | ||
|
||
def __init__(self, classes): | ||
self.model = KNeighborsClassifier(n_neighbors=2) | ||
self.X = None | ||
self.y = None | ||
|
||
def partial_fit(self, x, y): | ||
if self.X is None and self.y is None: | ||
self.X = np.array([x]) | ||
self.y = y | ||
else: | ||
self.X = np.vstack((self.X, x)) | ||
self.y = np.hstack((self.y, y)) | ||
|
||
self.model.fit(self.X, self.y) | ||
|
||
def predict(self, x): | ||
return self.model.predict(x)[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,33 @@ | ||
import unittest | ||
from perceptron import Perceptron | ||
import sys | ||
sys.path.append("../") | ||
import numpy as np | ||
from ..learners.knn import kNN | ||
|
||
|
||
class TestPerceptron(unittest.TestCase): | ||
class TestLearners(unittest.TestCase): | ||
|
||
def testSimple(self): | ||
def setUp(self): | ||
x1 = [1, 2, 3] | ||
y1 = 1 | ||
x2 = [1, 3, 5] | ||
y2 = -1 | ||
y2 = 1 | ||
x3 = [-1, -1, -1] | ||
y3 = 1 | ||
y3 = -1 | ||
x4 = [1, 2, 4] | ||
y4 = 1 | ||
data = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] | ||
p = Perceptron() | ||
for (x, y) in data: | ||
p.update(x, y) | ||
self.assertEqual(p.w, [1.0, 1.0, 2.0]) | ||
x5 = [-0.5, -0.5, 0] | ||
y5 = -1 | ||
data = [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x5, y5)] | ||
self.data = [(np.array(x), np.array([y])) for (x, y) in data] | ||
self.labels = np.array([-1, 1]) | ||
|
||
def testKNN(self): | ||
model = kNN(self.labels) | ||
for (x, y) in self.data: | ||
model.partial_fit(x, y) | ||
print model.predict(np.array([0, 0, 0])) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |