30
30
import numpy as np
31
31
32
32
from bayespy .nodes import (Concatenate ,
33
- GaussianARD )
33
+ GaussianARD ,
34
+ Gamma )
34
35
35
36
from bayespy .utils import random
36
37
@@ -88,6 +89,21 @@ def test_init(self):
88
89
self .assertEqual (Y .plates , (9 ,))
89
90
self .assertEqual (Y .dims , ( (), () ))
90
91
92
+ # Constant parent
93
+ X1 = [7.2 , 3.5 ]
94
+ X2 = GaussianARD (0 , 1 , plates = (3 ,), shape = ())
95
+ Y = Concatenate (X1 , X2 )
96
+ self .assertEqual (Y .plates , (5 ,))
97
+ self .assertEqual (Y .dims , ( (), () ))
98
+
99
+ # Different moments
100
+ X1 = GaussianARD (0 , 1 , plates = (3 ,))
101
+ X2 = Gamma (1 , 1 , plates = (4 ,))
102
+ self .assertRaises (ValueError ,
103
+ Concatenate ,
104
+ X1 ,
105
+ X2 )
106
+
91
107
# Incompatible shapes
92
108
X1 = GaussianARD (0 , 1 , plates = (3 ,), shape = (2 ,))
93
109
X2 = GaussianARD (0 , 1 , plates = (2 ,), shape = ())
@@ -119,14 +135,14 @@ def test_message_to_child(self):
119
135
u1 = X1 .get_moments ()
120
136
u2 = X2 .get_moments ()
121
137
u = Y .get_moments ()
122
- self .assertAllClose (u [0 ][: 2 ] * np .ones ((2 ,)),
123
- u1 [0 ] * np .ones ((2 ,)))
124
- self .assertAllClose (u [1 ][: 2 ] * np .ones ((2 ,)),
125
- u1 [1 ] * np .ones ((2 ,)))
126
- self .assertAllClose (u [0 ][ 2 :] * np .ones ((3 ,)),
127
- u2 [0 ] * np .ones ((3 ,)))
128
- self .assertAllClose (u [1 ][ 2 :] * np .ones ((3 ,)),
129
- u2 [1 ] * np .ones ((3 ,)))
138
+ self .assertAllClose (( u [0 ]* np .ones ((5 ,)))[: 2 ] ,
139
+ u1 [0 ]* np .ones ((2 ,)))
140
+ self .assertAllClose (( u [1 ]* np .ones ((5 ,)))[: 2 ] ,
141
+ u1 [1 ]* np .ones ((2 ,)))
142
+ self .assertAllClose (( u [0 ]* np .ones ((5 ,)))[ 2 :] ,
143
+ u2 [0 ]* np .ones ((3 ,)))
144
+ self .assertAllClose (( u [1 ]* np .ones ((5 ,)))[ 2 :] ,
145
+ u2 [1 ]* np .ones ((3 ,)))
130
146
131
147
# Two parents with shapes
132
148
X1 = GaussianARD (0 , 1 , plates = (2 ,), shape = (4 ,))
@@ -144,6 +160,39 @@ def test_message_to_child(self):
144
160
self .assertAllClose ((u [1 ]* np .ones ((5 ,4 ,4 )))[2 :],
145
161
u2 [1 ]* np .ones ((3 ,4 ,4 )))
146
162
163
+ # Test with non-constant axis
164
+ X1 = GaussianARD (0 , 1 , plates = (2 ,4 ), shape = ())
165
+ X2 = GaussianARD (0 , 1 , plates = (3 ,4 ), shape = ())
166
+ Y = Concatenate (X1 , X2 , axis = - 2 )
167
+ u1 = X1 .get_moments ()
168
+ u2 = X2 .get_moments ()
169
+ u = Y .get_moments ()
170
+ self .assertAllClose ((u [0 ]* np .ones ((5 ,4 )))[:2 ],
171
+ u1 [0 ]* np .ones ((2 ,4 )))
172
+ self .assertAllClose ((u [1 ]* np .ones ((5 ,4 )))[:2 ],
173
+ u1 [1 ]* np .ones ((2 ,4 )))
174
+ self .assertAllClose ((u [0 ]* np .ones ((5 ,4 )))[2 :],
175
+ u2 [0 ]* np .ones ((3 ,4 )))
176
+ self .assertAllClose ((u [1 ]* np .ones ((5 ,4 )))[2 :],
177
+ u2 [1 ]* np .ones ((3 ,4 )))
178
+
179
+ # Test with constant parent
180
+ X1 = np .random .randn (2 , 4 )
181
+ X2 = GaussianARD (0 , 1 , plates = (3 ,), shape = (4 ,))
182
+ Y = Concatenate (X1 , X2 )
183
+ u1 = Y .parents [0 ].get_moments ()
184
+ u2 = X2 .get_moments ()
185
+ u = Y .get_moments ()
186
+ self .assertAllClose ((u [0 ]* np .ones ((5 ,4 )))[:2 ],
187
+ u1 [0 ]* np .ones ((2 ,4 )))
188
+ self .assertAllClose ((u [1 ]* np .ones ((5 ,4 ,4 )))[:2 ],
189
+ u1 [1 ]* np .ones ((2 ,4 ,4 )))
190
+ self .assertAllClose ((u [0 ]* np .ones ((5 ,4 )))[2 :],
191
+ u2 [0 ]* np .ones ((3 ,4 )))
192
+ self .assertAllClose ((u [1 ]* np .ones ((5 ,4 ,4 )))[2 :],
193
+ u2 [1 ]* np .ones ((3 ,4 ,4 )))
194
+
195
+
147
196
pass
148
197
149
198
@@ -208,6 +257,24 @@ def test_message_to_parent(self):
208
257
self .assertAllClose ((m [1 ]* np .ones ((5 ,4 )))[2 :],
209
258
m2 [1 ]* np .ones ((3 ,4 )))
210
259
260
+ # Constant parent
261
+ X1 = np .random .randn (2 ,4 ,6 )
262
+ X2 = GaussianARD (0 , 1 , plates = (3 ,), shape = (4 ,6 ))
263
+ Z = Concatenate (X1 , X2 )
264
+ Y = GaussianARD (Z , 1 )
265
+ Y .observe (np .random .randn (* Y .get_shape (0 )))
266
+ m1 = Z ._message_to_parent (0 )
267
+ m2 = X2 ._message_from_children ()
268
+ m = Z ._message_from_children ()
269
+ self .assertAllClose ((m [0 ]* np .ones ((5 ,4 ,6 )))[:2 ],
270
+ m1 [0 ]* np .ones ((2 ,4 ,6 )))
271
+ self .assertAllClose ((m [1 ]* np .ones ((5 ,4 ,6 ,4 ,6 )))[:2 ],
272
+ m1 [1 ]* np .ones ((2 ,4 ,6 ,4 ,6 )))
273
+ self .assertAllClose ((m [0 ]* np .ones ((5 ,4 ,6 )))[2 :],
274
+ m2 [0 ]* np .ones ((3 ,4 ,6 )))
275
+ self .assertAllClose ((m [1 ]* np .ones ((5 ,4 ,6 ,4 ,6 )))[2 :],
276
+ m2 [1 ]* np .ones ((3 ,4 ,6 ,4 ,6 )))
277
+
211
278
pass
212
279
213
280
0 commit comments