From 4e58b8c2a6be7dbfe949df2c8b5f6c55d090ffdc Mon Sep 17 00:00:00 2001 From: "Ilya V. Schurov" Date: Wed, 20 Mar 2019 02:30:09 +0300 Subject: [PATCH] fixes error with size mismatch --- train_variational_autoencoder_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index 3b5e5d2..f663dea 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -71,7 +71,7 @@ def forward(self, x, n_samples=1): scale = self.softplus(scale_arg) eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device) z = loc + scale * eps # reparameterization - log_q_z = self.log_q_z(loc, scale, z).sum(-1) + log_q_z = self.log_q_z(loc, scale, z).sum(-1, keepdim=True) return z, log_q_z