Skip to content

# geomstats / geomstats

Merge pull request #267 from nguigs/nguigs-mean

`prelimenary fixes`
• Loading branch information
ninamiolane committed Jan 14, 2020
2 parents c68ec17 + fa936d0 commit 386be6d97d96e2062b3732a42c19efcde0c9bbb5
 @@ -315,7 +315,8 @@ def mean(self, points, weights=None, n_max_iterations=32, epsilon=EPSILON, point_type='vector'): point_type='vector', verbose=False): """ Frechet mean of (weighted) points. @@ -324,14 +325,16 @@ def mean(self, points, points: array-like, shape=[n_samples, dimension] weights: array-like, shape=[n_samples, 1], optional verbose: bool, optional """ # TODO(nina): Profile this code to study performance, # i.e. what to do with sq_dists_between_iterates. def while_loop_cond(iteration, mean, variance, sq_dist): result = gs.logical_or( result = ~gs.logical_or( gs.isclose(variance, 0.), gs.less_equal(sq_dist, epsilon * variance)) return result[0, 0] return result[0, 0] or iteration == 0 def while_loop_body(iteration, mean, variance, sq_dist): tangent_mean = gs.zeros_like(mean) @@ -394,6 +397,10 @@ def while_loop_body(iteration, mean, variance, sq_dist): print('Maximum number of iterations {} reached.' 'The mean may be inaccurate'.format(n_max_iterations)) if verbose: print('n_iter: {}, final variance: {}, final dist: {}'.format( last_iteration, variance, sq_dist)) mean = gs.to_ndarray(mean, to_ndim=2) return mean
 @@ -85,9 +85,8 @@ def belongs(self, point, point_type=None): elif point_type == 'matrix': point = gs.to_ndarray(point, to_ndim=3) point_transpose = gs.transpose(point, axes=(0, 2, 1)) point_inverse = gs.linalg.inv(point) mask = gs.isclose(point_inverse, point_transpose) mask = gs.isclose(gs.matmul(point, point_transpose), gs.eye(self.n)) mask = gs.all(mask, axis=(1, 2)) mask = gs.to_ndarray(mask, to_ndim=1) @@ -1274,13 +1273,10 @@ def random_uniform(self, n_samples=1, point_type=None): if point_type is None: point_type = self.default_point_type if point_type == 'vector': random_point = gs.random.rand(n_samples, self.dimension) * 2 - 1 random_point = self.regularize( random_point, point_type=point_type) elif point_type == 'matrix': random_matrix = gs.random.rand(n_samples, self.n, self.n) random_point = self.projection(random_matrix) random_point = gs.random.rand(n_samples, self.dimension) * 2 - 1 random_point = self.regularize(random_point, point_type='vector') if point_type == 'matrix': random_point = self.matrix_from_rotation_vector(random_point) return random_point
 @@ -216,12 +216,13 @@ def test_skew_matrix_from_vector_vectorization(self): @geomstats.tests.np_only def test_random_and_belongs(self): for n in self.n_seq: group = self.so[n] point = group.random_uniform() result = group.belongs(point) expected = gs.array([[True]]) self.assertAllClose(result, expected) for point_type in ('vector', 'matrix'): for n in self.n_seq: group = self.so[n] point = group.random_uniform(point_type=point_type) result = group.belongs(point, point_type=point_type) expected = gs.array([[True]]) self.assertAllClose(result, expected) @geomstats.tests.np_only def test_random_and_belongs_vectorization(self):

#### 0 comments on commit `386be6d`

Please sign in to comment.
You can’t perform that action at this time.