Skip to content

Commit

Permalink
VAE improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
muammar committed Nov 6, 2019
1 parent 93f3df0 commit b5c9fc0
Showing 1 changed file with 77 additions and 67 deletions.
144 changes: 77 additions & 67 deletions ml4chem/models/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ def __init__(self, hiddenlayers=None, activation="relu", **kwargs):

self.hiddenlayers = hiddenlayers
self.activation = activation

# A white list of supported kwargs.
supported_kwargs = []
supported_kwargs = ["multivariate"]

for k, v in kwargs.items():
if k in supported_kwargs:
Expand Down Expand Up @@ -100,8 +101,8 @@ def prepare_model(
}

if purpose == "training":
logger.info("Model Training")
logger.info("==============")
logger.info("Model")
logger.info("=====")
logger.info("Model name: {}.".format(self.name()))
logger.info(
"Structure of {}: {}".format(
Expand Down Expand Up @@ -174,8 +175,7 @@ def prepare_model(

inp_dim = out_dim

"""
if self.name() == "VAE":
if self.multivariate:
h = torch.nn.Sequential(*decoder)
mu = torch.nn.Linear(inp_dim, output_dimension)
mu = torch.nn.Sequential(*[mu])
Expand All @@ -184,15 +184,14 @@ def prepare_model(
values = [h, mu, logvar]
decoder = torch.nn.ModuleDict(list(map(list, zip(keys, values))))
else:
"""
# The last decoder layer for symbol
decoder.append(torch.nn.Linear(inp_dim, output_dimension))
# According to this video https://youtu.be/xTU79Zs4XKY?t=416
# real numbered inputs need no activation function in the output
# layer decoder.append(activation[self.activation]())
# The last decoder layer for symbol
decoder.append(torch.nn.Linear(inp_dim, output_dimension))
# According to this video https://youtu.be/xTU79Zs4XKY?t=416
# real numbered inputs need no activation function in the output
# layer decoder.append(activation[self.activation]())

# Stacking up the layers.
decoder = torch.nn.Sequential(*decoder)
# Stacking up the layers.
decoder = torch.nn.Sequential(*decoder)

symbol_decoder_pair.append([symbol, decoder])

Expand Down Expand Up @@ -367,6 +366,9 @@ class VAE(AutoEncoder):
Dictionary with encoder, and decoder layers in the Auto Encoder.
activation : str
The activation function.
multivariate : bool
If multivariate is set to True we treat the distribution as a
multivariate Gaussian distribution otherwise we use Bernoulli.
Notes
Expand Down Expand Up @@ -425,35 +427,20 @@ def decode(self, symbol, z):
-------
reconstruction
Tensor with reconstruction.
Notes
-----
See page 11 "Kingma, D. P. & Welling, M. Auto-Encoding Variational
Bayes. (2013)".
"""
reconstruction = self.decoders[symbol](z)
return torch.sigmoid(reconstruction)

# def decode(self, symbol, z):
# """Decode latent vector, z
#
# Parameters
# ----------
# symbol : str
# Chemical symbol.
# z : array
# Latent vector.
#
# Returns
# -------
# mu, logvar
# Mean and variance.
#
# Notes
# -----
# See page 11 "Kingma, D. P. & Welling, M. Auto-Encoding Variational
# Bayes. (2013)".
# """

# h = self.decoders[symbol]["h"](z)
# mu = self.decoders[symbol]["mu"](h)
# logvar = self.decoders[symbol]["logvar"](h)
# return mu, logvar
if self.multivariate:
h = self.decoders[symbol]["h"](z)
mu = self.decoders[symbol]["mu"](h)
logvar = self.decoders[symbol]["logvar"](h)
return mu, logvar
else:
reconstruction = self.decoders[symbol](z)
return torch.sigmoid(reconstruction)

def reparameterize(self, mu, logvar):
"""Reparameterization trick
Expand All @@ -467,7 +454,7 @@ def reparameterize(self, mu, logvar):
mu : tensor
Mean values of distribution.
logvar : tensor
Logarithm of variance of distribution,
Logarithm of variance of distribution.
Returns
-------
Expand Down Expand Up @@ -496,29 +483,34 @@ def forward(self, X):

mus_latent = []
logvars_latent = []
# mus_output = []
# logvars_output = []
mus_decoder = []
logvars_decoder = []
outputs = []
for hash, image in X.items():
for symbol, x in image:
mu_latent, logvar_latent = self.encode(symbol, x)
z = self.reparameterize(mu_latent, logvar_latent)
mus_latent.append(mu_latent)
logvars_latent.append(logvar_latent)
reconstruction = self.decode(symbol, z)
# mu_output, logvar_output = self.decode(symbol, z)
# mus_output.append(mu_output)
# logvars_output.append(logvar_output)
outputs.append(reconstruction)

if self.multivariate:
mu_decoder, logvar_decoder = self.decode(symbol, z)
mus_decoder.append(mu_decoder)
logvars_decoder.append(logvar_decoder)
else:
reconstruction = self.decode(symbol, z)
outputs.append(reconstruction)

mus_latent = torch.stack(mus_latent)
logvars_latent = torch.stack(logvars_latent)
# mus_output = torch.stack(mus_output)
# logvars_output = torch.stack(logvars_output)
outputs = torch.stack(outputs)

# return outputs, mus_latent, logvars_latent, mus_output, logvars_output
return outputs, mus_latent, logvars_latent
if self.multivariate:
mus_decoder = torch.stack(mus_decoder)
logvars_decoder = torch.stack(logvars_decoder)
return mus_decoder, logvars_decoder, mus_latent, logvars_latent
else:
outputs = torch.stack(outputs)
return outputs, mus_latent, logvars_latent

def get_latent_space(self, X, svm=False, purpose=None):
"""Get latent space for training ML4Chem models
Expand Down Expand Up @@ -973,23 +965,38 @@ def train_batches(
loss_name = lossfxn.__name__

if model.name() == "VAE":
# outputs, mus_latent, logvars_latent, mus_output, logvars_output = model(inputs)
outputs, mus_latent, logvars_latent, = model(inputs)
if model.multivariate:
mus_decoder, logvars_decoder, mus_latent, logvars_latent = model(inputs)

args = {
"targets": targets[index],
"mus_decoder": mus_decoder,
"logvars_decoder": logvars_decoder,
"mus_latent": mus_latent,
"logvars_latent": logvars_latent,
"annealing": annealing,
"multivariate": model.multivariate,
"input_dimension": model.input_dimension,
}

else:
outputs, mus_latent, logvars_latent, = model(inputs)

args = {
"outputs": outputs,
"targets": targets[index],
"mus_latent": mus_latent,
"logvars_latent": logvars_latent,
"annealing": annealing,
"multivariate": model.multivariate,
"input_dimension": model.input_dimension,
}

args = {
"outputs": outputs,
"targets": targets[index],
"mus_latent": mus_latent,
"logvars_latent": logvars_latent,
"annealing": annealing,
"input_dimension": model.input_dimension,
}
else:
outputs = model(inputs)
args = {"outputs": outputs, "targets": targets[index]}

# Latent space penalizations

# Latent space penalization
if penalize_latent:
latent = {
"latent": model.get_latent_space(
Expand Down Expand Up @@ -1018,7 +1025,10 @@ def train_batches(
for param in model.parameters():
gradients.append(param.grad.detach().numpy())

return outputs, loss, gradients
if model.multivariate:
return mus_decoder, loss, gradients
else:
return outputs, loss, gradients

@staticmethod
def get_inputs_chunks(chunks):
Expand Down

0 comments on commit b5c9fc0

Please sign in to comment.