-
Notifications
You must be signed in to change notification settings - Fork 134
/
test_torchutils.py
124 lines (100 loc) · 4.45 KB
/
test_torchutils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Tests for the PyTorch utility functions."""
import numpy as np
import torch
import torchtestcase
import pytest
from sbi.utils import torchutils
class TorchUtilsTest(torchtestcase.TorchTestCase):
def test_split_leading_dim(self):
x = torch.randn(24, 5)
self.assertEqual(torchutils.split_leading_dim(x, [-1]), x)
self.assertEqual(torchutils.split_leading_dim(x, [2, -1]), x.view(2, 12, 5))
self.assertEqual(
torchutils.split_leading_dim(x, [2, 3, -1]), x.view(2, 3, 4, 5)
)
with self.assertRaises(Exception):
self.assertEqual(torchutils.split_leading_dim(x, []), x)
with self.assertRaises(Exception):
self.assertEqual(torchutils.split_leading_dim(x, [5, 5]), x)
def test_merge_leading_dims(self):
x = torch.randn(2, 3, 4, 5)
self.assertEqual(torchutils.merge_leading_dims(x, 1), x)
self.assertEqual(torchutils.merge_leading_dims(x, 2), x.view(6, 4, 5))
self.assertEqual(torchutils.merge_leading_dims(x, 3), x.view(24, 5))
self.assertEqual(torchutils.merge_leading_dims(x, 4), x.view(120))
with self.assertRaises(Exception):
torchutils.merge_leading_dims(x, 0)
with self.assertRaises(Exception):
torchutils.merge_leading_dims(x, 5)
def test_split_merge_leading_dims_are_consistent(self):
x = torch.randn(2, 3, 4, 5)
y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 1), [2])
self.assertEqual(y, x)
y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 2), [2, 3])
self.assertEqual(y, x)
y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 3), [2, 3, 4])
self.assertEqual(y, x)
y = torchutils.split_leading_dim(
torchutils.merge_leading_dims(x, 4), [2, 3, 4, 5]
)
self.assertEqual(y, x)
def test_repeat_rows(self):
x = torch.randn(2, 3, 4, 5)
self.assertEqual(torchutils.repeat_rows(x, 1), x)
y = torchutils.repeat_rows(x, 2)
self.assertEqual(y.shape, torch.Size([4, 3, 4, 5]))
self.assertEqual(x[0], y[0])
self.assertEqual(x[0], y[1])
self.assertEqual(x[1], y[2])
self.assertEqual(x[1], y[3])
with self.assertRaises(Exception):
torchutils.repeat_rows(x, 0)
def test_logabsdet(self):
size = 10
matrix = torch.randn(size, size)
logabsdet = torchutils.logabsdet(matrix)
logabsdet_ref = torch.log(torch.abs(matrix.det()))
self.eps = 1e-6
self.assertEqual(logabsdet, logabsdet_ref)
def test_random_orthogonal(self):
size = 100
matrix = torchutils.random_orthogonal(size)
self.assertIsInstance(matrix, torch.Tensor)
self.assertEqual(matrix.shape, torch.Size([size, size]))
self.eps = 1e-5
unit = torch.eye(size, size)
self.assertEqual(matrix @ matrix.t(), unit)
self.assertEqual(matrix.t() @ matrix, unit)
self.assertEqual(matrix.t(), matrix.inverse())
self.assertEqual(torch.abs(matrix.det()), torch.tensor(1.0))
def test_searchsorted(self):
bin_locations = torch.linspace(0, 1, 10) # 9 bins == 10 locations
left_boundaries = bin_locations[:-1]
right_boundaries = bin_locations[:-1] + 0.1
mid_points = bin_locations[:-1] + 0.05
for inputs in [left_boundaries, right_boundaries, mid_points]:
with self.subTest(inputs=inputs):
idx = torchutils.searchsorted(bin_locations[None, :], inputs)
self.assertEqual(idx, torch.arange(0, 9))
def test_searchsorted_arbitrary_shape(self):
shape = [2, 3, 4]
bin_locations = torch.linspace(0, 1, 10).repeat(*shape, 1)
inputs = torch.rand(*shape)
idx = torchutils.searchsorted(bin_locations, inputs)
self.assertEqual(idx.shape, inputs.shape)
def test_box_distribution():
bu1 = torchutils.BoxUniform(low=0.0, high=torch.Tensor([3.0, 3.0, 3.0]))
assert bu1.event_shape == torch.Size([3])
def test_make_conform():
t1 = torch.tensor([0.0, -1.0, 1.0])
t2 = torch.tensor([[1, 2, 3]])
t3 = torchutils.make_shapes_conform(t1, t2)
assert (t3.squeeze() == t1).all()
assert t3.ndim == t2.ndim
def test_atleast_2d():
t1 = np.array([0.0, -1.0, 1.0])
t2 = torch.tensor([[1, 2, 3]])
t3, t4 = torchutils.atleast_2d(t1, t2)
assert isinstance(t3, torch.Tensor)
assert t3.ndim == 2
assert t4.ndim == 2