-
Notifications
You must be signed in to change notification settings - Fork 239
/
_my_manifold.py
150 lines (122 loc) · 5.01 KB
/
_my_manifold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Template file to illustrate how to create a manifold in geomstats.
For additional guidelines on how to contribute to geomstats, visit:
https://geomstats.github.io/contributing.html#contributing-code-workflow
Note: A manifold needs to be created with an associated test file.
The test file for this manifold can be found at:
tests/test__my_manifold.py.
"""
import geomstats.backend as gs
# Import the class(es) that MyManifold inherits from
from geomstats.geometry.manifold import Manifold
# This class inherits from the class Manifold.
# Inheritance in geomstats means that the class MyManifold will reuse code
# that is in the Manifold class.
class MyManifold(Manifold):
r"""Give a one-liner description/definition of MyManifold.
For example: Class for Euclidean spaces.
Give a more detailed description/definition of MyManifold.
For example: By definition, a Euclidean space is a vector space of a given
dimension, equipped with a Euclidean metric.
List the parameters of MyManifold, i.e. the parameters given as inputs
of the constructor __init__.
For example:
Parameters
----------
dim : int
Dimension of the manifold.
"""
def __init__(self, dim, another_parameter, **kwargs):
super(MyManifold, self).__init__(dim)
self.another_parameter = another_parameter
# Implement the main methods of MyManifold, for example belongs:
def belongs(self, point, atol=gs.atol):
"""Give a one-liner description of the method.
For example: Evaluate if a point belongs to MyManifold.
The signature of the method should match the signature of the parent
method, in this case the method `belongs` from the class `Manifold`.
List the parameters of the method.
In what follows, the ellipsis ... indicate either nothing
or any number n of elements, i.e. shape=[..., dim] means
shape=[dim] or shape=[n, dim] for any n.
All functions/methods of geomstats should work for any number
of inputs. In the case of the method `belongs`, it means:
for any number of input points.
For example:
Parameters
----------
point : array-like, shape=[..., dim]
Point to evaluate.
atol : float
Tolerance, unused.
Optional, default: backend atol
List the outputs of the method.
For example:
Returns
-------
belongs : array-like, shape=[...,]
Boolean evaluating if point belongs to the manifold.
"""
# Perform operations to check if point belongs
# to the manifold, for example:
belongs = point.shape[-1] == self.dim
if gs.ndim(point) == 2:
belongs = gs.tile([belongs], (point.shape[0],))
return belongs
# Another example of method of MyManifold.
def is_tangent(self, vector, base_point=None, atol=gs.atol):
"""Check whether vector is tangent to the manifold at base_point.
In what follows, the ellipsis ... indicates either nothing
or any number n of elements, i.e. shape=[..., dim] means
shape=[dim] or shape=[n, dim] for any n.
All functions/methods of geomstats should work for any number
of inputs. In the case of the function `is_tangent`, it means:
for any number of input vectors.
Parameters
----------
vector : array-like, shape=[..., dim]
Vector.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Optional, default: None.
atol : float
Absolute tolerance threshold
Returns
-------
is_tangent : bool
Boolean denoting if vector is a tangent vector at the base point.
"""
# Perform operations to determine if vector is a tangent vector,
# for example:
is_tangent = gs.shape(vector)[-1] == self.dim
if gs.ndim(vector) == 2:
is_tangent = gs.tile([is_tangent], (vector.shape[0],))
return is_tangent
def to_tangent(self, vector, base_point):
"""Project a vector to a tangent space of the manifold.
Parameters
----------
vector : array-like, shape=[..., dim]
Vector.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., dim]
Tangent vector at base point.
"""
def random_point(self, n_samples=1, bound=1.0):
"""Sample random points on the manifold.
If the manifold is compact, a uniform distribution is used.
Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.
bound : float
Bound of the interval in which to sample for non compact manifolds.
Optional, default: 1.
Returns
-------
samples : array-like, shape=[..., {dim, [n, n]}]
Points sampled on the hypersphere.
"""