-
Notifications
You must be signed in to change notification settings - Fork 239
/
hypersphere.py
1097 lines (919 loc) · 39.5 KB
/
hypersphere.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""The n-dimensional hypersphere.
The n-dimensional hypersphere embedded in (n+1)-dimensional
Euclidean space.
"""
import logging
import math
from itertools import product
from scipy.stats import beta
import geomstats.algebra_utils as utils
import geomstats.backend as gs
from geomstats.geometry.base import LevelSet
from geomstats.geometry.euclidean import Euclidean, EuclideanMetric
from geomstats.geometry.riemannian_metric import RiemannianMetric
class _Hypersphere(LevelSet):
"""Private class for the n-dimensional hypersphere.
Class for the n-dimensional hypersphere embedded in the
(n+1)-dimensional Euclidean space.
By default, points are parameterized by their extrinsic
(n+1)-coordinates. For dimensions 1 and 2, this can be changed with the
`default_coords_type` parameter. For dimensions 1 (the circle),
the intrinsic coordinates correspond angles in radians, with 0. mapping
to point [1., 0.]. For dimension 2, the intrinsic coordinates are the
spherical coordinates from the north pole, i.e. where angles [0., 0.]
correspond to point [0., 0., 1.].
Parameters
----------
dim : int
Dimension of the hypersphere.
default_coords_type : str, {'extrinsic', 'intrinsic'}
Type of representation for dimensions 1 and 2.
"""
def __init__(self, dim, default_coords_type="extrinsic"):
super(_Hypersphere, self).__init__(
dim=dim,
embedding_space=Euclidean(dim + 1),
submersion=lambda x: gs.sum(x ** 2, axis=-1),
value=1.0,
tangent_submersion=lambda v, x: 2 * gs.sum(x * v, axis=-1),
default_coords_type=default_coords_type,
)
def projection(self, point):
"""Project a point on the hypersphere.
Parameters
----------
point : array-like, shape=[..., dim + 1]
Point in embedding Euclidean space.
Returns
-------
projected_point : array-like, shape=[..., dim + 1]
Point projected on the hypersphere.
"""
norm = gs.linalg.norm(point, axis=-1)
projected_point = gs.einsum("...,...i->...i", 1.0 / norm, point)
return projected_point
def to_tangent(self, vector, base_point):
"""Project a vector to the tangent space.
Project a vector in Euclidean space
on the tangent space of the hypersphere at a base point.
Parameters
----------
vector : array-like, shape=[..., dim + 1]
Vector in Euclidean space.
base_point : array-like, shape=[..., dim + 1]
Point on the hypersphere defining the tangent space,
where the vector will be projected.
Returns
-------
tangent_vec : array-like, shape=[..., dim + 1]
Tangent vector in the tangent space of the hypersphere
at the base point.
"""
sq_norm = gs.sum(base_point ** 2, axis=-1)
inner_prod = self.embedding_metric.inner_product(base_point, vector)
coef = inner_prod / sq_norm
tangent_vec = vector - gs.einsum("...,...j->...j", coef, base_point)
return tangent_vec
@staticmethod
def angle_to_extrinsic(point_angle):
"""Convert point from angle to extrinsic coordinates.
Convert from the angle in radians to the extrinsic coordinates in
2d plane. Angle 0 corresponds to point [1., 0.] and is expected in [-Pi, Pi).
This method is only implemented in dimension 1.
Parameters
----------
point_angle : array-like, shape=[...]
Point on the circle, i.e. an angle in radians in [-Pi, Pi].
Returns
-------
point_extrinsic : array_like, shape=[..., 2]
Point on the sphere, in extrinsic coordinates in Euclidean space.
"""
cos = gs.cos(point_angle)
sin = gs.sin(point_angle)
return gs.stack([cos, sin], axis=-1)
@staticmethod
def extrinsic_to_angle(point_extrinsic):
"""Compute the angle of a point in the plane.
Convert from the extrinsic coordinates in the 2d plane to angle in
radians. Angle 0 corresponds to point [1., 0.]. This method is only
implemented in dimension 1.
Parameters
----------
point_extrinsic : array-like, shape=[...]
Point on the circle, in extrinsic coordinates in Euclidean space.
Returns
-------
point_angle : array_like, shape=[..., 2]
Point on the circle, i.e. an angle in radians in [-Pi, Pi].
"""
return gs.arctan2(point_extrinsic[..., 1], point_extrinsic[..., 0])
def spherical_to_extrinsic(self, point_spherical):
"""Convert point from spherical to extrinsic coordinates.
Convert from the spherical coordinates in the hypersphere
to the extrinsic coordinates in Euclidean space.
Spherical coordinates are defined from the north pole, i.e. that
angles [0., 0.] correspond to point [0., 0., 1.].
Only implemented in dimension 2.
Parameters
----------
point_spherical : array-like, shape=[..., dim]
Point on the sphere, in spherical coordinates.
Returns
-------
point_extrinsic : array_like, shape=[..., dim + 1]
Point on the sphere, in extrinsic coordinates in Euclidean space.
"""
if self.dim != 2:
raise NotImplementedError(
"The conversion from spherical coordinates"
" to extrinsic coordinates is implemented"
" only in dimension 2."
)
theta = point_spherical[..., 0]
phi = point_spherical[..., 1]
point_extrinsic = gs.stack(
[gs.sin(theta) * gs.cos(phi), gs.sin(theta) * gs.sin(phi), gs.cos(theta)],
axis=-1,
)
if not gs.all(self.belongs(point_extrinsic)):
raise ValueError("Points do not belong to the manifold.")
return point_extrinsic
def tangent_spherical_to_extrinsic(
self, tangent_vec_spherical, base_point_spherical
):
"""Convert tangent vector from spherical to extrinsic coordinates.
Convert from the spherical coordinates in the hypersphere
to the extrinsic coordinates in Euclidean space for a tangent
vector. Only implemented in dimension 2.
Parameters
----------
tangent_vec_spherical : array-like, shape=[..., dim]
Tangent vector to the sphere, in spherical coordinates.
base_point_spherical : array-like, shape=[..., dim]
Point on the sphere, in spherical coordinates.
Returns
-------
tangent_vec_extrinsic : array-like, shape=[..., dim + 1]
Tangent vector to the sphere, at base point,
in extrinsic coordinates in Euclidean space.
"""
if self.dim != 2:
raise NotImplementedError(
"The conversion from spherical coordinates"
" to extrinsic coordinates is implemented"
" only in dimension 2."
)
axes = (2, 0, 1) if base_point_spherical.ndim == 2 else (0, 1)
theta = base_point_spherical[..., 0]
phi = base_point_spherical[..., 1]
phi = gs.where(theta == 0.0, 0.0, phi)
zeros = gs.zeros_like(theta)
jac = gs.array(
[
[gs.cos(theta) * gs.cos(phi), -gs.sin(theta) * gs.sin(phi)],
[gs.cos(theta) * gs.sin(phi), gs.sin(theta) * gs.cos(phi)],
[-gs.sin(theta), zeros],
]
)
jac = gs.transpose(jac, axes)
tangent_vec_extrinsic = gs.einsum(
"...ij,...j->...i", jac, tangent_vec_spherical
)
return tangent_vec_extrinsic
def extrinsic_to_spherical(self, point_extrinsic):
"""Convert point from extrinsic to spherical coordinates.
Convert from the extrinsic coordinates, i.e. embedded in Euclidean
space of dim 3 to spherical coordinates in the hypersphere.
Spherical coordinates are defined from the north pole, i.e.
angles [0., 0.] correspond to point [0., 0., 1.].
Only implemented in dimension 2.
Parameters
----------
point_extrinsic : array-like, shape=[..., dim]
Point on the sphere, in extrinsic coordinates.
Returns
-------
point_spherical : array_like, shape=[..., dim + 1]
Point on the sphere, in spherical coordinates relative to the
north pole.
"""
if self.dim != 2:
raise NotImplementedError(
"The conversion from to extrinsic coordinates "
"spherical coordinates is implemented"
" only in dimension 2."
)
theta = gs.arccos(point_extrinsic[..., -1])
x = point_extrinsic[..., 0]
y = point_extrinsic[..., 1]
phi = gs.arctan2(y, x)
phi = gs.where(phi < 0, phi + 2 * gs.pi, phi)
return gs.stack([theta, phi], axis=-1)
def tangent_extrinsic_to_spherical(
self, tangent_vec, base_point=None, base_point_spherical=None
):
"""Convert tangent vector from extrinsic to spherical coordinates.
Convert a tangent vector from the extrinsic coordinates in Euclidean
space to the spherical coordinates in the hypersphere for.
Spherical coordinates are considered from the north pole [0., 0.,
1.]. This method is only implemented in dimension 2.
Parameters
----------
tangent_vec : array-like, shape=[..., dim]
Tangent vector to the sphere, in spherical coordinates.
base_point : array-like, shape=[..., dim]
Point on the sphere. Unused if `base_point_spherical` is given.
Optional, default : None.
base_point_spherical : array-like, shape=[..., dim]
Point on the sphere, in spherical coordinates. Either
`base_point` or `base_point_spherical` must be given.
Optional, default : None.
Returns
-------
tangent_vec_spherical : array-like, shape=[..., dim + 1]
Tangent vector to the sphere, at base point,
in spherical coordinates relative to the north pole [0., 0., 1.].
"""
if self.dim != 2:
raise NotImplementedError(
"The conversion from to extrinsic coordinates "
"spherical coordinates is implemented"
" only in dimension 2."
)
if base_point is None and base_point_spherical is None:
raise ValueError(
"A base point must be given, either in "
"extrinsic or in spherical coordinates."
)
if base_point_spherical is None and base_point is not None:
base_point_spherical = self.extrinsic_to_spherical(base_point)
axes = (2, 0, 1) if base_point_spherical.ndim == 2 else (0, 1)
theta = base_point_spherical[..., 0]
phi = base_point_spherical[..., 1]
theta_safe = gs.where(gs.abs(theta) < gs.atol, gs.atol, theta)
zeros = gs.zeros_like(theta)
jac_close_0 = gs.array(
[[gs.ones_like(theta), zeros, zeros], [zeros, gs.ones_like(theta), zeros]]
)
jac = gs.array(
[
[
gs.cos(theta) * gs.cos(phi),
gs.cos(theta) * gs.sin(phi),
-gs.sin(theta),
],
[
-gs.sin(phi) / gs.sin(theta_safe),
gs.cos(phi) / gs.sin(theta_safe),
zeros,
],
]
)
jac = gs.transpose(jac, axes)
jac_close_0 = gs.transpose(jac_close_0, axes)
theta_criterion = gs.einsum("...,...ij->...ij", theta, gs.ones_like(jac))
jac = gs.where(gs.abs(theta_criterion) < gs.atol, jac_close_0, jac)
tangent_vec_spherical = gs.einsum("...ij,...j->...i", jac, tangent_vec)
return tangent_vec_spherical
def intrinsic_to_extrinsic_coords(self, point_intrinsic):
"""Convert point from intrinsic to extrinsic coordinates.
Convert from the intrinsic coordinates in the hypersphere,
to the extrinsic coordinates in Euclidean space.
For dimensions 1 (the circle), the intrinsic coordinates correspond
angles in radians, with 0. mapping to point [1., 0.]. For dimension
2, the intrinsic coordinates are the spherical coordinates from the
north pole, i.e. that angles [0., 0.] correspond to point [0., 0., 1.].
Parameters
----------
point_intrinsic : array-like, shape=[..., dim]
Point on the hypersphere, in intrinsic coordinates.
Returns
-------
point_extrinsic : array-like, shape=[..., dim + 1]
Point on the hypersphere, in extrinsic coordinates in
Euclidean space.
"""
if self.dim == 2:
return self.spherical_to_extrinsic(point_intrinsic)
if self.dim == 1:
return self.angle_to_extrinsic(point_intrinsic)
raise NotImplementedError(
"Intrinsic coordinates are only implemented in dimension 1 and 2."
)
def extrinsic_to_intrinsic_coords(self, point_extrinsic):
"""Convert point from extrinsic to intrinsic coordinates.
Convert from the extrinsic coordinates in Euclidean space,
to some intrinsic coordinates in the hypersphere.
Parameters
----------
point_extrinsic : array-like, shape=[..., dim + 1]
Point on the hypersphere, in extrinsic coordinates in
Euclidean space.
Returns
-------
point_intrinsic : array-like, shape=[..., dim]
Point on the hypersphere, in intrinsic coordinates.
"""
if self.dim == 2:
return self.extrinsic_to_spherical(point_extrinsic)
if self.dim == 1:
return self.extrinsic_to_angle(point_extrinsic)
raise NotImplementedError(
"Intrinsic coordinates are only implemented in dimension 1 and 2."
)
def _replace_values(self, samples, new_samples, indcs):
replaced_indices = [i for i, is_replaced in enumerate(indcs) if is_replaced]
value_indices = list(product(replaced_indices, range(self.dim + 1)))
return gs.assignment(samples, gs.flatten(new_samples), value_indices)
def random_point(self, n_samples=1, bound=1.0):
"""Sample in the hypersphere from the uniform distribution.
Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.
bound : unused
Returns
-------
samples : array-like, shape=[..., dim + 1]
Points sampled on the hypersphere.
"""
return self.random_uniform(n_samples)
def random_uniform(self, n_samples=1):
"""Sample in the hypersphere from the uniform distribution.
Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.
Returns
-------
samples : array-like, shape=[..., dim + 1]
Points sampled on the hypersphere.
"""
size = (n_samples, self.dim + 1)
samples = gs.random.normal(size=size)
while True:
norms = gs.linalg.norm(samples, axis=1)
indcs = gs.isclose(norms, 0.0, atol=gs.atol)
num_bad_samples = gs.sum(indcs)
if num_bad_samples == 0:
break
new_samples = gs.random.normal(size=(num_bad_samples, self.dim + 1))
samples = self._replace_values(samples, new_samples, indcs)
samples = gs.einsum("..., ...i->...i", 1 / norms, samples)
if n_samples == 1:
samples = gs.squeeze(samples, axis=0)
if self.dim in [1, 2] and self.default_coords_type == "intrinsic":
return self.extrinsic_to_intrinsic_coords(samples)
return samples
def random_von_mises_fisher(self, mu=None, kappa=10, n_samples=1, max_iter=100):
"""Sample with the von Mises-Fisher distribution.
This distribution corresponds to the maximum entropy distribution
given a mean. In dimension 2, a closed form expression is available.
In larger dimension, rejection sampling is used according to [Wood94]_
References
----------
https://en.wikipedia.org/wiki/Von_Mises-Fisher_distribution
.. [Wood94] Wood, Andrew T. A. “Simulation of the von Mises Fisher
Distribution.” Communications in Statistics - Simulation
and Computation, June 27, 2007.
https://doi.org/10.1080/03610919408813161.
Parameters
----------
mu : array-like, shape=[dim]
Mean parameter of the distribution.
kappa : float
Kappa parameter of the von Mises distribution.
Optional, default: 10.
n_samples : int
Number of samples.
Optional, default: 1.
max_iter : int
Maximum number of trials in the rejection algorithm. In case it
is reached, the current number of samples < n_samples is returned.
Optional, default: 100.
Returns
-------
point : array-like, shape=[n_samples, dim + 1]
Points sampled on the sphere in extrinsic coordinates
in Euclidean space of dimension dim + 1.
"""
dim = self.dim
if dim == 2:
angle = 2.0 * gs.pi * gs.random.rand(n_samples)
angle = gs.to_ndarray(angle, to_ndim=2, axis=1)
unit_vector = gs.hstack((gs.cos(angle), gs.sin(angle)))
scalar = gs.random.rand(n_samples)
coord_x = 1.0 + 1.0 / kappa * gs.log(
scalar + (1.0 - scalar) * gs.exp(gs.array(-2.0 * kappa))
)
coord_x = gs.to_ndarray(coord_x, to_ndim=2, axis=1)
coord_yz = gs.sqrt(1.0 - coord_x ** 2) * unit_vector
sample = gs.hstack((coord_x, coord_yz))
else:
# rejection sampling in the general case
sqrt = gs.sqrt(4 * kappa ** 2.0 + dim ** 2)
envelop_param = (-2 * kappa + sqrt) / dim
node = (1.0 - envelop_param) / (1.0 + envelop_param)
correction = kappa * node + dim * gs.log(1.0 - node ** 2)
n_accepted, n_iter = 0, 0
result = []
while (n_accepted < n_samples) and (n_iter < max_iter):
sym_beta = beta.rvs(dim / 2, dim / 2, size=n_samples - n_accepted)
sym_beta = gs.cast(sym_beta, node.dtype)
coord_x = (1 - (1 + envelop_param) * sym_beta) / (
1 - (1 - envelop_param) * sym_beta
)
accept_tol = gs.random.rand(n_samples - n_accepted)
criterion = (
kappa * coord_x + dim * gs.log(1 - node * coord_x) - correction
) > gs.log(accept_tol)
result.append(coord_x[criterion])
n_accepted += gs.sum(criterion)
n_iter += 1
if n_accepted < n_samples:
logging.warning(
"Maximum number of iteration reached in rejection "
"sampling before n_samples were accepted."
)
coord_x = gs.concatenate(result)
coord_rest = _Hypersphere(dim - 1).random_uniform(n_accepted)
coord_rest = gs.einsum(
"...,...i->...i", gs.sqrt(1 - coord_x ** 2), coord_rest
)
sample = gs.concatenate([coord_x[..., None], coord_rest], axis=1)
if mu is not None:
sample = utils.rotate_points(sample, mu)
return sample if (n_samples > 1) else sample[0]
def random_riemannian_normal(
self, mean=None, precision=None, n_samples=1, max_iter=100
):
r"""Sample from the Riemannian normal distribution.
The Riemannian normal distribution, or spherical normal in this case,
is defined by the probability density function (with respect to the
Riemannian volume measure) proportional to:
.. math::
\exp \Big \left(- \frac{\lambda}{2} \mathtm{arccos}^2(x^T\mu)
\Big \right)
where :math:`\mu` is the mean and :math:`\lambda` is the isotropic
precision. For the anisotropic case,
:math:`\log_{\mu}(x)^T \Lambda \log_{\mu}(x)` is used instead.
A rejection algorithm is used to sample from this distribution [Hau18]_
Parameters
----------
mean : array-like, shape=[dim]
Mean parameter of the distribution.
Optional, default: (0,...,0,1) (the north pole).
precision : float or array-like, shape=[dim, dim]
Inverse of the covariance parameter of the normal distribution.
If a float is passed, the covariance matrix is precision times
identity.
Optional, default: identity.
n_samples : int
Number of samples.
Optional, default: 1.
max_iter : int
Maximum number of trials in the rejection algorithm. In case it
is reached, the current number of samples < n_samples is returned.
Optional, default: 100.
Returns
-------
point : array-like, shape=[n_samples, dim + 1]
Points sampled on the sphere.
References
----------
.. [Hau18] Hauberg, Soren. “Directional Statistics with the
Spherical Normal Distribution.”
In 2018 21st International Conference on Information
Fusion (FUSION), 704–11, 2018.
https://doi.org/10.23919/ICIF.2018.8455242.
"""
dim = self.dim
n_accepted, n_iter = 0, 0
result = []
if precision is None:
precision_ = gs.eye(self.dim)
elif isinstance(precision, (float, int)):
precision_ = precision * gs.eye(self.dim)
else:
precision_ = precision
precision_2 = precision_ + (dim - 1) / gs.pi * gs.eye(dim)
tangent_cov = gs.linalg.inv(precision_2)
def threshold(random_v):
"""Compute the acceptance threshold."""
squared_norm = gs.sum(random_v ** 2, axis=-1)
sinc = utils.taylor_exp_even_func(squared_norm, utils.sinc_close_0) ** (
dim - 1
)
threshold_val = sinc * gs.exp(squared_norm * (dim - 1) / 2 / gs.pi)
return threshold_val, squared_norm ** 0.5
while (n_accepted < n_samples) and (n_iter < max_iter):
envelope = gs.random.multivariate_normal(
gs.zeros(dim), tangent_cov, size=(n_samples - n_accepted,)
)
thresh, norm = threshold(envelope)
proposal = gs.random.rand(n_samples - n_accepted)
criterion = gs.logical_and(norm <= gs.pi, proposal <= thresh)
result.append(envelope[criterion])
n_accepted += gs.sum(criterion)
n_iter += 1
if n_accepted < n_samples:
logging.warning(
"Maximum number of iteration reached in rejection "
"sampling before n_samples were accepted."
)
tangent_sample_intr = gs.concatenate(result)
tangent_sample = gs.concatenate(
[tangent_sample_intr, gs.zeros(n_accepted)[:, None]], axis=1
)
metric = HypersphereMetric(dim)
north_pole = gs.array([0.0] * dim + [1.0])
if mean is not None:
mean_from_north = metric.log(mean, north_pole)
tangent_sample_at_pt = metric.parallel_transport(
tangent_sample, mean_from_north, north_pole
)
else:
tangent_sample_at_pt = tangent_sample
mean = north_pole
sample = metric.exp(tangent_sample_at_pt, mean)
return sample[0] if (n_samples == 1) else sample
class HypersphereMetric(RiemannianMetric):
"""Class for the Hypersphere Metric.
Parameters
----------
dim : int
Dimension of the hypersphere.
"""
def __init__(self, dim):
super(HypersphereMetric, self).__init__(dim=dim, signature=(dim, 0))
self.embedding_metric = EuclideanMetric(dim + 1)
self._space = _Hypersphere(dim=dim)
def metric_matrix(self, base_point=None):
"""Metric matrix at the tangent space at a base point.
Parameters
----------
base_point : array-like, shape=[..., dim + 1]
Base point.
Optional, default: None.
Returns
-------
mat : array-like, shape=[..., dim + 1, dim + 1]
Inner-product matrix.
"""
return gs.eye(self.dim + 1)
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point=None):
"""Compute the inner-product of two tangent vectors at a base point.
Parameters
----------
tangent_vec_a : array-like, shape=[..., dim + 1]
First tangent vector at base point.
tangent_vec_b : array-like, shape=[..., dim + 1]
Second tangent vector at base point.
base_point : array-like, shape=[..., dim + 1], optional
Point on the hypersphere.
Returns
-------
inner_prod : array-like, shape=[...,]
Inner-product of the two tangent vectors.
"""
inner_prod = self.embedding_metric.inner_product(
tangent_vec_a, tangent_vec_b, base_point
)
return inner_prod
def squared_norm(self, vector, base_point=None):
"""Compute the squared norm of a vector.
Squared norm of a vector associated with the inner-product
at the tangent space at a base point.
Parameters
----------
vector : array-like, shape=[..., dim + 1]
Vector on the tangent space of the hypersphere at base point.
base_point : array-like, shape=[..., dim + 1], optional
Point on the hypersphere.
Returns
-------
sq_norm : array-like, shape=[..., 1]
Squared norm of the vector.
"""
sq_norm = self.embedding_metric.squared_norm(vector)
return sq_norm
def exp(self, tangent_vec, base_point, **kwargs):
"""Compute the Riemannian exponential of a tangent vector.
Parameters
----------
tangent_vec : array-like, shape=[..., dim + 1]
Tangent vector at a base point.
base_point : array-like, shape=[..., dim + 1]
Point on the hypersphere.
Returns
-------
exp : array-like, shape=[..., dim + 1]
Point on the hypersphere equal to the Riemannian exponential
of tangent_vec at the base point.
"""
hypersphere = Hypersphere(dim=self.dim)
proj_tangent_vec = hypersphere.to_tangent(tangent_vec, base_point)
norm2 = self.embedding_metric.squared_norm(proj_tangent_vec)
coef_1 = utils.taylor_exp_even_func(norm2, utils.cos_close_0, order=4)
coef_2 = utils.taylor_exp_even_func(norm2, utils.sinc_close_0, order=4)
exp = gs.einsum("...,...j->...j", coef_1, base_point) + gs.einsum(
"...,...j->...j", coef_2, proj_tangent_vec
)
return exp
def log(self, point, base_point, **kwargs):
"""Compute the Riemannian logarithm of a point.
Parameters
----------
point : array-like, shape=[..., dim + 1]
Point on the hypersphere.
base_point : array-like, shape=[..., dim + 1]
Point on the hypersphere.
Returns
-------
log : array-like, shape=[..., dim + 1]
Tangent vector at the base point equal to the Riemannian logarithm
of point at the base point.
"""
inner_prod = self.embedding_metric.inner_product(base_point, point)
cos_angle = gs.clip(inner_prod, -1.0, 1.0)
squared_angle = gs.arccos(cos_angle) ** 2
coef_1_ = utils.taylor_exp_even_func(
squared_angle, utils.inv_sinc_close_0, order=5
)
coef_2_ = utils.taylor_exp_even_func(
squared_angle, utils.inv_tanc_close_0, order=5
)
log = gs.einsum("...,...j->...j", coef_1_, point) - gs.einsum(
"...,...j->...j", coef_2_, base_point
)
return log
def dist(self, point_a, point_b):
"""Compute the geodesic distance between two points.
Parameters
----------
point_a : array-like, shape=[..., dim + 1]
First point on the hypersphere.
point_b : array-like, shape=[..., dim + 1]
Second point on the hypersphere.
Returns
-------
dist : array-like, shape=[..., 1]
Geodesic distance between the two points.
"""
norm_a = self.embedding_metric.norm(point_a)
norm_b = self.embedding_metric.norm(point_b)
inner_prod = self.embedding_metric.inner_product(point_a, point_b)
cos_angle = inner_prod / (norm_a * norm_b)
cos_angle = gs.clip(cos_angle, -1, 1)
dist = gs.arccos(cos_angle)
return dist
def squared_dist(self, point_a, point_b, **kwargs):
"""Squared geodesic distance between two points.
Parameters
----------
point_a : array-like, shape=[..., dim]
Point on the hypersphere.
point_b : array-like, shape=[..., dim]
Point on the hypersphere.
Returns
-------
sq_dist : array-like, shape=[...,]
"""
return self.dist(point_a, point_b) ** 2
@staticmethod
def parallel_transport(tangent_vec_a, tangent_vec_b, base_point, **kwargs):
r"""Compute the parallel transport of a tangent vector.
Closed-form solution for the parallel transport of a tangent vector a
along the geodesic defined by :math:`t \mapsto exp_(base_point)(t*
tangent_vec_b)`.
Parameters
----------
tangent_vec_a : array-like, shape=[..., dim + 1]
Tangent vector at base point to be transported.
tangent_vec_b : array-like, shape=[..., dim + 1]
Tangent vector at base point, along which the parallel transport
is computed.
base_point : array-like, shape=[..., dim + 1]
Point on the hypersphere.
Returns
-------
transported_tangent_vec: array-like, shape=[..., dim + 1]
Transported tangent vector at `exp_(base_point)(tangent_vec_b)`.
"""
theta = gs.linalg.norm(tangent_vec_b, axis=-1)
eps = gs.where(theta == 0.0, 1.0, theta)
normalized_b = gs.einsum("...,...i->...i", 1 / eps, tangent_vec_b)
pb = gs.einsum("...i,...i->...", tangent_vec_a, normalized_b)
p_orth = tangent_vec_a - gs.einsum("...,...i->...i", pb, normalized_b)
transported = (
-gs.einsum("...,...i->...i", gs.sin(theta) * pb, base_point)
+ gs.einsum("...,...i->...i", gs.cos(theta) * pb, normalized_b)
+ p_orth
)
return transported
def christoffels(self, point, point_type="spherical"):
"""Compute the Christoffel symbols at a point.
Only implemented in dimension 2 and for spherical coordinates.
Parameters
----------
point : array-like, shape=[..., dim]
Point on hypersphere where the Christoffel symbols are computed.
point_type: str, {'spherical', 'intrinsic', 'extrinsic'}
Coordinates in which to express the Christoffel symbols.
Optional, default: 'spherical'.
Returns
-------
christoffel : array-like, shape=[..., contravariant index, 1st
covariant index, 2nd covariant index]
Christoffel symbols at point.
"""
if self.dim != 2 or point_type != "spherical":
raise NotImplementedError(
"The Christoffel symbols are only implemented"
" for spherical coordinates in the 2-sphere"
)
point = gs.to_ndarray(point, to_ndim=2)
christoffel = []
for sample in point:
gamma_0 = gs.array([[0, 0], [0, -gs.sin(sample[0]) * gs.cos(sample[0])]])
gamma_1 = gs.array(
[
[0, gs.cos(sample[0]) / gs.sin(sample[0])],
[gs.cos(sample[0]) / gs.sin(sample[0]), 0],
]
)
christoffel.append(gs.stack([gamma_0, gamma_1]))
christoffel = gs.stack(christoffel)
if gs.ndim(christoffel) == 4 and gs.shape(christoffel)[0] == 1:
christoffel = gs.squeeze(christoffel, axis=0)
return christoffel
def curvature(self, tangent_vec_a, tangent_vec_b, tangent_vec_c, base_point):
r"""Compute the curvature.
For three tangent vectors at a base point :math:`x,y,z`,
the curvature is defined by
:math:`R(x, y)z = \nabla_{[x,y]}z
- \nabla_z\nabla_y z + \nabla_y\nabla_x z`, where :math:`\nabla`
is the Levi-Civita connection. In the case of the hypersphere,
we have the closed formula
:math:`R(x,y)z = \langle x, z \rangle y - \langle y,z \rangle x`.
Parameters
----------
tangent_vec_a : array-like, shape=[..., dim]
Tangent vector at `base_point`.
tangent_vec_b : array-like, shape=[..., dim]
Tangent vector at `base_point`.
tangent_vec_c : array-like, shape=[..., dim]
Tangent vector at `base_point`.
base_point : array-like, shape=[..., dim]
Point on the hypersphere.
Returns
-------
curvature : array-like, shape=[..., dim]
Tangent vector at `base_point`.
"""
inner_ac = self.inner_product(tangent_vec_a, tangent_vec_c)
inner_bc = self.inner_product(tangent_vec_b, tangent_vec_c)
first_term = gs.einsum("...,...i->...i", inner_bc, tangent_vec_a)
second_term = gs.einsum("...,...i->...i", inner_ac, tangent_vec_b)
return -first_term + second_term
def _normalization_factor_odd_dim(self, variances):
"""Compute the normalization factor - odd dimension."""
dim = self.dim
half_dim = int((dim + 1) / 2)
area = 2 * gs.pi ** half_dim / math.factorial(half_dim - 1)
comb = gs.comb(dim - 1, half_dim - 1)
erf_arg = gs.sqrt(variances / 2) * gs.pi
first_term = (
area
/ (2 ** dim - 1)
* comb
* gs.sqrt(gs.pi / (2 * variances))
* gs.erf(erf_arg)
)
def summand(k):
exp_arg = -((dim - 1 - 2 * k) ** 2) / 2 / variances
erf_arg_2 = (gs.pi * variances - (dim - 1 - 2 * k) * 1j) / gs.sqrt(
2 * variances
)
sign = (-1.0) ** k
comb_2 = gs.comb(k, dim - 1)
return sign * comb_2 * gs.exp(exp_arg) * gs.real(gs.erf(erf_arg_2))
if half_dim > 2:
sum_term = gs.sum(gs.stack([summand(k)] for k in range(half_dim - 2)))
else:
sum_term = summand(0)
coef = area / 2 / erf_arg * gs.pi ** 0.5 * (-1.0) ** (half_dim - 1)
return first_term + coef / 2 ** (dim - 2) * sum_term
def _normalization_factor_even_dim(self, variances):
"""Compute the normalization factor - even dimension."""
dim = self.dim
half_dim = (dim + 1) / 2
area = 2 * gs.pi ** half_dim / math.gamma(half_dim)
def summand(k):
exp_arg = -((dim - 1 - 2 * k) ** 2) / 2 / variances
erf_arg_1 = (dim - 1 - 2 * k) * 1j / gs.sqrt(2 * variances)
erf_arg_2 = (gs.pi * variances - (dim - 1 - 2 * k) * 1j) / gs.sqrt(
2 * variances
)
sign = (-1.0) ** k
comb = gs.comb(dim - 1, k)
erf_terms = gs.imag(gs.erf(erf_arg_2) + gs.erf(erf_arg_1))
return sign * comb * gs.exp(exp_arg) * erf_terms
half_dim_2 = int((dim - 2) / 2)
if half_dim_2 > 0:
sum_term = gs.sum(gs.stack([summand(k)] for k in range(half_dim_2)))
else:
sum_term = summand(0)
coef = (
area
* (-1.0) ** half_dim_2
/ 2 ** (dim - 2)
* gs.sqrt(gs.pi / 2 / variances)
)
return coef * sum_term
def normalization_factor(self, variances):
"""Return normalization factor of the Gaussian distribution.
Parameters