Skip to content

Commit

Permalink
Fix for issue #10
Browse files Browse the repository at this point in the history
KMeansClustering now accepts an optional equality function. This is useful
when using numpy arrays as inputs.
  • Loading branch information
exhuma committed Mar 11, 2013
1 parent c628d8b commit 9b8358d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 13 deletions.
18 changes: 15 additions & 3 deletions cluster.py
Expand Up @@ -680,7 +680,7 @@ class KMeansClustering:
>>> clusters = cl.getclusters(2)
"""

def __init__(self, data, distance=None):
def __init__(self, data, distance=None, equality=None):
"""
Constructor
Expand All @@ -690,11 +690,14 @@ def __init__(self, data, distance=None):
Default: It assumes the tuples contain numeric values
and appiles a generalised form of the
euclidian-distance algorithm on them.
equality - A function to test equality of items. By default the
standard python equality operator (``==``) is applied.
"""
self.__clusters = []
self.__data = data
self.distance = distance
self.__initial_length = len(data)
self.equality = equality

# test if each item is of same dimensions
if len(data) > 1 and isinstance(data[0], TupleType):
Expand Down Expand Up @@ -768,7 +771,7 @@ def assign_item(self, item, origin):
centroid(closest_cluster)):
closest_cluster = cluster

if closest_cluster != origin:
if id(closest_cluster) != id(origin):
self.move_item(item, origin, closest_cluster)
return True
else:
Expand All @@ -784,7 +787,16 @@ def move_item(self, item, origin, destination):
origin - the originating cluster
destination - the target cluster
"""
destination.append(origin.pop(origin.index(item)))
if self.equality:
item_index = 0
for i, element in enumerate(origin):
if self.equality(element, item):
item_index = i
break
else:
item_index = origin.index(item)

destination.append(origin.pop(item_index))

def initialise_clusters(self, input_, clustercount):
"""
Expand Down
46 changes: 36 additions & 10 deletions test.py
Expand Up @@ -202,15 +202,41 @@ def testLostFunctionReference(self):
expected),
"Elements differ!\n%s\n%s" % (clusters, expected))

def testMultidimArray(self):
from random import random
data = []
for _ in range(200):
data.append([random(), random()])
cl = KMeansClustering(data, lambda p0, p1: (
p0[0] - p1[0]) ** 2 + (p0[1] - p1[1]) ** 2)
cl.getclusters(10)


class NumpyTests(unittest.TestCase):

def testNumpyRandom(self):
from cluster import KMeansClustering
from numpy import random as rnd
data = rnd.rand(500, 2)
cl = KMeansClustering(data, lambda p0, p1: (
p0[0] - p1[0]) ** 2 + (p0[1] - p1[1]) ** 2, numpy.array_equal)
cl.getclusters(10)


if __name__ == '__main__':
unittest.TextTestRunner(verbosity=2).run(
unittest.TestSuite((
unittest.makeSuite(HClusterSmallListTestCase),
unittest.makeSuite(HClusterIntegerTestCase),
unittest.makeSuite(HClusterStringTestCase),
unittest.makeSuite(KClusterSmallListTestCase),
unittest.makeSuite(KCluster2DTestCase),
unittest.makeSuite(KClusterSFBugs),
))
)
suite = unittest.TestSuite((
unittest.makeSuite(HClusterSmallListTestCase),
unittest.makeSuite(HClusterIntegerTestCase),
unittest.makeSuite(HClusterStringTestCase),
unittest.makeSuite(KClusterSmallListTestCase),
unittest.makeSuite(KCluster2DTestCase),
unittest.makeSuite(KClusterSFBugs)))

try:
import numpy # NOQA
tests = unittest.makeSuite(NumpyTests)
suite.addTests(tests)
except ImportError:
print "numpy not available. Associated test will not be loaded!"

unittest.TextTestRunner(verbosity=2).run(suite)

0 comments on commit 9b8358d

Please sign in to comment.