Skip to content

Commit

Permalink
Added Multivariate Gaussian Variational Autoencoder.
Browse files Browse the repository at this point in the history
- Instantiate VAE with `multivariate=True`.
  • Loading branch information
muammar committed Nov 6, 2019
1 parent 51d32d4 commit 93f3df0
Showing 1 changed file with 31 additions and 17 deletions.
48 changes: 31 additions & 17 deletions ml4chem/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,14 @@ def get_pairwise_distances(positions, squared=False):


def VAELoss(
outputs,
targets,
mus_latent,
logvars_latent,
annealing,
outputs=None,
targets=None,
mus_latent=None,
logvars_latent=None,
mus_decoder=None,
logvars_decoder=None,
annealing=None,
multivariate=None,
latent=None,
input_dimension=None,
):
Expand All @@ -270,14 +273,17 @@ def VAELoss(
targets : tensor
Expected value of outputs.
mus_latent : tensor
The latent space tensor.
Mean values of distribution.
logvars_latent : tensor
The latent space tensor.
Logarithm of the variance.
multivariate : bool
If multivariate is set to True we treat the distribution as a
multivariate Gaussian distribution otherwise we use Bernoulli.
annealing : float
Contribution of distance loss function to total loss.
latent : tensor
latent : tensor, optional
The latent space tensor.
input_dimension : int
input_dimension : int, optional
Input's dimension.
Expand All @@ -289,15 +295,23 @@ def VAELoss(
"""

loss = []
# LOG_2_PI = np.log(2 * np.pi)
# loss_rec = LOG_2_PI + torch.sum(logvars_output + (targets - mus_output) ** 2 / (2 * torch.exp(logvars_output)))

loss_rec = torch.nn.functional.binary_cross_entropy(
outputs, targets, reduction="sum"
)
# criterion = torch.nn.MSELoss(reduction="sum")
# loss_rec = criterion(outputs, targets) * input_dimension
loss_rec *= input_dimension
if multivariate:
# loss_rec = LOG_2_PI + logvar_x + (x - mu_x)**2 / (2*torch.exp(logvar_x))
# loss_rec = -torch.mean(torch.sum(-(0.5 * np.log(2 * np.pi) + 0.5 * logvars_decoder) - 0.5 * ((targets - mus_decoder)**2 / torch.exp(logvars_decoder)), dim=0))
loss_rec = torch.sum(
(-0.5 * np.log(2.0 * np.pi))
+ (-0.5 * logvars_decoder)
+ ((-0.5 / torch.exp(logvars_decoder)) * (targets - mus_decoder) ** 2.0)
)
loss_rec *= -1.0

else:
loss_rec = torch.nn.functional.binary_cross_entropy(
outputs, targets, reduction="sum"
)
loss_rec *= input_dimension

loss.append(loss_rec)

# see Appendix B from VAE paper:
Expand Down

0 comments on commit 93f3df0

Please sign in to comment.