-
Notifications
You must be signed in to change notification settings - Fork 239
/
full_rank_matrices.py
107 lines (92 loc) · 3.59 KB
/
full_rank_matrices.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
r"""Module exposing the full rank Euclidean matrices `math:`R_*^{m\times n}`"""
import geomstats.backend as gs
from geomstats.geometry.base import OpenSet
from geomstats.geometry.matrices import Matrices
from geomstats.geometry.matrices import MatricesMetric
class FullRankMatrices(OpenSet):
r"""Class for `math:`R_*^{m\times n}` matrices of dimension m x n and full rank.
Parameters
----------
m : int
Integer representing the shape of the matrices: m x n
n : int
Integer representing the shape of the matrices: m x n
"""
def __init__(self, m, n, **kwargs):
if "dim" not in kwargs.keys():
kwargs["dim"] = m * n
super(FullRankMatrices, self).__init__(
ambient_space=Matrices(m, n), metric=MatricesMetric(m, n), **kwargs
)
self.rank = min(m, n)
def belongs(self, point, atol=gs.atol):
r"""Check if the matrix belongs to `math:`R_*^{m\times n}`.
Parameters
----------
point : array-like, shape=[..., m, n]
Matrix to be checked.
Returns
-------
belongs : Boolean denoting if point is in `math:`R_*^{m\times n}`
"""
has_right_size = self.ambient_space.belongs(point)
has_right_rank = gs.where(
gs.linalg.matrix_rank(point) == self.rank, True, False
)
belongs = gs.logical_and(gs.array(has_right_size), has_right_rank)
return belongs
def projection(self, point):
r"""Project a matrix to the set of full rank matrices.
As the space of full rank matrices is dense in the space of matrices,
this is not a projection per se, but a regularization if the matrix input X
is not already full rank: `math:`X + \epsilon [I_rank, 0]` is returned
where :math:`\epsilon=gs.atol`
Parameters
----------
point : array-like, shape=[..., dim_embedding]
Point in embedding manifold.
Returns
-------
projected : array-like, shape=[..., dim_embedding]
Projected point.
"""
belongs = self.belongs(point)
regularization = gs.einsum(
"...,ij->...ij",
gs.where(~belongs, gs.atol, 0.0),
gs.eye(self.ambient_space.shape[0], self.ambient_space.shape[1]),
)
projected = point + regularization
return projected
def random_point(self, n_samples=1, bound=1.0, n_iter=100):
r"""Sample in `math:`R_*^{m\times n}` from a normal distribution.
Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.
bound: float
Bound of the interval in which to sample each matrix entry.
Optional, default: 1.
n_iter : int
Maximum number of trials to sample a matrix with full rank
Optional, default: 100.
Returns
-------
samples : array-like, shape=[..., m, n]
Point sampled on `math:`R_*^{m\times n}`
"""
m = self.ambient_space.shape[0]
n = self.ambient_space.shape[1]
sample = []
n_accepted, iteration = 0, 0
while n_accepted < n_samples and iteration < n_iter:
raw_samples = gs.random.normal(size=(n_samples - n_accepted, m, n))
ranks = gs.linalg.matrix_rank(raw_samples)
selected = ranks == self.rank
sample.append(raw_samples[selected])
n_accepted += gs.sum(selected)
iteration += 1
if n_samples == 1:
return sample[0][0]
return gs.concatenate(sample)