Skip to content

Commit

Permalink
VAE model now supports one_for_all keyword argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
muammar committed Nov 12, 2019
1 parent dca42d6 commit 97efd97
Showing 1 changed file with 92 additions and 29 deletions.
121 changes: 92 additions & 29 deletions ml4chem/models/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def prepare_model(
decoder = []
decoder_layers = self.hiddenlayers["decoder"]

"""Build Encoder"""
"""
Encoder
"""
out_dimension = encoder_layers[0]
_encoder = torch.nn.Linear(input_dimension, out_dimension)
encoder.append(_encoder)
Expand All @@ -147,22 +149,50 @@ def prepare_model(
encoder.append(_encoder)
encoder.append(activation[self.activation]())

if self.name() == "VAE":
keys = ["h", "mu", "logvar"]
mu = []
logvar = []

index = -3
for _ in range(2):
index += 1
if index == -2:
mu.append(encoder.pop(index))
else:
encoder.pop(index)

h = torch.nn.Sequential(*encoder)
logvar = torch.nn.Linear(inp_dim, out_dim)
logvar = torch.nn.Sequential(*[logvar])
mu = torch.nn.Sequential(*mu)

values = [h, mu, logvar]
encoder = torch.nn.ModuleDict(list(map(list, zip(keys, values))))
else:

encoder = torch.nn.Sequential(*encoder)

"""
Decoder
"""
for inp_dim, out_dim in zip(decoder_layers, decoder_layers[1:]):
decoder.append(torch.nn.Linear(inp_dim, out_dim))
decoder.append(activation[self.activation]())

inp_dim = out_dim

decoder.append(torch.nn.Linear(inp_dim, output_dimension))
inp_dim = out_dim

encoder = torch.nn.Sequential(*encoder)
decoder = torch.nn.Sequential(*decoder)

if self.name() == "VAE":
raise NotImplementedError
if self.variant == "multivariate":
h = torch.nn.Sequential(*decoder)
mu = torch.nn.Linear(inp_dim, output_dimension)
mu = torch.nn.Sequential(*[mu])
logvar = torch.nn.Linear(inp_dim, output_dimension)
logvar = torch.nn.Sequential(*[logvar])
values = [h, mu, logvar]
decoder = torch.nn.ModuleDict(list(map(list, zip(keys, values))))
else:
decoder.append(torch.nn.Linear(inp_dim, output_dimension))
decoder = torch.nn.Sequential(*decoder)

self.encoders = encoder
self.decoders = decoder
Expand Down Expand Up @@ -261,7 +291,6 @@ def prepare_model(
)
for m in self.modules():
if isinstance(m, torch.nn.Linear):
print(m)
# nn.init.normal_(m.weight) # , mean=0, std=0.01)
torch.nn.init.xavier_uniform_(m.weight)

Expand Down Expand Up @@ -476,35 +505,40 @@ def name(cls):
"""Returns name of class"""
return cls.NAME

def encode(self, symbol, x):
def encode(self, x, symbol=None):
"""Encode input
Parameters
----------
symbol : str
Chemical symbol.
x : array
Input array.
symbol : str, optional
Chemical symbol. Default is None.
Returns
-------
mu, logvar
Mean and variance.
"""
h = self.encoders[symbol]["h"](x)
mu = self.encoders[symbol]["mu"](h)
logvar = self.encoders[symbol]["logvar"](h)
if symbol is None:
h = self.encoders["h"](x)
mu = self.encoders["mu"](h)
logvar = self.encoders["logvar"](h)
else:
h = self.encoders[symbol]["h"](x)
mu = self.encoders[symbol]["mu"](h)
logvar = self.encoders[symbol]["logvar"](h)
return mu, logvar

def decode(self, symbol, z):
def decode(self, z, symbol=None):
"""Decode latent vector, z
Parameters
----------
symbol : str
Chemical symbol.
z : array
Latent vector.
symbol : str, optional
Chemical symbol. Default is None.
Returns
-------
Expand All @@ -517,17 +551,30 @@ def decode(self, symbol, z):
Bayes. (2013)".
"""
if self.variant == "multivariate":
h = self.decoders[symbol]["h"](z)
mu = self.decoders[symbol]["mu"](h)
logvar = self.decoders[symbol]["logvar"](h)
if symbol is None:
h = self.decoders["h"](z)
mu = self.decoders["mu"](h)
logvar = self.decoders["logvar"](h)
else:
h = self.decoders[symbol]["h"](z)
mu = self.decoders[symbol]["mu"](h)
logvar = self.decoders[symbol]["logvar"](h)

return mu, logvar

elif self.variant == "bernoulli":
reconstruction = self.decoders[symbol](z)
if symbol is None:
reconstruction = self.decoders(z)
else:
reconstruction = self.decoders[symbol](z)

return torch.sigmoid(reconstruction)

elif self.variant == "dcgan":
reconstruction = self.decoders[symbol](z)
if symbol is None:
reconstruction = self.decoders(z)
else:
reconstruction = self.decoders[symbol](z)
return torch.tanh(reconstruction)
else:
raise NotImplementedError
Expand Down Expand Up @@ -578,17 +625,27 @@ def forward(self, X):
outputs = []
for hash, image in X.items():
for symbol, x in image:
mu_latent, logvar_latent = self.encode(symbol, x)
if self.one_for_all:
mu_latent, logvar_latent = self.encode(x)
else:
mu_latent, logvar_latent = self.encode(x, symbol=symbol)
z = self.reparameterize(mu_latent, logvar_latent)
mus_latent.append(mu_latent)
logvars_latent.append(logvar_latent)

if self.variant == "multivariate":
mu_decoder, logvar_decoder = self.decode(symbol, z)
if self.one_for_all:
mu_decoder, logvar_decoder = self.decode(z)
else:
mu_decoder, logvar_decoder = self.decode(z, symbol=symbol)

mus_decoder.append(mu_decoder)
logvars_decoder.append(logvar_decoder)
else:
reconstruction = self.decode(symbol, z)
if self.one_for_all:
reconstruction = self.decode(z)
else:
reconstruction = self.decode(z, symbol=symbol)
outputs.append(reconstruction)

mus_latent = torch.stack(mus_latent)
Expand Down Expand Up @@ -645,7 +702,10 @@ def get_latent_space(self, X, svm=False, purpose=None):
hashes.append(hash)
_symbols = []
for symbol, x in image:
mu_latent, logvar_latent = self.encode(symbol, x)
if self.one_for_all:
mu_latent, logvar_latent = self.encode(x)
else:
mu_latent, logvar_latent = self.encode(x, symbol=symbol)
latent_vector = self.reparameterize(mu_latent, logvar_latent)
_symbols.append(symbol)

Expand All @@ -671,7 +731,10 @@ def get_latent_space(self, X, svm=False, purpose=None):
for hash, image in X.items():
latent_space[hash] = []
for symbol, x in image:
mu_latent, logvar_latent = self.encode(symbol, x)
if self.one_for_all:
mu_latent, logvar_latent = self.encode(x)
else:
mu_latent, logvar_latent = self.encode(x, symbol=symbol)
latent_vector = self.reparameterize(mu_latent, logvar_latent)

if svm:
Expand Down

0 comments on commit 97efd97

Please sign in to comment.