-
Notifications
You must be signed in to change notification settings - Fork 239
/
manifold.py
152 lines (126 loc) · 4.13 KB
/
manifold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Manifold module.
In other words, a topological space that locally resembles
Euclidean space near each point.
"""
import abc
import geomstats.backend as gs
import geomstats.errors
from geomstats.geometry.riemannian_metric import RiemannianMetric
class Manifold(abc.ABC):
r"""Class for manifolds.
Parameters
----------
dim : int
Dimension of the manifold.
default_point_type : str, {\'vector\', \'matrix\'}
Point type.
Optional, default: 'vector'.
default_coords_type : str, {\'intrinsic\', \'extrinsic\', etc}
Coordinate type.
Optional, default: 'intrinsic'.
"""
def __init__(
self,
dim,
metric=None,
default_point_type="vector",
default_coords_type="intrinsic",
**kwargs
):
super(Manifold, self).__init__(**kwargs)
geomstats.errors.check_integer(dim, "dim")
geomstats.errors.check_parameter_accepted_values(
default_point_type, "default_point_type", ["vector", "matrix"]
)
self.dim = dim
self.default_point_type = default_point_type
self.default_coords_type = default_coords_type
self.metric = metric
@abc.abstractmethod
def belongs(self, point, atol=gs.atol):
"""Evaluate if a point belongs to the manifold.
Parameters
----------
point : array-like, shape=[..., dim]
Point to evaluate.
atol : float
Absolute tolerance.
Optional, default: backend atol.
Returns
-------
belongs : array-like, shape=[...,]
Boolean evaluating if point belongs to the manifold.
"""
@abc.abstractmethod
def is_tangent(self, vector, base_point, atol=gs.atol):
"""Check whether the vector is tangent at base_point.
Parameters
----------
vector : array-like, shape=[..., dim]
Vector.
base_point : array-like, shape=[..., dim]
Point on the manifold.
atol : float
Absolute tolerance.
Optional, default: backend atol.
Returns
-------
is_tangent : bool
Boolean denoting if vector is a tangent vector at the base point.
"""
@abc.abstractmethod
def to_tangent(self, vector, base_point):
"""Project a vector to a tangent space of the manifold.
Parameters
----------
vector : array-like, shape=[..., dim]
Vector.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., dim]
Tangent vector at base point.
"""
@abc.abstractmethod
def random_point(self, n_samples=1, bound=1.0):
"""Sample random points on the manifold.
If the manifold is compact, a uniform distribution is used.
Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.
bound : float
Bound of the interval in which to sample for non compact manifolds.
Optional, default: 1.
Returns
-------
samples : array-like, shape=[..., {dim, [n, n]}]
Points sampled on the hypersphere.
"""
def regularize(self, point):
"""Regularize a point to the canonical representation for the manifold.
Parameters
----------
point : array-like, shape=[..., dim]
Point.
Returns
-------
regularized_point : array-like, shape=[..., dim]
Regularized point.
"""
regularized_point = point
return regularized_point
@property
def metric(self):
"""Riemannian Metric associated to the Manifold."""
return self._metric
@metric.setter
def metric(self, metric):
if metric is not None:
if not isinstance(metric, RiemannianMetric):
raise ValueError("The argument must be a RiemannianMetric object")
if metric.dim != self.dim:
metric.dim = self.dim
self._metric = metric