Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nguigs authored and ninamiolane committed Mar 22, 2020
1 parent 16ca17b commit 3f1ec06
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions geomstats/learning/frechet_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,22 @@ def variance(points,
if weights is None:
weights = gs.ones((n_points, 1))
weights = gs.array(weights)
weights = gs.reshape(weights, (n_points, 1))
einsum_str = 'nk,nj->j'

sum_weights = gs.sum(weights)
if point_type == 'vector':
points = gs.to_ndarray(points, to_ndim=2)
base_point = gs.to_ndarray(base_point, to_ndim=2)
weights = gs.to_ndarray(weights, to_ndim=2, axis=1)
if point_type == 'matrix':
points = gs.to_ndarray(points, to_ndim=3)
base_point = gs.to_ndarray(base_point, to_ndim=3)
weights = gs.to_ndarray(weights, to_ndim=3, axis=1)
weights = weights[:, :, 0]

var = 0.

sq_dists = metric.squared_dist(base_point, points)
var += gs.einsum(einsum_str, weights, sq_dists)
var += gs.einsum('nk,nj->j', weights, sq_dists)

var = gs.array(var)
var /= sum_weights
Expand Down Expand Up @@ -116,7 +117,8 @@ def while_loop_body(iteration, mean, var, sq_dist):
points=points,
weights=weights,
metric=metric,
base_point=estimate_next)
base_point=estimate_next,
point_type=point_type)

mean = estimate_next
iteration += 1
Expand Down
2 changes: 1 addition & 1 deletion tests/test_spd_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def test_squared_dist_vectorization(self):

self.assertAllClose(gs.shape(result), (1, 1))

@geomstats.tests.np_and_pytorch_only
@geomstats.tests.np_only
def test_parallel_transport_affine_invariant(self):
n_samples = self.n_samples
gs.random.seed(1)
Expand Down

0 comments on commit 3f1ec06

Please sign in to comment.