Skip to content
Browse files

Fix for issue #10

KMeansClustering now accepts an optional equality function. This is useful
when using numpy arrays as inputs.
  • Loading branch information...
1 parent c628d8b commit 9b8358d8329aa43975f7fc21a063727b8fbc8075 @exhuma committed Mar 11, 2013
Showing with 51 additions and 13 deletions.
  1. +15 −3 cluster.py
  2. +36 −10 test.py
View
18 cluster.py
@@ -680,7 +680,7 @@ class KMeansClustering:
>>> clusters = cl.getclusters(2)
"""
- def __init__(self, data, distance=None):
+ def __init__(self, data, distance=None, equality=None):
"""
Constructor
@@ -690,11 +690,14 @@ def __init__(self, data, distance=None):
Default: It assumes the tuples contain numeric values
and appiles a generalised form of the
euclidian-distance algorithm on them.
+ equality - A function to test equality of items. By default the
+ standard python equality operator (``==``) is applied.
"""
self.__clusters = []
self.__data = data
self.distance = distance
self.__initial_length = len(data)
+ self.equality = equality
# test if each item is of same dimensions
if len(data) > 1 and isinstance(data[0], TupleType):
@@ -768,7 +771,7 @@ def assign_item(self, item, origin):
centroid(closest_cluster)):
closest_cluster = cluster
- if closest_cluster != origin:
+ if id(closest_cluster) != id(origin):
self.move_item(item, origin, closest_cluster)
return True
else:
@@ -784,7 +787,16 @@ def move_item(self, item, origin, destination):
origin - the originating cluster
destination - the target cluster
"""
- destination.append(origin.pop(origin.index(item)))
+ if self.equality:
+ item_index = 0
+ for i, element in enumerate(origin):
+ if self.equality(element, item):
+ item_index = i
+ break
+ else:
+ item_index = origin.index(item)
+
+ destination.append(origin.pop(item_index))
def initialise_clusters(self, input_, clustercount):
"""
View
46 test.py
@@ -202,15 +202,41 @@ def testLostFunctionReference(self):
expected),
"Elements differ!\n%s\n%s" % (clusters, expected))
+ def testMultidimArray(self):
+ from random import random
+ data = []
+ for _ in range(200):
+ data.append([random(), random()])
+ cl = KMeansClustering(data, lambda p0, p1: (
+ p0[0] - p1[0]) ** 2 + (p0[1] - p1[1]) ** 2)
+ cl.getclusters(10)
+
+
+class NumpyTests(unittest.TestCase):
+
+ def testNumpyRandom(self):
+ from cluster import KMeansClustering
+ from numpy import random as rnd
+ data = rnd.rand(500, 2)
+ cl = KMeansClustering(data, lambda p0, p1: (
+ p0[0] - p1[0]) ** 2 + (p0[1] - p1[1]) ** 2, numpy.array_equal)
+ cl.getclusters(10)
+
if __name__ == '__main__':
- unittest.TextTestRunner(verbosity=2).run(
- unittest.TestSuite((
- unittest.makeSuite(HClusterSmallListTestCase),
- unittest.makeSuite(HClusterIntegerTestCase),
- unittest.makeSuite(HClusterStringTestCase),
- unittest.makeSuite(KClusterSmallListTestCase),
- unittest.makeSuite(KCluster2DTestCase),
- unittest.makeSuite(KClusterSFBugs),
- ))
- )
+ suite = unittest.TestSuite((
+ unittest.makeSuite(HClusterSmallListTestCase),
+ unittest.makeSuite(HClusterIntegerTestCase),
+ unittest.makeSuite(HClusterStringTestCase),
+ unittest.makeSuite(KClusterSmallListTestCase),
+ unittest.makeSuite(KCluster2DTestCase),
+ unittest.makeSuite(KClusterSFBugs)))
+
+ try:
+ import numpy # NOQA
+ tests = unittest.makeSuite(NumpyTests)
+ suite.addTests(tests)
+ except ImportError:
+ print "numpy not available. Associated test will not be loaded!"
+
+ unittest.TextTestRunner(verbosity=2).run(suite)

0 comments on commit 9b8358d

Please sign in to comment.
Something went wrong with that request. Please try again.