-
Notifications
You must be signed in to change notification settings - Fork 239
/
quotient_metric.py
429 lines (370 loc) · 17.8 KB
/
quotient_metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
"""Classes for fiber bundles and quotient metrics."""
import geomstats.backend as gs
from geomstats.geometry.fiber_bundle import FiberBundle
from geomstats.geometry.riemannian_metric import RiemannianMetric
class QuotientMetric(RiemannianMetric):
"""Quotient metric.
Given a (principal) fiber bundle, or more generally a manifold with a
Lie group acting on it by the right, the quotient space is the space of
orbits under this action. The quotient metric is defined such that the
canonical projection is a Riemannian submersion, i.e. it is isometric to
the restriction of the metric of the total space to horizontal subspaces.
Parameters
----------
fiber_bundle : geomstats.geometry.fiber_bundle.FiberBundle
Bundle structure to define the quotient.
"""
def __init__(self, fiber_bundle: FiberBundle, dim: int = None):
if dim is None:
if fiber_bundle.base is not None:
dim = fiber_bundle.base.dim
elif fiber_bundle.group is not None:
dim = fiber_bundle.dim - fiber_bundle.group.dim
else:
raise ValueError(
"Either the base manifold, "
"its dimension, or the group acting on the "
"total space must be provided."
)
super(QuotientMetric, self).__init__(
dim=dim, default_point_type=fiber_bundle.default_point_type
)
self.fiber_bundle = fiber_bundle
self.group = fiber_bundle.group
self.ambient_metric = fiber_bundle.ambient_metric
def inner_product(
self, tangent_vec_a, tangent_vec_b, base_point=None, fiber_point=None
):
"""Compute the inner-product of two tangent vectors at a base point.
Parameters
----------
tangent_vec_a : array-like, shape=[..., {dim, [n, n]}]
Tangent vector to the quotient manifold.
tangent_vec_b : array-like, shape=[..., {dim, [n, n]}]
Tangent vector to the quotient manifold.
base_point : array-like, shape=[..., {dim, [n, n]}]
Point on the quotient manifold.
Optional, default: None.
fiber_point : array-like, shape=[..., {dim, [n, n]}]
Point on the total space, lift of `base_point`, i.e. such that
`submersion` applied to `point` results in `base_point`.
Optional, default: None. In this case, it is computed using the
method lift.
Returns
-------
inner_product : float, shape=[...]
Inner products
"""
if fiber_point is None:
if base_point is not None:
fiber_point = self.fiber_bundle.lift(base_point)
else:
raise ValueError(
"Either a point (of the total space) or a "
"base point (of the quotient manifold) must "
"be given."
)
horizontal_a = self.fiber_bundle.horizontal_lift(
tangent_vec_a, fiber_point=fiber_point
)
horizontal_b = self.fiber_bundle.horizontal_lift(
tangent_vec_b, fiber_point=fiber_point
)
return self.ambient_metric.inner_product(
horizontal_a, horizontal_b, fiber_point
)
def exp(self, tangent_vec, base_point, **kwargs):
"""Compute the Riemannian exponential of a tangent vector.
Parameters
----------
tangent_vec : array-like, shape=[..., {dim, [n, n]}]
Tangent vector to the quotient manifold.
base_point : array-like, shape=[..., {dim, [n, n]}]
Point on the quotient manifold.
Optional, default: None.
Returns
-------
exp : array-like, shape=[..., {dim, [n, n]}]
Point on the quotient manifold.
"""
lift = self.fiber_bundle.lift(base_point)
horizontal_vec = self.fiber_bundle.horizontal_lift(
tangent_vec, fiber_point=lift
)
return self.fiber_bundle.riemannian_submersion(
self.ambient_metric.exp(horizontal_vec, lift)
)
def log(self, point, base_point, **kwargs):
"""Compute the Riemannian logarithm of a point.
Parameters
----------
point : array-like, shape=[..., {dim, [n, n]}]
Point on the quotient manifold.
base_point : array-like, shape=[..., {dim, [n, n]}]
Point on the quotient manifold.
Returns
-------
log : array-like, shape=[..., {dim, [n, n]}]
Tangent vector at the base point equal to the Riemannian logarithm
of point at the base point.
"""
fiber_point = self.fiber_bundle.lift(point)
bp_fiber = self.fiber_bundle.lift(base_point)
aligned = self.fiber_bundle.align(fiber_point, bp_fiber, **kwargs)
return self.fiber_bundle.tangent_riemannian_submersion(
self.ambient_metric.log(aligned, bp_fiber), bp_fiber
)
def squared_dist(self, point_a, point_b, **kwargs):
"""Squared geodesic distance between two points.
Parameters
----------
point_a : array-like, shape=[..., {dim, [n, n]}]
Point.
point_b : array-like, shape=[..., {dim, [n, n]}]
Point.
Returns
-------
sq_dist : array-like, shape=[...,]
Squared distance.
"""
lift_a = self.fiber_bundle.lift(point_a)
lift_b = self.fiber_bundle.lift(point_b)
aligned = self.fiber_bundle.align(lift_a, lift_b, **kwargs)
return self.ambient_metric.squared_dist(aligned, lift_b)
def curvature(self, tangent_vec_a, tangent_vec_b, tangent_vec_c, base_point):
r"""Compute the curvature.
For three vectors fields :math:`X|_P = tangent_vec_a,
Y|_P = tangent_vec_b, Z|_P = tangent_vec_c` with tangent vector
specified in argument at the base point :math:`P`,
the curvature is defined by :math:`R(X,Y)Z = \nabla_{[X,Y]}Z
- \nabla_X\nabla_Y Z + \nabla_Y\nabla_X Z`.
In the case of quotient metrics, the fundamental equations of a
Riemannian submersion allow to compute the curvature of the base
manifold from the one of the total space and a correction term that
uses the integrability tensor A [O'Neill]_.
In more details, let :math:`X, Y, Z` be the horizontal lift of
vector fields extending the tangent vectors given in argument in a
neighborhood of the base-point P in the base-space. Then the
curvature of the base-space at the base-points is
:math:`R(X,Y) Z = hor( R^T(X,Y) Z) - 2 A_Z A_X Y + A_X A_Y Z - A_Y
A_X Z`, where :math:`R^T(X,Y)Z` is the curvature tensor of the
total space.
Parameters
----------
tangent_vec_a : array-like, shape=[..., {dim, [n, n]}]
Tangent vector at `base_point`.
tangent_vec_b : array-like, shape=[..., {dim, [n, n]}]
Tangent vector at `base_point`.
tangent_vec_c : array-like, shape=[..., {dim, [n, n]}]
Tangent vector at `base_point`.
base_point : array-like, shape=[..., {dim, [n, n]}]
Point on the base manifold.
Returns
-------
curvature : array-like, shape=[..., {dim, [n, n]}]
Tangent vector at `base_point`.
References
----------
.. [O'Neill] O’Neill, Barrett. The Fundamental Equations of a
Submersion, Michigan Mathematical Journal 13, no. 4 (December 1966):
459–69. https://doi.org/10.1307/mmj/1028999604.
"""
bundle = self.fiber_bundle
fiber_point = bundle.lift(base_point)
horizontal_a = bundle.horizontal_lift(tangent_vec_a, base_point)
horizontal_b = bundle.horizontal_lift(tangent_vec_b, base_point)
horizontal_c = bundle.horizontal_lift(tangent_vec_c, base_point)
top_curvature = self.ambient_metric.curvature(
horizontal_a, horizontal_b, horizontal_c, fiber_point
)
projected_top_curvature = bundle.tangent_riemannian_submersion(
top_curvature, fiber_point
)
f_ab = bundle.integrability_tensor(horizontal_a, horizontal_b, fiber_point)
f_c_f_ab = bundle.integrability_tensor(horizontal_c, f_ab, fiber_point)
f_c_f_ab = bundle.tangent_riemannian_submersion(f_c_f_ab, fiber_point)
f_ac = bundle.integrability_tensor(horizontal_a, horizontal_c, fiber_point)
f_b_f_ac = bundle.integrability_tensor(horizontal_b, f_ac, fiber_point)
f_b_f_ac = bundle.tangent_riemannian_submersion(f_b_f_ac, fiber_point)
f_bc = bundle.integrability_tensor(horizontal_b, horizontal_c, fiber_point)
f_a_f_bc = bundle.integrability_tensor(horizontal_a, f_bc, fiber_point)
f_a_f_bc = bundle.tangent_riemannian_submersion(f_a_f_bc, fiber_point)
return projected_top_curvature - 2 * f_c_f_ab + f_a_f_bc - f_b_f_ac
def curvature_derivative(
self,
tangent_vec_a,
tangent_vec_b,
tangent_vec_c,
tangent_vec_d,
base_point=None,
):
r"""Compute the covariant derivative of the curvature.
For four vectors fields :math:`H|_P = tangent_vec_a, X|_P =
tangent_vec_b, Y|_P = tangent_vec_c, Z|_P = tangent_vec_d` with
tangent vector value specified in argument at the base point `P`,
the covariant derivative of the curvature
:math:`(\nabla_H R)(X, Y)Z |_P` is computed at the base point P.
In the case of quotient metrics, the fundamental equations of a
Riemannian submersion allow to compute this tensor on the base manifold
from the one of the total space T and its covariant derivative with
additional correction terms involving the integrability tensor A and
its covariant derivatives [Pennec]_.
In more details, let :math:`H, X, Y, Z` be the horizontal lift of
parallel vector fields extending the tangent vectors given in argument
by parallel transport in a neighborhood of the base-point P in the
base-space. Such vector fields verify :math:`\nabla^T_H H=0`,
:math:`\nabla^T_H^X = A_H X` (and similarly for Y and Z) using the
connection :math:`\nabla^T` of the total space. Then the covariant
derivative of the curvature tensor is given by
:math:`\nabla_H (R(X, Y) Z) =
\hor\nabla_H^T(R^T(X,Y)Z) - A_H(ver R^T(X,Y)Z )
+ (2 A_H A_Z A_X Y - A_H A_X A_Y Z + A_H A_Y A_X Z)
- (2 \nabla_H^T A_Z A_X Y - \nabla_H^T A_X A_Y Z +
\nabla_H^T A_Y A_X Z)`, where :math:`R^T(X,Y)Z` is the curvature
tensor of the total space.
Parameters
----------
tangent_vec_a : array-like, shape=[..., n, n]
Tangent vector at `base_point` (derivative direction)).
tangent_vec_b : array-like, shape=[..., n, n]
Tangent vector at `base_point`.
tangent_vec_c : array-like, shape=[..., n, n]
Tangent vector at `base_point`.
tangent_vec_d : array-like, shape=[..., n, n]
Tangent vector at `base_point`.
base_point : array-like, shape=[..., n, n]
Point on the base manifold.
Returns
-------
curvature_derivative : array-like, shape=[..., n, n]
Tangent vector at base point.
References
----------
.. [Pennec] Pennec, Xavier. Computing the curvature and its gradient
in Kendall shape spaces. Unpublished.
"""
bundle = self.fiber_bundle
point_fiber = bundle.lift(base_point)
hor_h = bundle.horizontal_lift(tangent_vec_a, point_fiber)
hor_x = bundle.horizontal_lift(tangent_vec_b, point_fiber)
hor_y = bundle.horizontal_lift(tangent_vec_c, point_fiber)
hor_z = bundle.horizontal_lift(tangent_vec_d, point_fiber)
nabla_h_x = bundle.integrability_tensor(hor_h, hor_x, point_fiber)
nabla_h_y = bundle.integrability_tensor(hor_h, hor_y, point_fiber)
nabla_h_z = bundle.integrability_tensor(hor_h, hor_z, point_fiber)
nabla_curvature_top = self.ambient_metric.curvature_derivative(
hor_h, hor_x, hor_y, hor_z, point_fiber
)
hor_nabla_curvature_top = bundle.horizontal_projection(
nabla_curvature_top, point_fiber
)
ver_nabla_curvature_top = nabla_curvature_top - hor_nabla_curvature_top
a_h_ver_nabla_curvature_top = bundle.integrability_tensor(
hor_h, ver_nabla_curvature_top, point_fiber
)
# A_H A_Z A_X Y and \nabla_H A_Z A_X Y
nabla_h_a_x_y, a_x_y = bundle.integrability_tensor_derivative(
hor_h, hor_x, nabla_h_x, hor_y, nabla_h_y, point_fiber
)
nabla_h_a_z_a_x_y, a_z_a_x_y = bundle.integrability_tensor_derivative(
hor_h, hor_z, nabla_h_z, a_x_y, nabla_h_a_x_y, point_fiber
)
a_h_a_z_a_x_y = bundle.integrability_tensor(hor_h, a_z_a_x_y, point_fiber)
# A_H A_X A_Y Z and \nabla_H A_X A_Y Z
nabla_h_a_y_z, a_y_z = bundle.integrability_tensor_derivative(
hor_h, hor_y, nabla_h_y, hor_z, nabla_h_z, point_fiber
)
nabla_h_a_x_a_y_z, a_x_a_y_z = bundle.integrability_tensor_derivative(
hor_h, hor_x, nabla_h_x, a_y_z, nabla_h_a_y_z, point_fiber
)
a_h_a_x_a_y_z = bundle.integrability_tensor(hor_h, a_x_a_y_z, point_fiber)
# A_H A_Y A_X Z and \nabla_H A_Y A_X Z
nabla_h_a_x_z, a_x_z = bundle.integrability_tensor_derivative(
hor_h, hor_x, nabla_h_x, hor_z, nabla_h_z, point_fiber
)
nabla_h_a_y_a_x_z, a_y_a_x_z = bundle.integrability_tensor_derivative(
hor_h, hor_y, nabla_h_y, a_x_z, nabla_h_a_x_z, point_fiber
)
a_h_a_y_a_x_z = bundle.integrability_tensor(hor_h, a_y_a_x_z, point_fiber)
return (
hor_nabla_curvature_top
- a_h_ver_nabla_curvature_top
- 2.0 * (nabla_h_a_z_a_x_y - a_h_a_z_a_x_y)
+ (nabla_h_a_x_a_y_z - a_h_a_x_a_y_z)
- (nabla_h_a_y_a_x_z - a_h_a_y_a_x_z)
)
def directional_curvature_derivative(
self, tangent_vec_a, tangent_vec_b, base_point=None
):
r"""Compute the covariant derivative of the directional curvature.
For two vectors fields :math:`X|_P = tangent_vec_a, Y|_P =
tangent_vec_b` with tangent vector value specified in argument at the
base point `P`, the covariant derivative (in the direction 'X')
:math:`(\nabla_X R_Y)(X) |_P = (\nabla_X R)(Y, X) Y |_P` of the
directional curvature (in the direction `Y`)
:math:`R_Y(X) = R(Y, X) Y` is a quadratic tensor in 'X' and 'Y' that
plays an important role in the computation of the moments of the
empirical Fréchet mean.
This tensor can be computed from the covariant derivative of the
curvature tensor as in done generically the Connection class.
However, in the case of quotient metrics, a simplified expression can
be implemented based on the directional curvature of the total space T
and its covariant derivative with additional correction terms
involving the integrability tensor A and its covariant derivatives
[Pennec]_.
In more details, let :math:`X, Y` be the horizontal lift of parallel
vector fields extending the tangent vectors given in argument by
parallel transport in a neighborhood of the base-point P in the
base-space. Such vector fields verify :math:`\nabla^T_X X=0` and
:math:`\nabla^T_X^Y = A_X Y` using the connection :math:`\nabla^T`
of the total space. Then the covariant derivative of the
directional curvature tensor is given by :math:
`\nabla_X (R_Y(X)) = hor \nabla^T_X (R^T_Y(X)) - A_X( ver R^T_Y(X))
+ 3 A_X A_Y A_X Y - 3 \nabla_X^T A_Y A_X Y `, where :math:`R^T_Y(X)`
is the directional curvature tensor of the total space.
Parameters
----------
tangent_vec_a : array-like, shape=[..., n, n]
Tangent vector at `base_point`.
tangent_vec_b : array-like, shape=[..., n, n]
Tangent vector at `base_point`.
base_point : array-like, shape=[..., n, n]
Point on the base manifold.
Returns
-------
curvature_derivative : array-like, shape=[..., n, n]
Tangent vector at base point.
References
----------
.. [Pennec] Pennec, Xavier. Computing the curvature and its gradient
in Kendall shape spaces. Unpublished.
"""
bundle = self.fiber_bundle
point_fiber = bundle.lift(base_point)
hor_x = bundle.horizontal_lift(tangent_vec_a, point_fiber)
hor_y = bundle.horizontal_lift(tangent_vec_b, point_fiber)
nabla_x_x = gs.zeros_like(hor_x)
nabla_x_y = bundle.integrability_tensor(hor_x, hor_y, point_fiber)
nabla_curvature_top = self.ambient_metric.curvature_derivative(
hor_x, hor_x, hor_y, hor_y, point_fiber
)
hor_nabla_curvature_top = bundle.horizontal_projection(
nabla_curvature_top, point_fiber
)
ver_nabla_curvature_top = nabla_curvature_top - hor_nabla_curvature_top
a_x_ver_nabla_curvature_top = bundle.integrability_tensor(
hor_x, ver_nabla_curvature_top, point_fiber
)
# A_X A_Y A_X Y and \nabla_X A_Y A_X Y
nabla_x_a_x_y, a_x_y = bundle.integrability_tensor_derivative(
hor_x, hor_x, nabla_x_x, hor_y, nabla_x_y, point_fiber
)
nabla_x_a_y_a_x_y, a_y_a_x_y = bundle.integrability_tensor_derivative(
hor_x, hor_y, nabla_x_y, a_x_y, nabla_x_a_x_y, point_fiber
)
a_x_a_y_a_x_y = bundle.integrability_tensor(hor_x, a_y_a_x_y, point_fiber)
return (
hor_nabla_curvature_top
- a_x_ver_nabla_curvature_top
+ 3.0 * (nabla_x_a_y_a_x_y - a_x_a_y_a_x_y)
)