-
Notifications
You must be signed in to change notification settings - Fork 146
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add code for generator based training, testing, sampling (#43)
* added functionality to generate train and test tensors on the fly * fix encoding bug to properly encode padding * fix bugs related to training on generators, now performs valid generator-based training * add a sampling script to work with generator-based models * add test_split as command-line optional argument; also change default batch_size to evenly divide default epoch_size
- Loading branch information
Showing
3 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import numpy as np | ||
import itertools | ||
import random | ||
|
||
class CharacterTable(object): | ||
''' | ||
Given a set of characters: | ||
+ Encode them to a one hot integer representation | ||
+ Decode the one hot integer representation to their character output | ||
+ Decode a vector of probabilities to their character output | ||
first version by rmcgibbo | ||
''' | ||
def __init__(self, chars, maxlen): | ||
self.chars = sorted(set(chars)) | ||
self.char_indices = dict((c, i) for i, c in enumerate(self.chars)) | ||
self.indices_char = dict((i, c) for i, c in enumerate(self.chars)) | ||
self.maxlen = maxlen | ||
|
||
def encode(self, C, maxlen=None): | ||
maxlen = maxlen if maxlen else self.maxlen | ||
X = np.zeros((maxlen, len(self.chars))) | ||
for i, c in enumerate(C): | ||
X[i, self.char_indices[c]] = 1 | ||
return X | ||
|
||
def decode(self, X, mode='argmax'): | ||
if mode == 'argmax': | ||
X = X.argmax(axis=-1) | ||
elif mode == 'choice': | ||
X = np.apply_along_axis(lambda vec: \ | ||
np.random.choice(len(vec), 1, | ||
p=(vec / np.sum(vec))), | ||
axis=-1, arr=X).ravel() | ||
return str.join('',(self.indices_char[x] for x in X)) | ||
|
||
|
||
class SmilesDataGenerator(object): | ||
""" | ||
Given a list of SMILES strings, | ||
returns a generator that returns batches of | ||
randomly sampled strings encoded as one-hot, | ||
as well as a weighting vector indicating true length of string | ||
""" | ||
SMILES_CHARS = [' ', | ||
'#', '%', '(', ')', '+', '-', '.', '/', | ||
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', | ||
'=', '@', | ||
'A', 'B', 'C', 'F', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', | ||
'R', 'S', 'T', 'V', 'X', 'Z', | ||
'[', '\\', ']', | ||
'a', 'b', 'c', 'e', 'g', 'i', 'l', 'n', 'o', 'p', 'r', 's', | ||
't', 'u'] | ||
|
||
def __init__(self, words, maxlen, | ||
pad_char=' ', | ||
pad_min=1, | ||
pad_weight=0.0, | ||
test_split=0.20): | ||
self.maxlen = maxlen | ||
self.words = words | ||
self.max_words = len(words) | ||
self.word_ixs = range(self.max_words) | ||
self.shuffled_word_ixs = range(self.max_words) | ||
random.shuffle(self.shuffled_word_ixs) | ||
|
||
self.pad_char = pad_char | ||
self.pad_min = pad_min | ||
self.pad_weight = pad_weight | ||
self.test_split = test_split | ||
self.chars = sorted(set.union(set(SmilesDataGenerator.SMILES_CHARS), | ||
set(pad_char))) | ||
self.table = CharacterTable(self.chars, self.maxlen) | ||
|
||
def encode(self, word): | ||
padded_word = word + self.pad_char*(self.maxlen-len(word)) | ||
return self.table.encode(padded_word) | ||
|
||
def weight(self, word): | ||
weight_vec = np.ones((self.maxlen,))*self.pad_weight | ||
weight_vec[np.arange(min(len(word)+self.pad_min, self.maxlen))] = 1 | ||
return weight_vec | ||
|
||
def sample(self, predicate=None): | ||
if predicate: | ||
word_ix = random.choice(self.word_ixs) | ||
if not predicate(word_ix): | ||
return self.sample(predicate=predicate) | ||
word = self.words[self.shuffled_word_ixs[word_ix]] | ||
else: | ||
word = random.choice(self.words) | ||
if len(word) < self.maxlen: | ||
return word | ||
return self.sample(predicate=predicate) | ||
|
||
def train_sample(self): | ||
if self.test_split > 0: | ||
threshold = self.max_words * self.test_split | ||
return self.sample(lambda word_ix: word_ix >= threshold) | ||
return self.sample() | ||
|
||
def test_sample(self): | ||
if self.test_split > 0: | ||
threshold = self.max_words * self.test_split | ||
return self.sample(lambda word_ix: word_ix < threshold) | ||
return self.sample() | ||
|
||
def generator(self, batch_size, sample_func=None): | ||
while True: | ||
data_tensor = np.zeros((batch_size, self.maxlen, len(self.chars)), dtype=np.bool) | ||
weight_tensor = np.zeros((batch_size, self.maxlen)) | ||
for word_ix in range(batch_size): | ||
if not sample_func: | ||
sample_func = self.sample | ||
word = sample_func() | ||
data_tensor[word_ix, ...] = self.encode(word) | ||
weight_tensor[word_ix, ...] = self.weight(word) | ||
yield (data_tensor, data_tensor, weight_tensor) | ||
|
||
def train_generator(self, batch_size): | ||
return self.generator(batch_size, sample_func=self.train_sample) | ||
|
||
def test_generator(self, batch_size): | ||
return self.generator(batch_size, sample_func=self.test_sample) | ||
|
||
|
||
class CanonicalSmilesDataGenerator(SmilesDataGenerator): | ||
""" | ||
Given a list of SMILES strings, | ||
returns a generator that returns batches of | ||
randomly sampled strings, canonicalized, encoded as one-hot, | ||
as well as a weighting vector indicating true length of string | ||
""" | ||
|
||
def sample(self, predicate=None): | ||
from rdkit import Chem | ||
mol = Chem.MolFromSmiles(super(CanonicalSmilesDataGenerator, self).sample(predicate=predicate)) | ||
if mol: | ||
canon_word = Chem.MolToSmiles(mol) | ||
if len(canon_word) < self.maxlen: | ||
return canon_word | ||
return self.sample(predicate=predicate) | ||
|
||
def train_sample(self): | ||
if self.test_split > 0: | ||
threshold = self.max_words * self.test_split | ||
return self.sample(lambda word_ix: word_ix >= threshold) | ||
return self.sample() | ||
|
||
def test_sample(self): | ||
if self.test_split > 0: | ||
threshold = self.max_words * self.test_split | ||
return self.sample(lambda word_ix: word_ix < threshold) | ||
return self.sample() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import os | ||
import h5py | ||
import numpy as np | ||
import sys | ||
|
||
from molecules.model import MoleculeVAE | ||
from molecules.utils import one_hot_array, one_hot_index, from_one_hot_array, \ | ||
decode_smiles_from_indexes, load_dataset | ||
from molecules.vectorizer import SmilesDataGenerator | ||
|
||
LATENT_DIM = 292 | ||
NUM_SAMPLED = 100 | ||
TARGET = 'autoencoder' | ||
|
||
def get_arguments(): | ||
parser = argparse.ArgumentParser(description='Molecular autoencoder network') | ||
parser.add_argument('data', type=str, help='File of latent representation tensors for decoding.') | ||
parser.add_argument('model', type=str, help='Trained Keras model to use.') | ||
parser.add_argument('--save_h5', type=str, help='Name of a file to write HDF5 output to.') | ||
parser.add_argument('--target', type=str, default=TARGET, | ||
help='What model to sample from: autoencoder, encoder, decoder.') | ||
parser.add_argument('--latent_dim', type=int, metavar='N', default=LATENT_DIM, | ||
help='Dimensionality of the latent representation.') | ||
parser.add_argument('--sample', type=int, metavar='N', default=NUM_SAMPLED, | ||
help='Number of items to sample from data generator.') | ||
return parser.parse_args() | ||
|
||
def read_latent_data(filename): | ||
h5f = h5py.File(filename, 'r') | ||
data = h5f['latent_vectors'][:] | ||
charset = h5f['charset'][:] | ||
h5f.close() | ||
return (data, charset) | ||
|
||
def read_smiles_data(filename): | ||
import pandas as pd | ||
h5f = pd.read_hdf(filename, 'table') | ||
data = h5f['structure'][:] | ||
# import gzip | ||
# data = [line.split()[0].strip() for line in gzip.open(filename) if line] | ||
return data | ||
|
||
def autoencoder(args, model): | ||
latent_dim = args.latent_dim | ||
|
||
structures = read_smiles_data(args.data) | ||
|
||
datobj = SmilesDataGenerator(structures, 120) | ||
train_gen = datobj.generator(1) | ||
|
||
if os.path.isfile(args.model): | ||
model.load(datobj.chars, args.model, latent_rep_size = latent_dim) | ||
else: | ||
raise ValueError("Model file %s doesn't exist" % args.model) | ||
|
||
true_pred_gen = (((mat, weight, model.autoencoder.predict(mat)) | ||
for (mat, _, weight) in train_gen)) | ||
text_gen = ((str.join('\n', | ||
[str((datobj.table.decode(true_mat[vec_ix])[:np.argmin(weight[vec_ix])], | ||
datobj.table.decode(vec)[:])) | ||
for (vec_ix, vec) in enumerate(pred_mat)])) | ||
for (true_mat, weight, pred_mat) in true_pred_gen) | ||
for _ in range(args.sample): | ||
print(text_gen.next()) | ||
|
||
def decoder(args, model): | ||
latent_dim = args.latent_dim | ||
data, charset = read_latent_data(args.data) | ||
|
||
if os.path.isfile(args.model): | ||
model.load(charset, args.model, latent_rep_size = latent_dim) | ||
else: | ||
raise ValueError("Model file %s doesn't exist" % args.model) | ||
|
||
for ix in range(len(data)): | ||
sampled = model.decoder.predict(data[ix]).argmax(axis=2)[0] | ||
sampled = decode_smiles_from_indexes(sampled, charset) | ||
print(sampled) | ||
|
||
def encoder(args, model): | ||
latent_dim = args.latent_dim | ||
|
||
structures = read_smiles_data(args.data) | ||
|
||
datobj = SmilesDataGenerator(structures, 120) | ||
train_gen = datobj.generator(1) | ||
|
||
if os.path.isfile(args.model): | ||
model.load(datobj.chars, args.model, latent_rep_size = latent_dim) | ||
else: | ||
raise ValueError("Model file %s doesn't exist" % args.model) | ||
|
||
true_pred_gen = (((mat, weight, model.encoder.predict(mat)) | ||
for (mat, _, weight) in train_gen)) | ||
if args.save_h5: | ||
h5f = h5py.File(args.save_h5, 'w') | ||
h5f.create_dataset('charset', data = datobj.chars) | ||
h5f.create_dataset('latent_vectors', (args.sample, 120, latent_dim)) | ||
for ix in range(args.sample): | ||
_, _, x_latent = true_pred_gen.next() | ||
h5f['latent_vectors'][ix] = x_latent[0] | ||
h5f.close() | ||
else: | ||
text_gen = ((str.join('\n', | ||
[str((datobj.table.decode(true_mat[vec_ix])[:np.argmin(weight[vec_ix])], | ||
(vec)[:])) | ||
for (vec_ix, vec) in enumerate(pred_mat)])) | ||
for (true_mat, weight, pred_mat) in true_pred_gen) | ||
for _ in range(args.sample): | ||
print(text_gen.next()) | ||
|
||
def main(): | ||
args = get_arguments() | ||
model = MoleculeVAE() | ||
|
||
if args.target == 'autoencoder': | ||
autoencoder(args, model) | ||
elif args.target == 'encoder': | ||
encoder(args, model) | ||
elif args.target == 'decoder': | ||
decoder(args, model) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import os | ||
import h5py | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from molecules.model import MoleculeVAE | ||
from molecules.vectorizer import SmilesDataGenerator | ||
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau | ||
|
||
NUM_EPOCHS = 1 | ||
EPOCH_SIZE = 500000 | ||
BATCH_SIZE = 500 | ||
LATENT_DIM = 292 | ||
MAX_LEN = 120 | ||
TEST_SPLIT = 0.20 | ||
|
||
def get_arguments(): | ||
parser = argparse.ArgumentParser(description='Molecular autoencoder network') | ||
parser.add_argument('data', type=str, help='The HDF5 file containing structures.') | ||
parser.add_argument('model', type=str, | ||
help='Where to save the trained model. If this file exists, it will be opened and resumed.') | ||
parser.add_argument('--epochs', type=int, metavar='N', default=NUM_EPOCHS, | ||
help='Number of epochs to run during training.') | ||
parser.add_argument('--latent_dim', type=int, metavar='N', default=LATENT_DIM, | ||
help='Dimensionality of the latent representation.') | ||
parser.add_argument('--batch_size', type=int, metavar='N', default=BATCH_SIZE, | ||
help='Number of samples to process per minibatch during training.') | ||
parser.add_argument('--epoch_size', type=int, metavar='N', default=EPOCH_SIZE, | ||
help='Number of samples to process per epoch during training.') | ||
parser.add_argument('--test_split', type=float, metavar='N', default=TEST_SPLIT, | ||
help='Fraction of dataset to use as test data, rest is | ||
training data.') | ||
return parser.parse_args() | ||
|
||
def main(): | ||
args = get_arguments() | ||
|
||
data = pd.read_hdf(args.data, 'table') | ||
structures = data['structure'] | ||
|
||
# import gzip | ||
# filepath = args.data | ||
# structures = [line.split()[0].strip() for line in gzip.open(filepath) if line] | ||
|
||
# can also use CanonicalSmilesDataGenerator | ||
datobj = SmilesDataGenerator(structures, MAX_LEN, | ||
test_split=args.test_split) | ||
test_divisor = int((1 - datobj.test_split) / (datobj.test_split)) | ||
train_gen = datobj.train_generator(args.batch_size) | ||
test_gen = datobj.test_generator(args.batch_size) | ||
|
||
# reformulate generators to not use weights | ||
train_gen = ((tens, tens) for (tens, _, weights) in train_gen) | ||
test_gen = ((tens, tens) for (tens, _, weights) in test_gen) | ||
|
||
model = MoleculeVAE() | ||
if os.path.isfile(args.model): | ||
model.load(datobj.chars, args.model, latent_rep_size = args.latent_dim) | ||
else: | ||
model.create(datobj.chars, latent_rep_size = args.latent_dim) | ||
|
||
checkpointer = ModelCheckpoint(filepath = args.model, | ||
verbose = 1, | ||
save_best_only = True) | ||
|
||
reduce_lr = ReduceLROnPlateau(monitor = 'val_loss', | ||
factor = 0.2, | ||
patience = 3, | ||
min_lr = 0.0001) | ||
|
||
model.autoencoder.fit_generator( | ||
train_gen, | ||
args.epoch_size, | ||
nb_epoch = args.epochs, | ||
callbacks = [checkpointer, reduce_lr], | ||
validation_data = test_gen, | ||
nb_val_samples = args.epoch_size / test_divisor, | ||
pickle_safe = True | ||
) | ||
|
||
if __name__ == '__main__': | ||
main() |