Skip to content

Commit

Permalink
Adapt unit tests in kmeans and online kmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
ninamiolane committed Feb 12, 2020
1 parent 6fb156d commit 3e98187
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
11 changes: 6 additions & 5 deletions tests/test_online_kmeans.py
@@ -1,10 +1,9 @@
""" """Unit tests for Online k-means."""
Unit tests for Online k-means.
"""


import geomstats.backend as gs import geomstats.backend as gs
import geomstats.tests import geomstats.tests
from geomstats.geometry.hypersphere import Hypersphere from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.frechet_mean import FrechetMean
from geomstats.learning.online_kmeans import OnlineKMeans from geomstats.learning.online_kmeans import OnlineKMeans


TOLERANCE = 1e-3 TOLERANCE = 1e-3
Expand All @@ -28,8 +27,10 @@ def test_fit(self):
clustering.fit(X) clustering.fit(X)


center = clustering.cluster_centers_ center = clustering.cluster_centers_
mean = self.metric.mean(X) mean = FrechetMean(metric=self.metric)
result = self.metric.dist(center, mean) mean.fit(X)

result = self.metric.dist(center, mean.estimate_)
expected = 0. expected = 0.
self.assertAllClose(expected, result, atol=TOLERANCE) self.assertAllClose(expected, result, atol=TOLERANCE)


Expand Down
8 changes: 6 additions & 2 deletions tests/test_riemannian_kmeans.py
Expand Up @@ -3,6 +3,7 @@
import geomstats.backend as gs import geomstats.backend as gs
import geomstats.tests import geomstats.tests
from geomstats.geometry import hypersphere from geomstats.geometry import hypersphere
from geomstats.learning.frechet_mean import FrechetMean
from geomstats.learning.kmeans import RiemannianKMeans from geomstats.learning.kmeans import RiemannianKMeans




Expand All @@ -21,8 +22,11 @@ def test_hypersphere_kmeans_fit(self):
kmeans = RiemannianKMeans(metric, 1, tol=1e-3) kmeans = RiemannianKMeans(metric, 1, tol=1e-3)
kmeans.fit(x) kmeans.fit(x)
center = kmeans.centroids center = kmeans.centroids
mean = metric.mean(x)
result = metric.dist(center, mean) mean = FrechetMean(metric=metric)
mean.fit(x)

result = metric.dist(center, mean.estimate_)
expected = 0. expected = 0.
self.assertAllClose(expected, result, atol=1e-2) self.assertAllClose(expected, result, atol=1e-2)


Expand Down

0 comments on commit 3e98187

Please sign in to comment.