From c1c16e855be2a029de0b16988bea337831483111 Mon Sep 17 00:00:00 2001 From: Charles Marsh Date: Wed, 30 Apr 2014 14:00:51 -0400 Subject: [PATCH] Iterative kNN with no flash --- __init__.py | 0 learners/knn.py | 23 +++++++++++++++++++++++ main.py | 14 ++++++++------ utils/test.py | 30 ++++++++++++++++++++---------- 4 files changed, 51 insertions(+), 16 deletions(-) create mode 100644 __init__.py create mode 100644 learners/knn.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/learners/knn.py b/learners/knn.py new file mode 100644 index 0000000..2df2f6d --- /dev/null +++ b/learners/knn.py @@ -0,0 +1,23 @@ +from sklearn.neighbors import KNeighborsClassifier +import numpy as np + + +class kNN(object): + + def __init__(self, classes): + self.model = KNeighborsClassifier(n_neighbors=2) + self.X = None + self.y = None + + def partial_fit(self, x, y): + if self.X is None and self.y is None: + self.X = np.array([x]) + self.y = y + else: + self.X = np.vstack((self.X, x)) + self.y = np.hstack((self.y, y)) + + self.model.fit(self.X, self.y) + + def predict(self, x): + return self.model.predict(x)[0] diff --git a/main.py b/main.py index cc4e9d1..855d3a3 100644 --- a/main.py +++ b/main.py @@ -1,21 +1,23 @@ -from sklearn.datasets import load_svmlight_file -import numpy as np from random import shuffle +import warnings + +from sklearn.datasets import load_svmlight_file +from utils.experiment import test + from ensemblers.adaboost import AdaBooster from ensemblers.ogboost import OGBooster from ensemblers.ocpboost import OCPBooster from ensemblers.expboost import EXPBooster + from learners.sk_naive_bayes import NaiveBayes from learners.perceptron import Perceptron from learners.random_stump import RandomStump from learners.decision_stump import DecisionStump -from learners.ce_knn import kNN +from learners.knn import kNN from learners.histogram import RNB from learners.winnow import Winnow -from utils.experiment import test -import warnings -warnings.filterwarnings("ignore", module="sklearn") +warnings.filterwarnings("ignore", module="sklearn") def loadData(filename): X, y = load_svmlight_file(filename) diff --git a/utils/test.py b/utils/test.py index 1c962cf..0f71d69 100644 --- a/utils/test.py +++ b/utils/test.py @@ -1,23 +1,33 @@ import unittest -from perceptron import Perceptron +import sys +sys.path.append("../") +import numpy as np +from ..learners.knn import kNN -class TestPerceptron(unittest.TestCase): +class TestLearners(unittest.TestCase): - def testSimple(self): + def setUp(self): x1 = [1, 2, 3] y1 = 1 x2 = [1, 3, 5] - y2 = -1 + y2 = 1 x3 = [-1, -1, -1] - y3 = 1 + y3 = -1 x4 = [1, 2, 4] y4 = 1 - data = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] - p = Perceptron() - for (x, y) in data: - p.update(x, y) - self.assertEqual(p.w, [1.0, 1.0, 2.0]) + x5 = [-0.5, -0.5, 0] + y5 = -1 + data = [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x5, y5)] + self.data = [(np.array(x), np.array([y])) for (x, y) in data] + self.labels = np.array([-1, 1]) + + def testKNN(self): + model = kNN(self.labels) + for (x, y) in self.data: + model.partial_fit(x, y) + print model.predict(np.array([0, 0, 0])) + if __name__ == '__main__': unittest.main()