Skip to content

Commit

Permalink
Fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
natsukium committed Mar 22, 2019
1 parent 5947af9 commit 43f4989
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 182 deletions.
65 changes: 2 additions & 63 deletions examples/own_dataset/train_own_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from chainer_chemistry.dataset.parsers import CSVFileParser
from chainer_chemistry.dataset.preprocessors import preprocess_method_dict
from chainer_chemistry.datasets import NumpyTupleDataset
from chainer_chemistry.models import (
MLP, NFP, GGNN, SchNet, WeaveNet, RSGCN, RelGCN, RelGAT, Regressor)
from chainer_chemistry.models.prediction import GraphConvPredictor
from chainer_chemistry.models import Regressor
from chainer_chemistry.models.prediction import set_up_predictor


class MeanAbsError(object):
Expand Down Expand Up @@ -95,66 +94,6 @@ def __call__(self, x0, x1):
return numpy.mean(numpy.absolute(diff), axis=0)[0]


def set_up_predictor(method, n_unit, conv_layers, class_num):
"""Sets up the graph convolution network predictor.
Args:
method: Method name. Currently, the supported ones are `nfp`, `ggnn`,
`schnet`, `weavenet` and `rsgcn`.
n_unit: Number of hidden units.
conv_layers: Number of convolutional layers for the graph convolution
network.
class_num: Number of output classes.
Returns:
An instance of the selected predictor.
"""

predictor = None
mlp = MLP(out_dim=class_num, hidden_dim=n_unit)

if method == 'nfp':
print('Training an NFP predictor...')
nfp = NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(nfp, mlp)
elif method == 'ggnn':
print('Training a GGNN predictor...')
ggnn = GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(ggnn, mlp)
elif method == 'schnet':
print('Training an SchNet predictor...')
schnet = SchNet(out_dim=class_num, hidden_dim=n_unit,
n_layers=conv_layers)
predictor = GraphConvPredictor(schnet, None)
elif method == 'weavenet':
print('Training a WeaveNet predictor...')
n_atom = 20
n_sub_layer = 1
weave_channels = [50] * conv_layers

weavenet = WeaveNet(weave_channels=weave_channels, hidden_dim=n_unit,
n_sub_layer=n_sub_layer, n_atom=n_atom)
predictor = GraphConvPredictor(weavenet, mlp)
elif method == 'rsgcn':
print('Training an RSGCN predictor...')
rsgcn = RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(rsgcn, mlp)
elif method == 'relgcn':
print('Training an RelGCN predictor...')
num_edge_type = 4
relgcn = RelGCN(out_channels=n_unit, num_edge_type=num_edge_type,
scale_adj=True)
predictor = GraphConvPredictor(relgcn, mlp)
elif method == 'relgat':
print('Training an RelGAT predictor...')
relgat = RelGAT(out_dim=n_unit, hidden_dim=n_unit,
n_layers=conv_layers)
predictor = GraphConvPredictor(relgat, mlp)
else:
raise ValueError('[ERROR] Invalid method: {}'.format(method))
return predictor


def parse_arguments():
# Lists of supported preprocessing methods/models.
method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'rsgcn', 'relgcn',
Expand Down
61 changes: 1 addition & 60 deletions examples/qm9/train_qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
from chainer_chemistry import datasets as D
from chainer_chemistry.datasets import NumpyTupleDataset
from chainer_chemistry.links.scaler.standard_scaler import StandardScaler
from chainer_chemistry.models import (
MLP, NFP, GGNN, SchNet, WeaveNet, RSGCN, RelGCN, RelGAT)
from chainer_chemistry.models.prediction import Regressor
from chainer_chemistry.models.prediction import GraphConvPredictor
from chainer_chemistry.models.prediction import set_up_predictor


class MeanAbsError(object):
Expand Down Expand Up @@ -98,63 +96,6 @@ def parse_arguments():
return parser.parse_args()


def set_up_predictor(method, n_unit, conv_layers, class_num, scaler):
"""Sets up the predictor, consisting of a graph convolution network and
a multilayer perceptron.
Args:
method (str): Method name.
n_unit (int): Number of hidden units.
conv_layers (int): Number of convolutional layers for the graph
convolution network.
class_num (int): Number of output classes.
Returns:
predictor (chainer.Chain): An instance of the selected predictor.
"""
mlp = MLP(out_dim=class_num, hidden_dim=n_unit)

if method == 'nfp':
print('Training an NFP predictor...')
nfp = NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(nfp, mlp, scaler)
elif method == 'ggnn':
print('Training a GGNN predictor...')
ggnn = GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(ggnn, mlp, scaler)
elif method == 'schnet':
print('Training an SchNet predictor...')
schnet = SchNet(out_dim=class_num, hidden_dim=n_unit,
n_layers=conv_layers)
predictor = GraphConvPredictor(schnet, None, scaler)
elif method == 'weavenet':
print('Training a WeaveNet predictor...')
n_atom = 20
n_sub_layer = 1
weave_channels = [50] * conv_layers

weavenet = WeaveNet(weave_channels=weave_channels, hidden_dim=n_unit,
n_sub_layer=n_sub_layer, n_atom=n_atom)
predictor = GraphConvPredictor(weavenet, mlp, scaler)
elif method == 'rsgcn':
print('Training an RSGCN predictor...')
rsgcn = RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(rsgcn, mlp, scaler)
elif method == 'relgcn':
print('Use Relational GCN predictor...')
num_edge_type = 4
relgcn = RelGCN(out_channels=n_unit, num_edge_type=num_edge_type,
scale_adj=True)
predictor = GraphConvPredictor(relgcn, mlp, scaler)
elif method == 'relgat':
print('Train Relational GAT predictor...')
relgat = RelGAT(out_dim=n_unit, hidden_dim=n_unit,
n_layers=conv_layers)
predictor = GraphConvPredictor(relgat, mlp, scaler)
else:
raise ValueError('[ERROR] Invalid method: {}'.format(method))
return predictor


def main():
# Parse the arguments.
args = parse_arguments()
Expand Down
57 changes: 0 additions & 57 deletions examples/tox21/predictor.py

This file was deleted.

4 changes: 2 additions & 2 deletions examples/tox21/train_tox21.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from chainer_chemistry import datasets as D
from chainer_chemistry.iterators.balanced_serial_iterator import BalancedSerialIterator # NOQA
from chainer_chemistry.models.prediction import Classifier
from chainer_chemistry.models.prediction import set_up_predictor
from chainer_chemistry.training.extensions import ROCAUCEvaluator # NOQA

import data
import predictor

# Disable errors by RDKit occurred in preprocessing Tox21 dataset.
lg = RDLogger.logger()
Expand Down Expand Up @@ -94,7 +94,7 @@ def main():
train, val, _ = data.load_dataset(method, labels, num_data=args.num_data)

# Network
predictor_ = predictor.build_predictor(
predictor_ = set_up_predictor(
method, args.unit_num, args.conv_layers, class_num)

iterator_type = args.iterator_type
Expand Down

0 comments on commit 43f4989

Please sign in to comment.