-
Notifications
You must be signed in to change notification settings - Fork 239
/
lie_algebra_data.py
45 lines (40 loc) · 1.72 KB
/
lie_algebra_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import geomstats.backend as gs
from geomstats.algebra_utils import from_vector_to_diagonal_matrix
from geomstats.geometry.skew_symmetric_matrices import SkewSymmetricMatrices
from geomstats.geometry.special_euclidean import SpecialEuclidean
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
from tests.data_generation import TestData
class TestDataLieAlgebra(TestData):
def dimension_test_data(self):
smoke_data = [dict(algebra=SkewSymmetricMatrices(4), expected=6)]
return self.generate_tests(smoke_data)
def matrix_representation_and_belongs_test_data(self):
smoke_data = [
dict(algebra=SkewSymmetricMatrices(4), point=gs.random.rand(2, 6))
]
return self.generate_tests(smoke_data)
def orthonormal_basis_test_data(self):
smoke_data = [
dict(group=SpecialOrthogonal(3, equip=False), metric_mat_at_identity=None),
dict(
group=SpecialOrthogonal(3, equip=False),
metric_mat_at_identity=from_vector_to_diagonal_matrix(
gs.array([1.0, 2.0, 3.0])
),
),
]
return self.generate_tests(smoke_data)
def orthonormal_basis_se3_test_data(self):
smoke_data = [
dict(group=SpecialEuclidean(3, equip=False), metric_mat_at_identity=None),
dict(
group=SpecialEuclidean(3, equip=False),
metric_mat_at_identity=from_vector_to_diagonal_matrix(
gs.cast(
gs.arange(1, SpecialEuclidean(3, equip=False).dim + 1),
gs.float32,
)
),
),
]
return self.generate_tests(smoke_data)