Skip to content

Commit 2f9869d

Browse files
committed
TST: Add unit tests for constant nodes in concatenation
1 parent ad1b271 commit 2f9869d

File tree

2 files changed

+82
-11
lines changed

2 files changed

+82
-11
lines changed

bayespy/inference/vmp/nodes/concatenate.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from bayespy.utils import misc
2727

2828
from .deterministic import Deterministic
29+
from .node import Moments
2930

3031
class Concatenate(Deterministic):
3132
"""
@@ -58,8 +59,11 @@ def __init__(self, *nodes, axis=-1, **kwargs):
5859
self._parent_moments = (parent_moments,) * len(nodes)
5960
self._moments = parent_moments
6061
# Convert nodes
61-
nodes = [self._ensure_moments(node, self._parent_moments[0])
62-
for node in nodes]
62+
try:
63+
nodes = [self._ensure_moments(node, self._parent_moments[0])
64+
for node in nodes]
65+
except Moments.NoConverterError:
66+
raise ValueError("Parents have different moments")
6367
# Dimensionality of the node
6468
dims = tuple([dim for dim in nodes[0].dims])
6569
for node in nodes:

bayespy/inference/vmp/nodes/tests/test_concatenate.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
import numpy as np
3131

3232
from bayespy.nodes import (Concatenate,
33-
GaussianARD)
33+
GaussianARD,
34+
Gamma)
3435

3536
from bayespy.utils import random
3637

@@ -88,6 +89,21 @@ def test_init(self):
8889
self.assertEqual(Y.plates, (9,))
8990
self.assertEqual(Y.dims, ( (), () ))
9091

92+
# Constant parent
93+
X1 = [7.2, 3.5]
94+
X2 = GaussianARD(0, 1, plates=(3,), shape=())
95+
Y = Concatenate(X1, X2)
96+
self.assertEqual(Y.plates, (5,))
97+
self.assertEqual(Y.dims, ( (), () ))
98+
99+
# Different moments
100+
X1 = GaussianARD(0, 1, plates=(3,))
101+
X2 = Gamma(1, 1, plates=(4,))
102+
self.assertRaises(ValueError,
103+
Concatenate,
104+
X1,
105+
X2)
106+
91107
# Incompatible shapes
92108
X1 = GaussianARD(0, 1, plates=(3,), shape=(2,))
93109
X2 = GaussianARD(0, 1, plates=(2,), shape=())
@@ -119,14 +135,14 @@ def test_message_to_child(self):
119135
u1 = X1.get_moments()
120136
u2 = X2.get_moments()
121137
u = Y.get_moments()
122-
self.assertAllClose(u[0][:2] * np.ones((2,)),
123-
u1[0] * np.ones((2,)))
124-
self.assertAllClose(u[1][:2] * np.ones((2,)),
125-
u1[1] * np.ones((2,)))
126-
self.assertAllClose(u[0][2:] * np.ones((3,)),
127-
u2[0] * np.ones((3,)))
128-
self.assertAllClose(u[1][2:] * np.ones((3,)),
129-
u2[1] * np.ones((3,)))
138+
self.assertAllClose((u[0]*np.ones((5,)))[:2],
139+
u1[0]*np.ones((2,)))
140+
self.assertAllClose((u[1]*np.ones((5,)))[:2],
141+
u1[1]*np.ones((2,)))
142+
self.assertAllClose((u[0]*np.ones((5,)))[2:],
143+
u2[0]*np.ones((3,)))
144+
self.assertAllClose((u[1]*np.ones((5,)))[2:],
145+
u2[1]*np.ones((3,)))
130146

131147
# Two parents with shapes
132148
X1 = GaussianARD(0, 1, plates=(2,), shape=(4,))
@@ -144,6 +160,39 @@ def test_message_to_child(self):
144160
self.assertAllClose((u[1]*np.ones((5,4,4)))[2:],
145161
u2[1]*np.ones((3,4,4)))
146162

163+
# Test with non-constant axis
164+
X1 = GaussianARD(0, 1, plates=(2,4), shape=())
165+
X2 = GaussianARD(0, 1, plates=(3,4), shape=())
166+
Y = Concatenate(X1, X2, axis=-2)
167+
u1 = X1.get_moments()
168+
u2 = X2.get_moments()
169+
u = Y.get_moments()
170+
self.assertAllClose((u[0]*np.ones((5,4)))[:2],
171+
u1[0]*np.ones((2,4)))
172+
self.assertAllClose((u[1]*np.ones((5,4)))[:2],
173+
u1[1]*np.ones((2,4)))
174+
self.assertAllClose((u[0]*np.ones((5,4)))[2:],
175+
u2[0]*np.ones((3,4)))
176+
self.assertAllClose((u[1]*np.ones((5,4)))[2:],
177+
u2[1]*np.ones((3,4)))
178+
179+
# Test with constant parent
180+
X1 = np.random.randn(2, 4)
181+
X2 = GaussianARD(0, 1, plates=(3,), shape=(4,))
182+
Y = Concatenate(X1, X2)
183+
u1 = Y.parents[0].get_moments()
184+
u2 = X2.get_moments()
185+
u = Y.get_moments()
186+
self.assertAllClose((u[0]*np.ones((5,4)))[:2],
187+
u1[0]*np.ones((2,4)))
188+
self.assertAllClose((u[1]*np.ones((5,4,4)))[:2],
189+
u1[1]*np.ones((2,4,4)))
190+
self.assertAllClose((u[0]*np.ones((5,4)))[2:],
191+
u2[0]*np.ones((3,4)))
192+
self.assertAllClose((u[1]*np.ones((5,4,4)))[2:],
193+
u2[1]*np.ones((3,4,4)))
194+
195+
147196
pass
148197

149198

@@ -208,6 +257,24 @@ def test_message_to_parent(self):
208257
self.assertAllClose((m[1]*np.ones((5,4)))[2:],
209258
m2[1]*np.ones((3,4)))
210259

260+
# Constant parent
261+
X1 = np.random.randn(2,4,6)
262+
X2 = GaussianARD(0, 1, plates=(3,), shape=(4,6))
263+
Z = Concatenate(X1, X2)
264+
Y = GaussianARD(Z, 1)
265+
Y.observe(np.random.randn(*Y.get_shape(0)))
266+
m1 = Z._message_to_parent(0)
267+
m2 = X2._message_from_children()
268+
m = Z._message_from_children()
269+
self.assertAllClose((m[0]*np.ones((5,4,6)))[:2],
270+
m1[0]*np.ones((2,4,6)))
271+
self.assertAllClose((m[1]*np.ones((5,4,6,4,6)))[:2],
272+
m1[1]*np.ones((2,4,6,4,6)))
273+
self.assertAllClose((m[0]*np.ones((5,4,6)))[2:],
274+
m2[0]*np.ones((3,4,6)))
275+
self.assertAllClose((m[1]*np.ones((5,4,6,4,6)))[2:],
276+
m2[1]*np.ones((3,4,6,4,6)))
277+
211278
pass
212279

213280

0 commit comments

Comments
 (0)