Skip to content
Permalink
Browse files

Adapt unit tests in kmeans and online kmeans

  • Loading branch information
ninamiolane committed Feb 6, 2020
1 parent 6fb156d commit 3e981878a7e6c654e56732922e3f534caba22a55
Showing with 12 additions and 7 deletions.
  1. +6 −5 tests/test_online_kmeans.py
  2. +6 −2 tests/test_riemannian_kmeans.py
@@ -1,10 +1,9 @@
"""
Unit tests for Online k-means.
"""
"""Unit tests for Online k-means."""

import geomstats.backend as gs
import geomstats.tests
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.frechet_mean import FrechetMean
from geomstats.learning.online_kmeans import OnlineKMeans

TOLERANCE = 1e-3
@@ -28,8 +27,10 @@ def test_fit(self):
clustering.fit(X)

center = clustering.cluster_centers_
mean = self.metric.mean(X)
result = self.metric.dist(center, mean)
mean = FrechetMean(metric=self.metric)
mean.fit(X)

result = self.metric.dist(center, mean.estimate_)
expected = 0.
self.assertAllClose(expected, result, atol=TOLERANCE)

@@ -3,6 +3,7 @@
import geomstats.backend as gs
import geomstats.tests
from geomstats.geometry import hypersphere
from geomstats.learning.frechet_mean import FrechetMean
from geomstats.learning.kmeans import RiemannianKMeans


@@ -21,8 +22,11 @@ def test_hypersphere_kmeans_fit(self):
kmeans = RiemannianKMeans(metric, 1, tol=1e-3)
kmeans.fit(x)
center = kmeans.centroids
mean = metric.mean(x)
result = metric.dist(center, mean)

mean = FrechetMean(metric=metric)
mean.fit(x)

result = metric.dist(center, mean.estimate_)
expected = 0.
self.assertAllClose(expected, result, atol=1e-2)

0 comments on commit 3e98187

Please sign in to comment.
You can’t perform that action at this time.