-
Notifications
You must be signed in to change notification settings - Fork 239
/
preprocessing.py
179 lines (150 loc) · 6.81 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Transformer for manifold-valued data.
Lead author: Nicolas Guigui.
"""
from sklearn.base import BaseEstimator, TransformerMixin
import geomstats.backend as gs
from geomstats.geometry.lie_group import LieGroup
from geomstats.geometry.matrices import Matrices
from geomstats.geometry.riemannian_metric import RiemannianMetric
from geomstats.geometry.skew_symmetric_matrices import SkewSymmetricMatrices
from geomstats.geometry.symmetric_matrices import SymmetricMatrices
from geomstats.learning.exponential_barycenter import ExponentialBarycenter
from geomstats.learning.frechet_mean import FrechetMean
class ToTangentSpace(BaseEstimator, TransformerMixin):
"""Lift data to a tangent space.
Compute the logs of all data points and reshape them to
1d vectors if necessary. This means that all the data points, that belong
to a possibly non-linear manifold are lifted to one of the tangent space of
the manifold, which is a vector space. By default, the mean of the data
is computed (with the FrechetMean or the ExponentialBarycenter estimator,
as appropriate) and the tangent space at the mean is used. Any other base
point can be passed. The data points are then represented by the initial
velocities of the geodesics that lead from base_point to each data point.
Any machine learning algorithm can then be used with the output array.
Parameters
----------
geometry : {Manifold ,LieGroup or RiemannianMetric}
Metric or Lie group to use to compute the log and exp. If a Lie group
is passed, its group exp/log will be used, which don't necessarily
correspond to a Riemannian metric. To use a `metric` on the Lie group,
explicitly pass `geometry=metric`
**kwargs : key-word arguments for the FrechetMean/ExponentialBarycenter
estimator.
"""
def __init__(self, geometry, **kwargs):
if isinstance(geometry, LieGroup):
self._used_geometry = geometry
self.estimator = ExponentialBarycenter(group=self._used_geometry, **kwargs)
else:
if hasattr(geometry, "metric"):
self._used_geometry = geometry.metric
elif isinstance(geometry, RiemannianMetric):
self._used_geometry = geometry
else:
raise ValueError(
"The input geometry must be either a "
"Manifold equipped with a "
"RiemannianMetric, or a RiemannianMetric or a"
" LieGroup"
)
self.estimator = FrechetMean(metric=self._used_geometry, **kwargs)
self.point_type = geometry.default_point_type
self.geometry = geometry
def fit(self, X, y=None, weights=None, base_point=None):
"""Compute the central point at which to take the log.
This method is only used if `base_point=None` to compute the mean of
the input data.
Parameters
----------
X : array-like, shape=[..., {dim, [n, n]}]
The training input samples.
y : array-like, shape (n_samples,) or (n_samples, n_outputs)
Ignored
weights : array-like, shape=[..., 1]
Weights associated to the points.
Optional, default: None
base_point : array-like, shape=[{dim, [n, n]}]
Point similar to the input data from which to compute the logs.
Optional, default: None.
Returns
-------
self : object
Returns self.
"""
if base_point is None:
self.estimator.fit(X, y, weights)
return self
def transform(self, X, base_point=None):
"""Lift data to a tangent space.
Compute the logs of all data point and reshapes them to
1d vectors if necessary. By default the logs are taken at the mean
but any other base point can be passed. Any machine learning
algorithm can then be used with the output array.
Parameters
----------
X : array-like, shape=[..., {dim, [n, n]}]
Data to transform.
y : Ignored (Compliance with scikit-learn interface)
base_point : array-like, shape={dim, [n,n]}, optional (mean)
Point on the manifold, the returned samples will be tangent
vectors at the base point.
Returns
-------
X_new : array-like, shape=[..., dim]
Lifted data.
"""
if base_point is None:
base_point = self.estimator.estimate_
if self.estimator.estimate_ is None:
raise RuntimeError(
"fit needs to be called first or a " "base_point passed."
)
tangent_vecs = self._used_geometry.log(X, base_point=base_point)
if self.point_type == "vector":
return tangent_vecs
if gs.all(Matrices.is_symmetric(tangent_vecs)):
X = SymmetricMatrices.to_vector(tangent_vecs)
elif gs.all(Matrices.is_skew_symmetric(tangent_vecs)):
X = SkewSymmetricMatrices(tangent_vecs.shape[-1]).basis_representation(
tangent_vecs
)
else:
X = gs.reshape(tangent_vecs, (len(X), -1))
return X
def inverse_transform(self, X, base_point=None):
"""Reconstruction of X.
The reconstruction will match X_original whose transform would be X.
Parameters
----------
X : array-like, shape=[..., dim]
New data, where dim is the dimension of the manifold data belong
to.
base_point : array-like, shape={dim, [n,n]}, optional (mean)
Point on the manifold, where the input samples are tangent
vectors.
Returns
-------
X_original : array-like, shape=[..., {dim, [n, n]}
Data lying on the manifold.
"""
if base_point is None:
base_point = self.estimator.estimate_
if self.estimator.estimate_ is None:
raise RuntimeError(
"fit needs to be called first or a " "base_point passed."
)
if self.point_type == "matrix":
n_base_point = base_point.shape[-1]
n_vecs = X.shape[-1]
dim_sym = int(n_base_point * (n_base_point + 1) / 2)
dim_skew = int(n_base_point * (n_base_point - 1) / 2)
if gs.all(Matrices.is_symmetric(base_point)) and dim_sym == n_vecs:
tangent_vecs = SymmetricMatrices(base_point.shape[-1]).from_vector(X)
elif dim_skew == n_vecs:
tangent_vecs = SkewSymmetricMatrices(dim_skew).matrix_representation(X)
else:
dim = base_point.shape[-1]
tangent_vecs = gs.reshape(X, (len(X), dim, dim))
else:
tangent_vecs = X
return self._used_geometry.exp(tangent_vecs, base_point)