Skip to content

Commit

Permalink
Add code for generator based training, testing, sampling (#43)
Browse files Browse the repository at this point in the history
* added functionality to generate train and test tensors on the fly

* fix encoding bug to properly encode padding

* fix bugs related to training on generators, now performs valid generator-based training

* add a sampling script to work with generator-based models

* add test_split as command-line optional argument; also change default batch_size to evenly divide default epoch_size
  • Loading branch information
pechersky authored and maxhodak committed Nov 25, 2016
1 parent 0e24228 commit c53f54f
Show file tree
Hide file tree
Showing 3 changed files with 365 additions and 0 deletions.
153 changes: 153 additions & 0 deletions molecules/vectorizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import numpy as np
import itertools
import random

class CharacterTable(object):
'''
Given a set of characters:
+ Encode them to a one hot integer representation
+ Decode the one hot integer representation to their character output
+ Decode a vector of probabilities to their character output
first version by rmcgibbo
'''
def __init__(self, chars, maxlen):
self.chars = sorted(set(chars))
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
self.maxlen = maxlen

def encode(self, C, maxlen=None):
maxlen = maxlen if maxlen else self.maxlen
X = np.zeros((maxlen, len(self.chars)))
for i, c in enumerate(C):
X[i, self.char_indices[c]] = 1
return X

def decode(self, X, mode='argmax'):
if mode == 'argmax':
X = X.argmax(axis=-1)
elif mode == 'choice':
X = np.apply_along_axis(lambda vec: \
np.random.choice(len(vec), 1,
p=(vec / np.sum(vec))),
axis=-1, arr=X).ravel()
return str.join('',(self.indices_char[x] for x in X))


class SmilesDataGenerator(object):
"""
Given a list of SMILES strings,
returns a generator that returns batches of
randomly sampled strings encoded as one-hot,
as well as a weighting vector indicating true length of string
"""
SMILES_CHARS = [' ',
'#', '%', '(', ')', '+', '-', '.', '/',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'=', '@',
'A', 'B', 'C', 'F', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P',
'R', 'S', 'T', 'V', 'X', 'Z',
'[', '\\', ']',
'a', 'b', 'c', 'e', 'g', 'i', 'l', 'n', 'o', 'p', 'r', 's',
't', 'u']

def __init__(self, words, maxlen,
pad_char=' ',
pad_min=1,
pad_weight=0.0,
test_split=0.20):
self.maxlen = maxlen
self.words = words
self.max_words = len(words)
self.word_ixs = range(self.max_words)
self.shuffled_word_ixs = range(self.max_words)
random.shuffle(self.shuffled_word_ixs)

self.pad_char = pad_char
self.pad_min = pad_min
self.pad_weight = pad_weight
self.test_split = test_split
self.chars = sorted(set.union(set(SmilesDataGenerator.SMILES_CHARS),
set(pad_char)))
self.table = CharacterTable(self.chars, self.maxlen)

def encode(self, word):
padded_word = word + self.pad_char*(self.maxlen-len(word))
return self.table.encode(padded_word)

def weight(self, word):
weight_vec = np.ones((self.maxlen,))*self.pad_weight
weight_vec[np.arange(min(len(word)+self.pad_min, self.maxlen))] = 1
return weight_vec

def sample(self, predicate=None):
if predicate:
word_ix = random.choice(self.word_ixs)
if not predicate(word_ix):
return self.sample(predicate=predicate)
word = self.words[self.shuffled_word_ixs[word_ix]]
else:
word = random.choice(self.words)
if len(word) < self.maxlen:
return word
return self.sample(predicate=predicate)

def train_sample(self):
if self.test_split > 0:
threshold = self.max_words * self.test_split
return self.sample(lambda word_ix: word_ix >= threshold)
return self.sample()

def test_sample(self):
if self.test_split > 0:
threshold = self.max_words * self.test_split
return self.sample(lambda word_ix: word_ix < threshold)
return self.sample()

def generator(self, batch_size, sample_func=None):
while True:
data_tensor = np.zeros((batch_size, self.maxlen, len(self.chars)), dtype=np.bool)
weight_tensor = np.zeros((batch_size, self.maxlen))
for word_ix in range(batch_size):
if not sample_func:
sample_func = self.sample
word = sample_func()
data_tensor[word_ix, ...] = self.encode(word)
weight_tensor[word_ix, ...] = self.weight(word)
yield (data_tensor, data_tensor, weight_tensor)

def train_generator(self, batch_size):
return self.generator(batch_size, sample_func=self.train_sample)

def test_generator(self, batch_size):
return self.generator(batch_size, sample_func=self.test_sample)


class CanonicalSmilesDataGenerator(SmilesDataGenerator):
"""
Given a list of SMILES strings,
returns a generator that returns batches of
randomly sampled strings, canonicalized, encoded as one-hot,
as well as a weighting vector indicating true length of string
"""

def sample(self, predicate=None):
from rdkit import Chem
mol = Chem.MolFromSmiles(super(CanonicalSmilesDataGenerator, self).sample(predicate=predicate))
if mol:
canon_word = Chem.MolToSmiles(mol)
if len(canon_word) < self.maxlen:
return canon_word
return self.sample(predicate=predicate)

def train_sample(self):
if self.test_split > 0:
threshold = self.max_words * self.test_split
return self.sample(lambda word_ix: word_ix >= threshold)
return self.sample()

def test_sample(self):
if self.test_split > 0:
threshold = self.max_words * self.test_split
return self.sample(lambda word_ix: word_ix < threshold)
return self.sample()
127 changes: 127 additions & 0 deletions sample_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import print_function

import argparse
import os
import h5py
import numpy as np
import sys

from molecules.model import MoleculeVAE
from molecules.utils import one_hot_array, one_hot_index, from_one_hot_array, \
decode_smiles_from_indexes, load_dataset
from molecules.vectorizer import SmilesDataGenerator

LATENT_DIM = 292
NUM_SAMPLED = 100
TARGET = 'autoencoder'

def get_arguments():
parser = argparse.ArgumentParser(description='Molecular autoencoder network')
parser.add_argument('data', type=str, help='File of latent representation tensors for decoding.')
parser.add_argument('model', type=str, help='Trained Keras model to use.')
parser.add_argument('--save_h5', type=str, help='Name of a file to write HDF5 output to.')
parser.add_argument('--target', type=str, default=TARGET,
help='What model to sample from: autoencoder, encoder, decoder.')
parser.add_argument('--latent_dim', type=int, metavar='N', default=LATENT_DIM,
help='Dimensionality of the latent representation.')
parser.add_argument('--sample', type=int, metavar='N', default=NUM_SAMPLED,
help='Number of items to sample from data generator.')
return parser.parse_args()

def read_latent_data(filename):
h5f = h5py.File(filename, 'r')
data = h5f['latent_vectors'][:]
charset = h5f['charset'][:]
h5f.close()
return (data, charset)

def read_smiles_data(filename):
import pandas as pd
h5f = pd.read_hdf(filename, 'table')
data = h5f['structure'][:]
# import gzip
# data = [line.split()[0].strip() for line in gzip.open(filename) if line]
return data

def autoencoder(args, model):
latent_dim = args.latent_dim

structures = read_smiles_data(args.data)

datobj = SmilesDataGenerator(structures, 120)
train_gen = datobj.generator(1)

if os.path.isfile(args.model):
model.load(datobj.chars, args.model, latent_rep_size = latent_dim)
else:
raise ValueError("Model file %s doesn't exist" % args.model)

true_pred_gen = (((mat, weight, model.autoencoder.predict(mat))
for (mat, _, weight) in train_gen))
text_gen = ((str.join('\n',
[str((datobj.table.decode(true_mat[vec_ix])[:np.argmin(weight[vec_ix])],
datobj.table.decode(vec)[:]))
for (vec_ix, vec) in enumerate(pred_mat)]))
for (true_mat, weight, pred_mat) in true_pred_gen)
for _ in range(args.sample):
print(text_gen.next())

def decoder(args, model):
latent_dim = args.latent_dim
data, charset = read_latent_data(args.data)

if os.path.isfile(args.model):
model.load(charset, args.model, latent_rep_size = latent_dim)
else:
raise ValueError("Model file %s doesn't exist" % args.model)

for ix in range(len(data)):
sampled = model.decoder.predict(data[ix]).argmax(axis=2)[0]
sampled = decode_smiles_from_indexes(sampled, charset)
print(sampled)

def encoder(args, model):
latent_dim = args.latent_dim

structures = read_smiles_data(args.data)

datobj = SmilesDataGenerator(structures, 120)
train_gen = datobj.generator(1)

if os.path.isfile(args.model):
model.load(datobj.chars, args.model, latent_rep_size = latent_dim)
else:
raise ValueError("Model file %s doesn't exist" % args.model)

true_pred_gen = (((mat, weight, model.encoder.predict(mat))
for (mat, _, weight) in train_gen))
if args.save_h5:
h5f = h5py.File(args.save_h5, 'w')
h5f.create_dataset('charset', data = datobj.chars)
h5f.create_dataset('latent_vectors', (args.sample, 120, latent_dim))
for ix in range(args.sample):
_, _, x_latent = true_pred_gen.next()
h5f['latent_vectors'][ix] = x_latent[0]
h5f.close()
else:
text_gen = ((str.join('\n',
[str((datobj.table.decode(true_mat[vec_ix])[:np.argmin(weight[vec_ix])],
(vec)[:]))
for (vec_ix, vec) in enumerate(pred_mat)]))
for (true_mat, weight, pred_mat) in true_pred_gen)
for _ in range(args.sample):
print(text_gen.next())

def main():
args = get_arguments()
model = MoleculeVAE()

if args.target == 'autoencoder':
autoencoder(args, model)
elif args.target == 'encoder':
encoder(args, model)
elif args.target == 'decoder':
decoder(args, model)

if __name__ == '__main__':
main()
85 changes: 85 additions & 0 deletions train_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import print_function

import argparse
import os
import h5py
import numpy as np
import pandas as pd

from molecules.model import MoleculeVAE
from molecules.vectorizer import SmilesDataGenerator
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau

NUM_EPOCHS = 1
EPOCH_SIZE = 500000
BATCH_SIZE = 500
LATENT_DIM = 292
MAX_LEN = 120
TEST_SPLIT = 0.20

def get_arguments():
parser = argparse.ArgumentParser(description='Molecular autoencoder network')
parser.add_argument('data', type=str, help='The HDF5 file containing structures.')
parser.add_argument('model', type=str,
help='Where to save the trained model. If this file exists, it will be opened and resumed.')
parser.add_argument('--epochs', type=int, metavar='N', default=NUM_EPOCHS,
help='Number of epochs to run during training.')
parser.add_argument('--latent_dim', type=int, metavar='N', default=LATENT_DIM,
help='Dimensionality of the latent representation.')
parser.add_argument('--batch_size', type=int, metavar='N', default=BATCH_SIZE,
help='Number of samples to process per minibatch during training.')
parser.add_argument('--epoch_size', type=int, metavar='N', default=EPOCH_SIZE,
help='Number of samples to process per epoch during training.')
parser.add_argument('--test_split', type=float, metavar='N', default=TEST_SPLIT,
help='Fraction of dataset to use as test data, rest is
training data.')
return parser.parse_args()

def main():
args = get_arguments()

data = pd.read_hdf(args.data, 'table')
structures = data['structure']

# import gzip
# filepath = args.data
# structures = [line.split()[0].strip() for line in gzip.open(filepath) if line]

# can also use CanonicalSmilesDataGenerator
datobj = SmilesDataGenerator(structures, MAX_LEN,
test_split=args.test_split)
test_divisor = int((1 - datobj.test_split) / (datobj.test_split))
train_gen = datobj.train_generator(args.batch_size)
test_gen = datobj.test_generator(args.batch_size)

# reformulate generators to not use weights
train_gen = ((tens, tens) for (tens, _, weights) in train_gen)
test_gen = ((tens, tens) for (tens, _, weights) in test_gen)

model = MoleculeVAE()
if os.path.isfile(args.model):
model.load(datobj.chars, args.model, latent_rep_size = args.latent_dim)
else:
model.create(datobj.chars, latent_rep_size = args.latent_dim)

checkpointer = ModelCheckpoint(filepath = args.model,
verbose = 1,
save_best_only = True)

reduce_lr = ReduceLROnPlateau(monitor = 'val_loss',
factor = 0.2,
patience = 3,
min_lr = 0.0001)

model.autoencoder.fit_generator(
train_gen,
args.epoch_size,
nb_epoch = args.epochs,
callbacks = [checkpointer, reduce_lr],
validation_data = test_gen,
nb_val_samples = args.epoch_size / test_divisor,
pickle_safe = True
)

if __name__ == '__main__':
main()

0 comments on commit c53f54f

Please sign in to comment.