Skip to content

Commit

Permalink
Fixed problems with classify_myo when starting with no training data.
Browse files Browse the repository at this point in the history
  • Loading branch information
dzhu committed Dec 2, 2014
1 parent 0baebc1 commit 460bdc2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
4 changes: 2 additions & 2 deletions classify_myo.py
Expand Up @@ -85,10 +85,10 @@ def __call__(self, emg, moving):
scr.fill((0,0,0), (x+130, y + txt.get_height() / 2 - 10, len(m.history) * 20, 20))
scr.fill(clr, (x+130, y + txt.get_height() / 2 - 10, m.history_cnt[i] * 20, 20))

if HAVE_SK:
if HAVE_SK and m.cls.nn is not None:
dists, inds = m.cls.nn.kneighbors(hnd.emg)
for i, (d, ind) in enumerate(zip(dists[0], inds[0])):
y = m.cls.Y[3*ind]
y = m.cls.Y[myo.SUBSAMPLE*ind]
text(scr, font, '%d %6d' % (y, d), (650, 20 * i))

pygame.display.flip()
Expand Down
11 changes: 7 additions & 4 deletions myo.py
Expand Up @@ -15,6 +15,8 @@
from common import *
from myo_raw import MyoRaw

SUBSAMPLE = 3
K = 15

class NNClassifier(object):
'''A wrapper for sklearn's nearest-neighbor classifier that stores
Expand Down Expand Up @@ -43,8 +45,9 @@ def read_data(self):
def train(self, X, Y):
self.X = X
self.Y = Y
if HAVE_SK and self.X.shape[0] >= 20:
self.nn = neighbors.KNeighborsClassifier(n_neighbors=15, algorithm='kd_tree').fit(self.X[::3], self.Y[::3])
if HAVE_SK and self.X.shape[0] >= K * SUBSAMPLE:
self.nn = neighbors.KNeighborsClassifier(n_neighbors=K, algorithm='kd_tree')
self.nn.fit(self.X[::SUBSAMPLE], self.Y[::SUBSAMPLE])
else:
self.nn = None

Expand All @@ -54,8 +57,8 @@ def nearest(self, d):
return self.Y[ind]

def classify(self, d):
if self.X.shape[0] < 20: return 0
if self.nn is None: return self.nearest(d)
if self.X.shape[0] < K * SUBSAMPLE: return 0
if not HAVE_SK: return self.nearest(d)
return int(self.nn.predict(d)[0])


Expand Down

0 comments on commit 460bdc2

Please sign in to comment.