Skip to content

Commit

Permalink
Merge pull request #2236 from hsjang001205/GCN_reload
Browse files Browse the repository at this point in the history
Fix graph conv model save/load
  • Loading branch information
Bharath Ramsundar committed Oct 21, 2020
2 parents e5f9545 + ad24660 commit e06055e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 55 deletions.
4 changes: 2 additions & 2 deletions deepchem/models/layers.py
Expand Up @@ -127,14 +127,14 @@ def build(self, input_shape):
num_deg = 2 * self.max_degree + (1 - self.min_degree)
self.W_list = [
self.add_weight(
name='kernel',
name='kernel' + str(k),
shape=(int(input_shape[0][-1]), self.out_channel),
initializer='glorot_uniform',
trainable=True) for k in range(num_deg)
]
self.b_list = [
self.add_weight(
name='bias',
name='bias' + str(k),
shape=(self.out_channel,),
initializer='zeros',
trainable=True) for k in range(num_deg)
Expand Down
99 changes: 46 additions & 53 deletions deepchem/models/tests/test_reload.py
Expand Up @@ -849,59 +849,52 @@ def test_1d_cnn_regression_reload():
assert scores[regression_metric.name] < 0.1


### TODO: THIS IS FAILING!
#def test_graphconvmodel_reload():
# featurizer = dc.feat.ConvMolFeaturizer()
# tasks = ["outcome"]
# n_tasks = len(tasks)
# mols = ["C", "CO", "CC"]
# n_samples = len(mols)
# X = featurizer(mols)
# y = np.array([0, 1, 0])
# dataset = dc.data.NumpyDataset(X, y)
#
# classification_metric = dc.metrics.Metric(
# dc.metrics.roc_auc_score, np.mean, mode="classification")
#
# batch_size = 10
# model_dir = tempfile.mkdtemp()
# model = dc.models.GraphConvModel(
# len(tasks),
# batch_size=batch_size,
# batch_normalize=False,
# mode='classification',
# model_dir=model_dir)
#
# model.fit(dataset, nb_epoch=10)
# scores = model.evaluate(dataset, [classification_metric])
# assert scores[classification_metric.name] >= 0.9
#
#
# # Reload trained Model
# reloaded_model = dc.models.GraphConvModel(
# len(tasks),
# batch_size=batch_size,
# batch_normalize=False,
# mode='classification',
# model_dir=model_dir)
# reloaded_model.restore()
#
# # Check predictions match on random sample
# predmols = ["CCCC", "CCCCCO", "CCCCC"]
# Xpred = featurizer(predmols)
# predset = dc.data.NumpyDataset(Xpred)
# origpred = model.predict(predset)
# reloadpred = reloaded_model.predict(predset)
# assert np.all(origpred == reloadpred)
#
# # Try re-restore
# reloaded_model.restore()
# reloadpred = reloaded_model.predict(predset)
# assert np.all(origpred == reloadpred)
#
# # Eval model on train
# scores = reloaded_model.evaluate(dataset, [classification_metric])
# assert scores[classification_metric.name] > .9
def test_graphconvmodel_reload():
featurizer = dc.feat.ConvMolFeaturizer()
tasks = ["outcome"]
n_tasks = len(tasks)
mols = ["C", "CO", "CC"]
n_samples = len(mols)
X = featurizer(mols)
y = np.array([0, 1, 0])
dataset = dc.data.NumpyDataset(X, y)

classification_metric = dc.metrics.Metric(
dc.metrics.roc_auc_score, np.mean, mode="classification")

batch_size = 10
model_dir = tempfile.mkdtemp()
model = dc.models.GraphConvModel(
len(tasks),
batch_size=batch_size,
batch_normalize=False,
mode='classification',
model_dir=model_dir)

model.fit(dataset, nb_epoch=10)
scores = model.evaluate(dataset, [classification_metric])
assert scores[classification_metric.name] >= 0.6

# Reload trained Model
reloaded_model = dc.models.GraphConvModel(
len(tasks),
batch_size=batch_size,
batch_normalize=False,
mode='classification',
model_dir=model_dir)
reloaded_model.restore()

# Check predictions match on random sample
predmols = ["CCCC", "CCCCCO", "CCCCC"]
Xpred = featurizer(predmols)
predset = dc.data.NumpyDataset(Xpred)
origpred = model.predict(predset)
reloadpred = reloaded_model.predict(predset)
assert np.all(origpred == reloadpred)

# Eval model on train
scores = reloaded_model.evaluate(dataset, [classification_metric])
assert scores[classification_metric.name] > .6


def test_chemception_reload():
Expand Down

0 comments on commit e06055e

Please sign in to comment.