Skip to content

Commit

Permalink
Rework classify0 into knn_classify
Browse files Browse the repository at this point in the history
Fix pep8 issues along with renaming
the variables for better quality.
  • Loading branch information
James Saryerwinnie committed Jun 4, 2011
1 parent 4a2578c commit 10e5cf1
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions Ch02/kNN.py
Expand Up @@ -17,24 +17,33 @@
from numpy import tile, array, zeros, shape


def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
sortedDistIndicies = distances.argsort()
classCount = {}
def knn_classify(input_vector, training_set, labels, k):
"""Classify input vector using k nearest neighbors.
Args:
input_vector: The input vector to classify.
training_set: The matrix of training examples.
labels: The class labels.
k: The number of neighbors t ouse.
"""
training_set_size = training_set.shape[0]
diff_matrix = tile(input_vector, (training_set_size, 1)) - training_set
diff_matrix_squared = diff_matrix ** 2
distances_squared = diff_matrix_squared.sum(axis=1)
distances = distances_squared ** 0.5
sorted_distance_indices = distances.argsort()
class_count = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
current_label = labels[sorted_distance_indices[i]]
class_count[current_label] = class_count.get(current_label, 0) + 1

sortedClassCount = sorted(classCount.iteritems(),
sorted_class_count = sorted(class_count.iteritems(),
key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
return sorted_class_count[0][0]


def createDataSet():
def create_data_set():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
Expand Down Expand Up @@ -68,15 +77,16 @@ def autoNorm(dataSet):
return normDataSet, ranges, minVals


def datingClassTest():
hoRatio = 0.50 #hold out 10%
def dating_class_test():
# Hold out 10%.
hoRatio = 0.50
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m * hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m,:],
classifierResult = knn_classify(normMat[i,:], normMat[numTestVecs:m,:],
datingLabels[numTestVecs:m], 3)
print "the classifier came back with: %d, the real answer is: %d" % \
(classifierResult, datingLabels[i])
Expand Down Expand Up @@ -115,7 +125,7 @@ def handwritingClassTest():
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
classifierResult = knn_classify(vectorUnderTest, trainingMat, hwLabels, 3)
print "the classifier came back with: %d, the real answer is: %d" % \
(classifierResult, classNumStr)
if classifierResult != classNumStr:
Expand Down

0 comments on commit 10e5cf1

Please sign in to comment.