From 2f9869d5ad091a4d92cf74ec04a5084ae4f65f37 Mon Sep 17 00:00:00 2001 From: Jaakko Luttinen Date: Sat, 14 Mar 2015 15:38:42 +0200 Subject: [PATCH] TST: Add unit tests for constant nodes in concatenation --- bayespy/inference/vmp/nodes/concatenate.py | 8 +- .../vmp/nodes/tests/test_concatenate.py | 85 +++++++++++++++++-- 2 files changed, 82 insertions(+), 11 deletions(-) diff --git a/bayespy/inference/vmp/nodes/concatenate.py b/bayespy/inference/vmp/nodes/concatenate.py index e1b9231df..e687b9e90 100644 --- a/bayespy/inference/vmp/nodes/concatenate.py +++ b/bayespy/inference/vmp/nodes/concatenate.py @@ -26,6 +26,7 @@ from bayespy.utils import misc from .deterministic import Deterministic +from .node import Moments class Concatenate(Deterministic): """ @@ -58,8 +59,11 @@ def __init__(self, *nodes, axis=-1, **kwargs): self._parent_moments = (parent_moments,) * len(nodes) self._moments = parent_moments # Convert nodes - nodes = [self._ensure_moments(node, self._parent_moments[0]) - for node in nodes] + try: + nodes = [self._ensure_moments(node, self._parent_moments[0]) + for node in nodes] + except Moments.NoConverterError: + raise ValueError("Parents have different moments") # Dimensionality of the node dims = tuple([dim for dim in nodes[0].dims]) for node in nodes: diff --git a/bayespy/inference/vmp/nodes/tests/test_concatenate.py b/bayespy/inference/vmp/nodes/tests/test_concatenate.py index 7e8ea9464..2de2c815b 100644 --- a/bayespy/inference/vmp/nodes/tests/test_concatenate.py +++ b/bayespy/inference/vmp/nodes/tests/test_concatenate.py @@ -30,7 +30,8 @@ import numpy as np from bayespy.nodes import (Concatenate, - GaussianARD) + GaussianARD, + Gamma) from bayespy.utils import random @@ -88,6 +89,21 @@ def test_init(self): self.assertEqual(Y.plates, (9,)) self.assertEqual(Y.dims, ( (), () )) + # Constant parent + X1 = [7.2, 3.5] + X2 = GaussianARD(0, 1, plates=(3,), shape=()) + Y = Concatenate(X1, X2) + self.assertEqual(Y.plates, (5,)) + self.assertEqual(Y.dims, ( (), () )) + + # Different moments + X1 = GaussianARD(0, 1, plates=(3,)) + X2 = Gamma(1, 1, plates=(4,)) + self.assertRaises(ValueError, + Concatenate, + X1, + X2) + # Incompatible shapes X1 = GaussianARD(0, 1, plates=(3,), shape=(2,)) X2 = GaussianARD(0, 1, plates=(2,), shape=()) @@ -119,14 +135,14 @@ def test_message_to_child(self): u1 = X1.get_moments() u2 = X2.get_moments() u = Y.get_moments() - self.assertAllClose(u[0][:2] * np.ones((2,)), - u1[0] * np.ones((2,))) - self.assertAllClose(u[1][:2] * np.ones((2,)), - u1[1] * np.ones((2,))) - self.assertAllClose(u[0][2:] * np.ones((3,)), - u2[0] * np.ones((3,))) - self.assertAllClose(u[1][2:] * np.ones((3,)), - u2[1] * np.ones((3,))) + self.assertAllClose((u[0]*np.ones((5,)))[:2], + u1[0]*np.ones((2,))) + self.assertAllClose((u[1]*np.ones((5,)))[:2], + u1[1]*np.ones((2,))) + self.assertAllClose((u[0]*np.ones((5,)))[2:], + u2[0]*np.ones((3,))) + self.assertAllClose((u[1]*np.ones((5,)))[2:], + u2[1]*np.ones((3,))) # Two parents with shapes X1 = GaussianARD(0, 1, plates=(2,), shape=(4,)) @@ -144,6 +160,39 @@ def test_message_to_child(self): self.assertAllClose((u[1]*np.ones((5,4,4)))[2:], u2[1]*np.ones((3,4,4))) + # Test with non-constant axis + X1 = GaussianARD(0, 1, plates=(2,4), shape=()) + X2 = GaussianARD(0, 1, plates=(3,4), shape=()) + Y = Concatenate(X1, X2, axis=-2) + u1 = X1.get_moments() + u2 = X2.get_moments() + u = Y.get_moments() + self.assertAllClose((u[0]*np.ones((5,4)))[:2], + u1[0]*np.ones((2,4))) + self.assertAllClose((u[1]*np.ones((5,4)))[:2], + u1[1]*np.ones((2,4))) + self.assertAllClose((u[0]*np.ones((5,4)))[2:], + u2[0]*np.ones((3,4))) + self.assertAllClose((u[1]*np.ones((5,4)))[2:], + u2[1]*np.ones((3,4))) + + # Test with constant parent + X1 = np.random.randn(2, 4) + X2 = GaussianARD(0, 1, plates=(3,), shape=(4,)) + Y = Concatenate(X1, X2) + u1 = Y.parents[0].get_moments() + u2 = X2.get_moments() + u = Y.get_moments() + self.assertAllClose((u[0]*np.ones((5,4)))[:2], + u1[0]*np.ones((2,4))) + self.assertAllClose((u[1]*np.ones((5,4,4)))[:2], + u1[1]*np.ones((2,4,4))) + self.assertAllClose((u[0]*np.ones((5,4)))[2:], + u2[0]*np.ones((3,4))) + self.assertAllClose((u[1]*np.ones((5,4,4)))[2:], + u2[1]*np.ones((3,4,4))) + + pass @@ -208,6 +257,24 @@ def test_message_to_parent(self): self.assertAllClose((m[1]*np.ones((5,4)))[2:], m2[1]*np.ones((3,4))) + # Constant parent + X1 = np.random.randn(2,4,6) + X2 = GaussianARD(0, 1, plates=(3,), shape=(4,6)) + Z = Concatenate(X1, X2) + Y = GaussianARD(Z, 1) + Y.observe(np.random.randn(*Y.get_shape(0))) + m1 = Z._message_to_parent(0) + m2 = X2._message_from_children() + m = Z._message_from_children() + self.assertAllClose((m[0]*np.ones((5,4,6)))[:2], + m1[0]*np.ones((2,4,6))) + self.assertAllClose((m[1]*np.ones((5,4,6,4,6)))[:2], + m1[1]*np.ones((2,4,6,4,6))) + self.assertAllClose((m[0]*np.ones((5,4,6)))[2:], + m2[0]*np.ones((3,4,6))) + self.assertAllClose((m[1]*np.ones((5,4,6,4,6)))[2:], + m2[1]*np.ones((3,4,6,4,6))) + pass