Skip to content

Commit

Permalink
Merge pull request #267 from nguigs/nguigs-mean
Browse files Browse the repository at this point in the history
prelimenary fixes
  • Loading branch information
ninamiolane committed Jan 14, 2020
2 parents c68ec17 + fa936d0 commit 386be6d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 19 deletions.
13 changes: 10 additions & 3 deletions geomstats/geometry/riemannian_metric.py
Expand Up @@ -315,7 +315,8 @@ def mean(self, points,
weights=None,
n_max_iterations=32,
epsilon=EPSILON,
point_type='vector'):
point_type='vector',
verbose=False):
"""
Frechet mean of (weighted) points.
Expand All @@ -324,14 +325,16 @@ def mean(self, points,
points: array-like, shape=[n_samples, dimension]
weights: array-like, shape=[n_samples, 1], optional
verbose: bool, optional
"""
# TODO(nina): Profile this code to study performance,
# i.e. what to do with sq_dists_between_iterates.
def while_loop_cond(iteration, mean, variance, sq_dist):
result = gs.logical_or(
result = ~gs.logical_or(
gs.isclose(variance, 0.),
gs.less_equal(sq_dist, epsilon * variance))
return result[0, 0]
return result[0, 0] or iteration == 0

def while_loop_body(iteration, mean, variance, sq_dist):
tangent_mean = gs.zeros_like(mean)
Expand Down Expand Up @@ -394,6 +397,10 @@ def while_loop_body(iteration, mean, variance, sq_dist):
print('Maximum number of iterations {} reached.'
'The mean may be inaccurate'.format(n_max_iterations))

if verbose:
print('n_iter: {}, final variance: {}, final dist: {}'.format(
last_iteration, variance, sq_dist))

mean = gs.to_ndarray(mean, to_ndim=2)
return mean

Expand Down
16 changes: 6 additions & 10 deletions geomstats/geometry/special_orthogonal_group.py
Expand Up @@ -85,9 +85,8 @@ def belongs(self, point, point_type=None):
elif point_type == 'matrix':
point = gs.to_ndarray(point, to_ndim=3)
point_transpose = gs.transpose(point, axes=(0, 2, 1))
point_inverse = gs.linalg.inv(point)

mask = gs.isclose(point_inverse, point_transpose)
mask = gs.isclose(gs.matmul(point, point_transpose),
gs.eye(self.n))
mask = gs.all(mask, axis=(1, 2))

mask = gs.to_ndarray(mask, to_ndim=1)
Expand Down Expand Up @@ -1274,13 +1273,10 @@ def random_uniform(self, n_samples=1, point_type=None):
if point_type is None:
point_type = self.default_point_type

if point_type == 'vector':
random_point = gs.random.rand(n_samples, self.dimension) * 2 - 1
random_point = self.regularize(
random_point, point_type=point_type)
elif point_type == 'matrix':
random_matrix = gs.random.rand(n_samples, self.n, self.n)
random_point = self.projection(random_matrix)
random_point = gs.random.rand(n_samples, self.dimension) * 2 - 1
random_point = self.regularize(random_point, point_type='vector')
if point_type == 'matrix':
random_point = self.matrix_from_rotation_vector(random_point)

return random_point

Expand Down
13 changes: 7 additions & 6 deletions tests/test_special_orthogonal_group.py
Expand Up @@ -216,12 +216,13 @@ def test_skew_matrix_from_vector_vectorization(self):

@geomstats.tests.np_only
def test_random_and_belongs(self):
for n in self.n_seq:
group = self.so[n]
point = group.random_uniform()
result = group.belongs(point)
expected = gs.array([[True]])
self.assertAllClose(result, expected)
for point_type in ('vector', 'matrix'):
for n in self.n_seq:
group = self.so[n]
point = group.random_uniform(point_type=point_type)
result = group.belongs(point, point_type=point_type)
expected = gs.array([[True]])
self.assertAllClose(result, expected)

@geomstats.tests.np_only
def test_random_and_belongs_vectorization(self):
Expand Down

0 comments on commit 386be6d

Please sign in to comment.