Skip to content

Commit

Permalink
TST: Add unit tests for constant nodes in concatenation
Browse files Browse the repository at this point in the history
  • Loading branch information
jluttine committed Mar 14, 2015
1 parent ad1b271 commit 2f9869d
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 11 deletions.
8 changes: 6 additions & 2 deletions bayespy/inference/vmp/nodes/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from bayespy.utils import misc

from .deterministic import Deterministic
from .node import Moments

class Concatenate(Deterministic):
"""
Expand Down Expand Up @@ -58,8 +59,11 @@ def __init__(self, *nodes, axis=-1, **kwargs):
self._parent_moments = (parent_moments,) * len(nodes)
self._moments = parent_moments
# Convert nodes
nodes = [self._ensure_moments(node, self._parent_moments[0])
for node in nodes]
try:
nodes = [self._ensure_moments(node, self._parent_moments[0])
for node in nodes]
except Moments.NoConverterError:
raise ValueError("Parents have different moments")
# Dimensionality of the node
dims = tuple([dim for dim in nodes[0].dims])
for node in nodes:
Expand Down
85 changes: 76 additions & 9 deletions bayespy/inference/vmp/nodes/tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
import numpy as np

from bayespy.nodes import (Concatenate,
GaussianARD)
GaussianARD,
Gamma)

from bayespy.utils import random

Expand Down Expand Up @@ -88,6 +89,21 @@ def test_init(self):
self.assertEqual(Y.plates, (9,))
self.assertEqual(Y.dims, ( (), () ))

# Constant parent
X1 = [7.2, 3.5]
X2 = GaussianARD(0, 1, plates=(3,), shape=())
Y = Concatenate(X1, X2)
self.assertEqual(Y.plates, (5,))
self.assertEqual(Y.dims, ( (), () ))

# Different moments
X1 = GaussianARD(0, 1, plates=(3,))
X2 = Gamma(1, 1, plates=(4,))
self.assertRaises(ValueError,
Concatenate,
X1,
X2)

# Incompatible shapes
X1 = GaussianARD(0, 1, plates=(3,), shape=(2,))
X2 = GaussianARD(0, 1, plates=(2,), shape=())
Expand Down Expand Up @@ -119,14 +135,14 @@ def test_message_to_child(self):
u1 = X1.get_moments()
u2 = X2.get_moments()
u = Y.get_moments()
self.assertAllClose(u[0][:2] * np.ones((2,)),
u1[0] * np.ones((2,)))
self.assertAllClose(u[1][:2] * np.ones((2,)),
u1[1] * np.ones((2,)))
self.assertAllClose(u[0][2:] * np.ones((3,)),
u2[0] * np.ones((3,)))
self.assertAllClose(u[1][2:] * np.ones((3,)),
u2[1] * np.ones((3,)))
self.assertAllClose((u[0]*np.ones((5,)))[:2],
u1[0]*np.ones((2,)))
self.assertAllClose((u[1]*np.ones((5,)))[:2],
u1[1]*np.ones((2,)))
self.assertAllClose((u[0]*np.ones((5,)))[2:],
u2[0]*np.ones((3,)))
self.assertAllClose((u[1]*np.ones((5,)))[2:],
u2[1]*np.ones((3,)))

# Two parents with shapes
X1 = GaussianARD(0, 1, plates=(2,), shape=(4,))
Expand All @@ -144,6 +160,39 @@ def test_message_to_child(self):
self.assertAllClose((u[1]*np.ones((5,4,4)))[2:],
u2[1]*np.ones((3,4,4)))

# Test with non-constant axis
X1 = GaussianARD(0, 1, plates=(2,4), shape=())
X2 = GaussianARD(0, 1, plates=(3,4), shape=())
Y = Concatenate(X1, X2, axis=-2)
u1 = X1.get_moments()
u2 = X2.get_moments()
u = Y.get_moments()
self.assertAllClose((u[0]*np.ones((5,4)))[:2],
u1[0]*np.ones((2,4)))
self.assertAllClose((u[1]*np.ones((5,4)))[:2],
u1[1]*np.ones((2,4)))
self.assertAllClose((u[0]*np.ones((5,4)))[2:],
u2[0]*np.ones((3,4)))
self.assertAllClose((u[1]*np.ones((5,4)))[2:],
u2[1]*np.ones((3,4)))

# Test with constant parent
X1 = np.random.randn(2, 4)
X2 = GaussianARD(0, 1, plates=(3,), shape=(4,))
Y = Concatenate(X1, X2)
u1 = Y.parents[0].get_moments()
u2 = X2.get_moments()
u = Y.get_moments()
self.assertAllClose((u[0]*np.ones((5,4)))[:2],
u1[0]*np.ones((2,4)))
self.assertAllClose((u[1]*np.ones((5,4,4)))[:2],
u1[1]*np.ones((2,4,4)))
self.assertAllClose((u[0]*np.ones((5,4)))[2:],
u2[0]*np.ones((3,4)))
self.assertAllClose((u[1]*np.ones((5,4,4)))[2:],
u2[1]*np.ones((3,4,4)))


pass


Expand Down Expand Up @@ -208,6 +257,24 @@ def test_message_to_parent(self):
self.assertAllClose((m[1]*np.ones((5,4)))[2:],
m2[1]*np.ones((3,4)))

# Constant parent
X1 = np.random.randn(2,4,6)
X2 = GaussianARD(0, 1, plates=(3,), shape=(4,6))
Z = Concatenate(X1, X2)
Y = GaussianARD(Z, 1)
Y.observe(np.random.randn(*Y.get_shape(0)))
m1 = Z._message_to_parent(0)
m2 = X2._message_from_children()
m = Z._message_from_children()
self.assertAllClose((m[0]*np.ones((5,4,6)))[:2],
m1[0]*np.ones((2,4,6)))
self.assertAllClose((m[1]*np.ones((5,4,6,4,6)))[:2],
m1[1]*np.ones((2,4,6,4,6)))
self.assertAllClose((m[0]*np.ones((5,4,6)))[2:],
m2[0]*np.ones((3,4,6)))
self.assertAllClose((m[1]*np.ones((5,4,6,4,6)))[2:],
m2[1]*np.ones((3,4,6,4,6)))

pass


Expand Down

0 comments on commit 2f9869d

Please sign in to comment.