diff --git a/bayespy/inference/vmp/nodes/dot.py b/bayespy/inference/vmp/nodes/dot.py index 135a6e296..0659ae450 100644 --- a/bayespy/inference/vmp/nodes/dot.py +++ b/bayespy/inference/vmp/nodes/dot.py @@ -495,19 +495,35 @@ def _message_to_parent(self, index, u_parent=None): # Moments and keys of other parents for (k, u) in enumerate(u_parents): if k != index: - num_dims = (ind+1) * len(self.in_keys[k]) - num_plates = np.ndim(u[ind]) - num_dims - plates = np.shape(u[ind])[:num_plates] + num_dims = ( + (ind+1) * len(self.in_keys[k]) + if not self.is_constant[k] else + len(self.in_keys[k]) + ) + ui = ( + u[ind] if not self.is_constant[k] else + u[0] + ) + num_plates = np.ndim(ui) - num_dims + plates = np.shape(ui)[:num_plates] plate_keys = list(range(N + num_plates, N, -1)) - dim_keys = self.in_keys[k] - if ind == 1: - dim_keys = ([key + self.N_keys - for key in self.in_keys[k]] - + dim_keys) - args.append(u[ind]) - args.append(plate_keys + dim_keys) + if ind == 0: + args.append(ui) + args.append(plate_keys + self.in_keys[k]) + else: + in_keys2 = [key + self.N_keys for key in self.in_keys[k]] + if not self.is_constant[k]: + # Gaussian moments: Use second moment once + args.append(ui) + args.append(plate_keys + in_keys2 + self.in_keys[k]) + else: + # Delta moments: Use first moment twice + args.append(ui) + args.append(plate_keys + self.in_keys[k]) + args.append(ui) + args.append(plate_keys + in_keys2) result_num_plates = max(result_num_plates, num_plates) result_plates = misc.broadcasted_shape(result_plates, @@ -598,8 +614,11 @@ def _message_to_parent(self, index, u_parent=None): msg[ind] *= r if self.gaussian_gamma: - alphas = [u_parents[i][2] - for i in range(len(u_parents)) if i != index] + alphas = [ + (u_parents[i][2] if not is_const else 1.0) + for (i, is_const) in zip(range(len(u_parents)), self.is_constant) + if i != index + ] m2 = self._compute_message(m[2], mask, *alphas, ndim=0, plates_from=self.plates, diff --git a/bayespy/inference/vmp/nodes/tests/test_dot.py b/bayespy/inference/vmp/nodes/tests/test_dot.py index 3aeea216d..7fc02f8d8 100644 --- a/bayespy/inference/vmp/nodes/tests/test_dot.py +++ b/bayespy/inference/vmp/nodes/tests/test_dot.py @@ -909,6 +909,36 @@ def check_message(true_m0, true_m1, parent, *args, F=None): ['i'], F=F) + # Test with constant nodes + N = 10 + M = 8 + D = 5 + K = 3 + a = np.random.randn(N, D) + B = Gaussian( + np.random.randn(D), + random.covariance(D), + ) + C = GaussianARD( + np.random.randn(M, 1, D, K), + np.random.rand(M, 1, D, K), + ndim=2 + ) + F = SumMultiply('i,i,ij->', a, B, C) + tau = np.random.rand(M, N) + Y = GaussianARD(F, tau, plates=(M,N)) + y = np.random.randn(M, N) + Y.observe(y) + (m0, m1) = F._message_to_parent(1) + np.testing.assert_allclose( + m0, + np.einsum('mn,ni,mnik->i', tau*y, a, C.get_moments()[0]), + ) + np.testing.assert_allclose( + m1, + np.einsum('mn,ni,nj,mnikjl->ij', -0.5*tau, a, a, C.get_moments()[1]), + ) + # Check: Gaussian-gamma parents X1 = GaussianGamma( np.random.randn(2), @@ -939,6 +969,40 @@ def check_message(true_m0, true_m1, parent, *args, F=None): self.assertAllClose(m[2], m2) self.assertAllClose(m[3], m3) + # Delta moments + N = 10 + M = 8 + D = 5 + a = np.random.randn(N, D) + B = GaussianGamma( + np.random.randn(D), + random.covariance(D), + np.random.rand(), + np.random.rand(), + ndim=1 + ) + F = SumMultiply('i,i->', a, B) + tau = np.random.rand(M, N) + Y = GaussianARD(F, tau, plates=(M,N)) + y = np.random.randn(M, N) + Y.observe(y) + (m0, m1, m2, m3) = F._message_to_parent(1) + np.testing.assert_allclose( + m0, + np.einsum('mn,ni->i', tau*y, a), + ) + np.testing.assert_allclose( + m1, + np.einsum('mn,ni,nj->ij', -0.5*tau, a, a), + ) + np.testing.assert_allclose( + m2, + np.einsum('mn->', -0.5*tau*y**2), + ) + np.testing.assert_allclose( + m3, + np.einsum('mn->', 0.5*np.ones(np.shape(tau))), + ) pass