Skip to content
Permalink
Browse files

Merge pull request #267 from nguigs/nguigs-mean

prelimenary fixes
  • Loading branch information
ninamiolane committed Jan 14, 2020
2 parents c68ec17 + fa936d0 commit 386be6d97d96e2062b3732a42c19efcde0c9bbb5
@@ -315,7 +315,8 @@ def mean(self, points,
weights=None,
n_max_iterations=32,
epsilon=EPSILON,
point_type='vector'):
point_type='vector',
verbose=False):
"""
Frechet mean of (weighted) points.
@@ -324,14 +325,16 @@ def mean(self, points,
points: array-like, shape=[n_samples, dimension]
weights: array-like, shape=[n_samples, 1], optional
verbose: bool, optional
"""
# TODO(nina): Profile this code to study performance,
# i.e. what to do with sq_dists_between_iterates.
def while_loop_cond(iteration, mean, variance, sq_dist):
result = gs.logical_or(
result = ~gs.logical_or(
gs.isclose(variance, 0.),
gs.less_equal(sq_dist, epsilon * variance))
return result[0, 0]
return result[0, 0] or iteration == 0

def while_loop_body(iteration, mean, variance, sq_dist):
tangent_mean = gs.zeros_like(mean)
@@ -394,6 +397,10 @@ def while_loop_body(iteration, mean, variance, sq_dist):
print('Maximum number of iterations {} reached.'
'The mean may be inaccurate'.format(n_max_iterations))

if verbose:
print('n_iter: {}, final variance: {}, final dist: {}'.format(
last_iteration, variance, sq_dist))

mean = gs.to_ndarray(mean, to_ndim=2)
return mean

@@ -85,9 +85,8 @@ def belongs(self, point, point_type=None):
elif point_type == 'matrix':
point = gs.to_ndarray(point, to_ndim=3)
point_transpose = gs.transpose(point, axes=(0, 2, 1))
point_inverse = gs.linalg.inv(point)

mask = gs.isclose(point_inverse, point_transpose)
mask = gs.isclose(gs.matmul(point, point_transpose),
gs.eye(self.n))
mask = gs.all(mask, axis=(1, 2))

mask = gs.to_ndarray(mask, to_ndim=1)
@@ -1274,13 +1273,10 @@ def random_uniform(self, n_samples=1, point_type=None):
if point_type is None:
point_type = self.default_point_type

if point_type == 'vector':
random_point = gs.random.rand(n_samples, self.dimension) * 2 - 1
random_point = self.regularize(
random_point, point_type=point_type)
elif point_type == 'matrix':
random_matrix = gs.random.rand(n_samples, self.n, self.n)
random_point = self.projection(random_matrix)
random_point = gs.random.rand(n_samples, self.dimension) * 2 - 1
random_point = self.regularize(random_point, point_type='vector')
if point_type == 'matrix':
random_point = self.matrix_from_rotation_vector(random_point)

return random_point

@@ -216,12 +216,13 @@ def test_skew_matrix_from_vector_vectorization(self):

@geomstats.tests.np_only
def test_random_and_belongs(self):
for n in self.n_seq:
group = self.so[n]
point = group.random_uniform()
result = group.belongs(point)
expected = gs.array([[True]])
self.assertAllClose(result, expected)
for point_type in ('vector', 'matrix'):
for n in self.n_seq:
group = self.so[n]
point = group.random_uniform(point_type=point_type)
result = group.belongs(point, point_type=point_type)
expected = gs.array([[True]])
self.assertAllClose(result, expected)

@geomstats.tests.np_only
def test_random_and_belongs_vectorization(self):

0 comments on commit 386be6d

Please sign in to comment.
You can’t perform that action at this time.