-
Notifications
You must be signed in to change notification settings - Fork 248
/
test_backend_numpy.py
115 lines (88 loc) · 3.6 KB
/
test_backend_numpy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Unit tests for numpy backend.
"""
import importlib
import os
import unittest
import warnings
import geomstats.backend as gs
from geomstats.special_orthogonal_group import SpecialOrthogonalGroup
class TestBackendNumpy(unittest.TestCase):
_multiprocess_can_split_ = True
@classmethod
def setUpClass(cls):
cls.initial_backend = os.environ['GEOMSTATS_BACKEND']
os.environ['GEOMSTATS_BACKEND'] = 'numpy'
importlib.reload(gs)
@classmethod
def tearDownClass(cls):
os.environ['GEOMSTATS_BACKEND'] = cls.initial_backend
importlib.reload(gs)
def setUp(self):
warnings.simplefilter('ignore', category=ImportWarning)
self.so3_group = SpecialOrthogonalGroup(n=3)
self.n_samples = 2
def test_logm(self):
point = gs.array([[2., 0., 0.],
[0., 3., 0.],
[0., 0., 4.]])
result = gs.linalg.logm(point)
expected = gs.array([[0.693147180, 0., 0.],
[0., 1.098612288, 0.],
[0., 0., 1.38629436]])
self.assertTrue(gs.allclose(result, expected))
def test_expm_and_logm(self):
point = gs.array([[2., 0., 0.],
[0., 3., 0.],
[0., 0., 4.]])
result = gs.linalg.expm(gs.linalg.logm(point))
expected = point
self.assertTrue(gs.allclose(result, expected))
def test_expm_vectorization(self):
point = gs.array([[[2., 0., 0.],
[0., 3., 0.],
[0., 0., 4.]],
[[1., 0., 0.],
[0., 5., 0.],
[0., 0., 6.]]])
expected = gs.array([[[7.38905609, 0., 0.],
[0., 20.0855369, 0.],
[0., 0., 54.5981500]],
[[2.718281828, 0., 0.],
[0., 148.413159, 0.],
[0., 0., 403.42879349]]])
result = gs.linalg.expm(point)
self.assertTrue(gs.allclose(result, expected))
def test_logm_vectorization_diagonal(self):
point = gs.array([[[2., 0., 0.],
[0., 3., 0.],
[0., 0., 4.]],
[[1., 0., 0.],
[0., 5., 0.],
[0., 0., 6.]]])
expected = gs.array([[[0.693147180, 0., 0.],
[0., 1.09861228866, 0.],
[0., 0., 1.38629436]],
[[0., 0., 0.],
[0., 1.609437912, 0.],
[0., 0., 1.79175946]]])
result = gs.linalg.logm(point)
self.assertTrue(gs.allclose(result, expected))
def test_expm_and_logm_vectorization_random_rotation(self):
point = self.so3_group.random_uniform(self.n_samples)
point = self.so3_group.matrix_from_rotation_vector(point)
result = gs.linalg.expm(gs.linalg.logm(point))
expected = point
self.assertTrue(gs.allclose(result, expected))
def test_expm_and_logm_vectorization(self):
point = gs.array([[[2., 0., 0.],
[0., 3., 0.],
[0., 0., 4.]],
[[1., 0., 0.],
[0., 5., 0.],
[0., 0., 6.]]])
result = gs.linalg.expm(gs.linalg.logm(point))
expected = point
self.assertTrue(gs.allclose(result, expected))
if __name__ == '__main__':
unittest.main()