Skip to content

Commit

Permalink
Implement par.msg from SumMultiply with Constants
Browse files Browse the repository at this point in the history
  • Loading branch information
jluttine committed Oct 3, 2019
1 parent c106882 commit 7740300
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 12 deletions.
43 changes: 31 additions & 12 deletions bayespy/inference/vmp/nodes/dot.py
Expand Up @@ -495,19 +495,35 @@ def _message_to_parent(self, index, u_parent=None):
# Moments and keys of other parents
for (k, u) in enumerate(u_parents):
if k != index:
num_dims = (ind+1) * len(self.in_keys[k])
num_plates = np.ndim(u[ind]) - num_dims
plates = np.shape(u[ind])[:num_plates]
num_dims = (
(ind+1) * len(self.in_keys[k])
if not self.is_constant[k] else
len(self.in_keys[k])
)
ui = (
u[ind] if not self.is_constant[k] else
u[0]
)
num_plates = np.ndim(ui) - num_dims
plates = np.shape(ui)[:num_plates]
plate_keys = list(range(N + num_plates,
N,
-1))
dim_keys = self.in_keys[k]
if ind == 1:
dim_keys = ([key + self.N_keys
for key in self.in_keys[k]]
+ dim_keys)
args.append(u[ind])
args.append(plate_keys + dim_keys)
if ind == 0:
args.append(ui)
args.append(plate_keys + self.in_keys[k])
else:
in_keys2 = [key + self.N_keys for key in self.in_keys[k]]
if not self.is_constant[k]:
# Gaussian moments: Use second moment once
args.append(ui)
args.append(plate_keys + in_keys2 + self.in_keys[k])
else:
# Delta moments: Use first moment twice
args.append(ui)
args.append(plate_keys + self.in_keys[k])
args.append(ui)
args.append(plate_keys + in_keys2)

result_num_plates = max(result_num_plates, num_plates)
result_plates = misc.broadcasted_shape(result_plates,
Expand Down Expand Up @@ -598,8 +614,11 @@ def _message_to_parent(self, index, u_parent=None):
msg[ind] *= r

if self.gaussian_gamma:
alphas = [u_parents[i][2]
for i in range(len(u_parents)) if i != index]
alphas = [
(u_parents[i][2] if not is_const else 1.0)
for (i, is_const) in zip(range(len(u_parents)), self.is_constant)
if i != index
]
m2 = self._compute_message(m[2], mask, *alphas,
ndim=0,
plates_from=self.plates,
Expand Down
64 changes: 64 additions & 0 deletions bayespy/inference/vmp/nodes/tests/test_dot.py
Expand Up @@ -909,6 +909,36 @@ def check_message(true_m0, true_m1, parent, *args, F=None):
['i'],
F=F)

# Test with constant nodes
N = 10
M = 8
D = 5
K = 3
a = np.random.randn(N, D)
B = Gaussian(
np.random.randn(D),
random.covariance(D),
)
C = GaussianARD(
np.random.randn(M, 1, D, K),
np.random.rand(M, 1, D, K),
ndim=2
)
F = SumMultiply('i,i,ij->', a, B, C)
tau = np.random.rand(M, N)
Y = GaussianARD(F, tau, plates=(M,N))
y = np.random.randn(M, N)
Y.observe(y)
(m0, m1) = F._message_to_parent(1)
np.testing.assert_allclose(
m0,
np.einsum('mn,ni,mnik->i', tau*y, a, C.get_moments()[0]),
)
np.testing.assert_allclose(
m1,
np.einsum('mn,ni,nj,mnikjl->ij', -0.5*tau, a, a, C.get_moments()[1]),
)

# Check: Gaussian-gamma parents
X1 = GaussianGamma(
np.random.randn(2),
Expand Down Expand Up @@ -939,6 +969,40 @@ def check_message(true_m0, true_m1, parent, *args, F=None):
self.assertAllClose(m[2], m2)
self.assertAllClose(m[3], m3)

# Delta moments
N = 10
M = 8
D = 5
a = np.random.randn(N, D)
B = GaussianGamma(
np.random.randn(D),
random.covariance(D),
np.random.rand(),
np.random.rand(),
ndim=1
)
F = SumMultiply('i,i->', a, B)
tau = np.random.rand(M, N)
Y = GaussianARD(F, tau, plates=(M,N))
y = np.random.randn(M, N)
Y.observe(y)
(m0, m1, m2, m3) = F._message_to_parent(1)
np.testing.assert_allclose(
m0,
np.einsum('mn,ni->i', tau*y, a),
)
np.testing.assert_allclose(
m1,
np.einsum('mn,ni,nj->ij', -0.5*tau, a, a),
)
np.testing.assert_allclose(
m2,
np.einsum('mn->', -0.5*tau*y**2),
)
np.testing.assert_allclose(
m3,
np.einsum('mn->', 0.5*np.ones(np.shape(tau))),
)
pass


Expand Down

0 comments on commit 7740300

Please sign in to comment.