Skip to content

Commit

Permalink
Merge 70abc0a into 53c3b55
Browse files Browse the repository at this point in the history
  • Loading branch information
mufeili committed Nov 2, 2020
2 parents 53c3b55 + 70abc0a commit eec4df9
Show file tree
Hide file tree
Showing 9 changed files with 495 additions and 8 deletions.
21 changes: 14 additions & 7 deletions deepchem/feat/graph_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,29 +123,36 @@ def to_pyg_graph(self):
edge_attr=edge_features,
pos=node_pos_features)

def to_dgl_graph(self):
def to_dgl_graph(self, self_loop: bool = False):
"""Convert to DGL graph data instance
Returns
-------
dgl.DGLGraph
Graph data for DGL
self_loop: bool
Whether to add self loops for the nodes, i.e. edges from nodes
to themselves. Default to False.
Notes
-----
This method requires DGL to be installed.
"""
try:
import dgl
import torch
from dgl import DGLGraph
except ModuleNotFoundError:
raise ImportError("This function requires DGL to be installed.")

g = DGLGraph()
g.add_nodes(self.num_nodes)
g.add_edges(
torch.from_numpy(self.edge_index[0]).long(),
torch.from_numpy(self.edge_index[1]).long())
src = self.edge_index[0]
dst = self.edge_index[1]
if self_loop:
src = np.concatenate([src, np.arange(self.num_nodes)])
dst = np.concatenate([dst, np.arange(self.num_nodes)])

g = dgl.graph(
(torch.from_numpy(src).long(), torch.from_numpy(dst).long()),
num_nodes=self.num_nodes)
g.ndata['x'] = torch.from_numpy(self.node_features).float()

if self.node_pos_features is not None:
Expand Down
1 change: 1 addition & 0 deletions deepchem/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from deepchem.models.torch_models import TorchModel
from deepchem.models.torch_models import CGCNN, CGCNNModel
from deepchem.models.torch_models import GAT, GATModel
from deepchem.models.torch_models import GCN, GCNModel
except ModuleNotFoundError:
pass

Expand Down
102 changes: 102 additions & 0 deletions deepchem/models/tests/test_gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import unittest
import tempfile

import numpy as np

import deepchem as dc
from deepchem.feat import MolGraphConvFeaturizer
from deepchem.models import GCNModel
from deepchem.models.tests.test_graph_models import get_dataset

try:
import dgl
import dgllife
import torch
has_torch_and_dgl = True
except:
has_torch_and_dgl = False


@unittest.skipIf(not has_torch_and_dgl,
'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gcn_regression():
# load datasets
featurizer = MolGraphConvFeaturizer()
tasks, dataset, transformers, metric = get_dataset(
'regression', featurizer=featurizer)

# initialize models
n_tasks = len(tasks)
model = GCNModel(
mode='regression',
n_tasks=n_tasks,
number_atom_features=30,
batch_size=10)

# overfit test
model.fit(dataset, nb_epoch=100)
scores = model.evaluate(dataset, [metric], transformers)
assert scores['mean_absolute_error'] < 0.5


@unittest.skipIf(not has_torch_and_dgl,
'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gcn_classification():
# load datasets
featurizer = MolGraphConvFeaturizer()
tasks, dataset, transformers, metric = get_dataset(
'classification', featurizer=featurizer)

# initialize models
n_tasks = len(tasks)
model = GCNModel(
mode='classification',
n_tasks=n_tasks,
number_atom_features=30,
batch_size=10,
learning_rate=0.001)

# overfit test
model.fit(dataset, nb_epoch=50)
scores = model.evaluate(dataset, [metric], transformers)
assert scores['mean-roc_auc_score'] >= 0.85


@unittest.skipIf(not has_torch_and_dgl,
'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gcn_reload():
# load datasets
featurizer = MolGraphConvFeaturizer()
tasks, dataset, transformers, metric = get_dataset(
'classification', featurizer=featurizer)

# initialize models
n_tasks = len(tasks)
model_dir = tempfile.mkdtemp()
model = GCNModel(
mode='classification',
n_tasks=n_tasks,
number_atom_features=30,
model_dir=model_dir,
batch_size=10,
learning_rate=0.001)

model.fit(dataset, nb_epoch=50)
scores = model.evaluate(dataset, [metric], transformers)
assert scores['mean-roc_auc_score'] >= 0.85

reloaded_model = GCNModel(
mode='classification',
n_tasks=n_tasks,
number_atom_features=30,
model_dir=model_dir,
batch_size=10,
learning_rate=0.001)
reloaded_model.restore()

pred_mols = ["CCCC", "CCCCCO", "CCCCC"]
X_pred = featurizer(pred_mols)
random_dataset = dc.data.NumpyDataset(X_pred)
original_pred = model.predict(random_dataset)
reload_pred = reloaded_model.predict(random_dataset)
assert np.all(original_pred == reload_pred)
1 change: 1 addition & 0 deletions deepchem/models/torch_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from deepchem.models.torch_models.torch_model import TorchModel
from deepchem.models.torch_models.cgcnn import CGCNN, CGCNNModel
from deepchem.models.torch_models.gat import GAT, GATModel
from deepchem.models.torch_models.gcn import GCN, GCNModel

0 comments on commit eec4df9

Please sign in to comment.