Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move setup predictor to Library #336

Merged
merged 5 commits into from
Mar 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chainer_chemistry/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@
from chainer_chemistry.models.prediction.classifier import Classifier # NOQA
from chainer_chemistry.models.prediction.graph_conv_predictor import GraphConvPredictor # NOQA
from chainer_chemistry.models.prediction.regressor import Regressor # NOQA
from chainer_chemistry.models.prediction.set_up_predictor import set_up_predictor # NOQA
1 change: 1 addition & 0 deletions chainer_chemistry/models/prediction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from chainer_chemistry.models.prediction.classifier import Classifier # NOQA
from chainer_chemistry.models.prediction.graph_conv_predictor import GraphConvPredictor # NOQA
from chainer_chemistry.models.prediction.regressor import Regressor # NOQA
from chainer_chemistry.models.prediction.set_up_predictor import set_up_predictor # NOQA
96 changes: 96 additions & 0 deletions chainer_chemistry/models/prediction/set_up_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Any # NOQA
from typing import Dict # NOQA
from typing import Optional # NOQA

import chainer # NOQA

from chainer_chemistry.models.ggnn import GGNN
from chainer_chemistry.models.mlp import MLP
from chainer_chemistry.models.nfp import NFP
from chainer_chemistry.models.prediction.graph_conv_predictor import GraphConvPredictor # NOQA
from chainer_chemistry.models.relgat import RelGAT
from chainer_chemistry.models.relgcn import RelGCN
from chainer_chemistry.models.rsgcn import RSGCN
from chainer_chemistry.models.schnet import SchNet
from chainer_chemistry.models.weavenet import WeaveNet


def set_up_predictor(
method, # type: str
n_unit, # type: int
conv_layers, # type: int
class_num, # type: int
label_scaler=None, # type: Optional[chainer.Link]
postprocess_fn=None, # type: Optional[chainer.FunctionNode]
conv_kwargs=None # type: Optional[Dict[str, Any]]
):
# type: (...) -> GraphConvPredictor
"""Set up the predictor, consisting of a GCN and a MLP.

Args:
method (str): Method name.
n_unit (int): Number of hidden units.
conv_layers (int): Number of convolutional layers for the graph
convolution network.
class_num (int): Number of output classes.
label_scaler (chainer.Link or None): scaler link
postprocess_fn (chainer.FunctionNode or None):
postprocess function for prediction.
conv_kwargs (dict): keyword args for GraphConvolution model.
"""
mlp = MLP(out_dim=class_num, hidden_dim=n_unit) # type: Optional[MLP]
if conv_kwargs is None:
conv_kwargs = {}

if method == 'nfp':
print('Training an NFP predictor...')
conv = NFP(
out_dim=n_unit,
hidden_dim=n_unit,
n_layers=conv_layers,
**conv_kwargs)
elif method == 'ggnn':
print('Training a GGNN predictor...')
conv = GGNN(
out_dim=n_unit,
hidden_dim=n_unit,
n_layers=conv_layers,
**conv_kwargs)
elif method == 'schnet':
print('Training an SchNet predictor...')
conv = SchNet(
out_dim=class_num,
hidden_dim=n_unit,
n_layers=conv_layers,
**conv_kwargs)
mlp = None
elif method == 'weavenet':
print('Training a WeaveNet predictor...')
conv = WeaveNet(hidden_dim=n_unit, **conv_kwargs)
elif method == 'rsgcn':
print('Training an RSGCN predictor...')
conv = RSGCN(
out_dim=n_unit,
hidden_dim=n_unit,
n_layers=conv_layers,
**conv_kwargs)
elif method == 'relgcn':
print('Training a Relational GCN predictor...')
num_edge_type = 4
conv = RelGCN(
out_channels=n_unit,
num_edge_type=num_edge_type,
scale_adj=True,
**conv_kwargs)
elif method == 'relgat':
print('Training a Relational GAT predictor...')
conv = RelGAT(
out_dim=n_unit,
hidden_dim=n_unit,
n_layers=conv_layers,
**conv_kwargs)
else:
raise ValueError('[ERROR] Invalid method: {}'.format(method))

predictor = GraphConvPredictor(conv, mlp, label_scaler, postprocess_fn)
return predictor
65 changes: 2 additions & 63 deletions examples/own_dataset/train_own_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from chainer_chemistry.dataset.parsers import CSVFileParser
from chainer_chemistry.dataset.preprocessors import preprocess_method_dict
from chainer_chemistry.datasets import NumpyTupleDataset
from chainer_chemistry.models import (
MLP, NFP, GGNN, SchNet, WeaveNet, RSGCN, RelGCN, RelGAT, Regressor)
from chainer_chemistry.models.prediction import GraphConvPredictor
from chainer_chemistry.models import Regressor
from chainer_chemistry.models.prediction import set_up_predictor


class MeanAbsError(object):
Expand Down Expand Up @@ -95,66 +94,6 @@ def __call__(self, x0, x1):
return numpy.mean(numpy.absolute(diff), axis=0)[0]


def set_up_predictor(method, n_unit, conv_layers, class_num):
"""Sets up the graph convolution network predictor.

Args:
method: Method name. Currently, the supported ones are `nfp`, `ggnn`,
`schnet`, `weavenet` and `rsgcn`.
n_unit: Number of hidden units.
conv_layers: Number of convolutional layers for the graph convolution
network.
class_num: Number of output classes.

Returns:
An instance of the selected predictor.
"""

predictor = None
mlp = MLP(out_dim=class_num, hidden_dim=n_unit)

if method == 'nfp':
print('Training an NFP predictor...')
nfp = NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(nfp, mlp)
elif method == 'ggnn':
print('Training a GGNN predictor...')
ggnn = GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(ggnn, mlp)
elif method == 'schnet':
print('Training an SchNet predictor...')
schnet = SchNet(out_dim=class_num, hidden_dim=n_unit,
n_layers=conv_layers)
predictor = GraphConvPredictor(schnet, None)
elif method == 'weavenet':
print('Training a WeaveNet predictor...')
n_atom = 20
n_sub_layer = 1
weave_channels = [50] * conv_layers

weavenet = WeaveNet(weave_channels=weave_channels, hidden_dim=n_unit,
n_sub_layer=n_sub_layer, n_atom=n_atom)
predictor = GraphConvPredictor(weavenet, mlp)
elif method == 'rsgcn':
print('Training an RSGCN predictor...')
rsgcn = RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(rsgcn, mlp)
elif method == 'relgcn':
print('Training an RelGCN predictor...')
num_edge_type = 4
relgcn = RelGCN(out_channels=n_unit, num_edge_type=num_edge_type,
scale_adj=True)
predictor = GraphConvPredictor(relgcn, mlp)
elif method == 'relgat':
print('Training an RelGAT predictor...')
relgat = RelGAT(out_dim=n_unit, hidden_dim=n_unit,
n_layers=conv_layers)
predictor = GraphConvPredictor(relgat, mlp)
else:
raise ValueError('[ERROR] Invalid method: {}'.format(method))
return predictor


def parse_arguments():
# Lists of supported preprocessing methods/models.
method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'rsgcn', 'relgcn',
Expand Down
61 changes: 1 addition & 60 deletions examples/qm9/train_qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
from chainer_chemistry import datasets as D
from chainer_chemistry.datasets import NumpyTupleDataset
from chainer_chemistry.links.scaler.standard_scaler import StandardScaler
from chainer_chemistry.models import (
MLP, NFP, GGNN, SchNet, WeaveNet, RSGCN, RelGCN, RelGAT)
from chainer_chemistry.models.prediction import Regressor
from chainer_chemistry.models.prediction import GraphConvPredictor
from chainer_chemistry.models.prediction import set_up_predictor


class MeanAbsError(object):
Expand Down Expand Up @@ -98,63 +96,6 @@ def parse_arguments():
return parser.parse_args()


def set_up_predictor(method, n_unit, conv_layers, class_num, scaler):
"""Sets up the predictor, consisting of a graph convolution network and
a multilayer perceptron.

Args:
method (str): Method name.
n_unit (int): Number of hidden units.
conv_layers (int): Number of convolutional layers for the graph
convolution network.
class_num (int): Number of output classes.
Returns:
predictor (chainer.Chain): An instance of the selected predictor.
"""
mlp = MLP(out_dim=class_num, hidden_dim=n_unit)

if method == 'nfp':
print('Training an NFP predictor...')
nfp = NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(nfp, mlp, scaler)
elif method == 'ggnn':
print('Training a GGNN predictor...')
ggnn = GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(ggnn, mlp, scaler)
elif method == 'schnet':
print('Training an SchNet predictor...')
schnet = SchNet(out_dim=class_num, hidden_dim=n_unit,
n_layers=conv_layers)
predictor = GraphConvPredictor(schnet, None, scaler)
elif method == 'weavenet':
print('Training a WeaveNet predictor...')
n_atom = 20
n_sub_layer = 1
weave_channels = [50] * conv_layers

weavenet = WeaveNet(weave_channels=weave_channels, hidden_dim=n_unit,
n_sub_layer=n_sub_layer, n_atom=n_atom)
predictor = GraphConvPredictor(weavenet, mlp, scaler)
elif method == 'rsgcn':
print('Training an RSGCN predictor...')
rsgcn = RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(rsgcn, mlp, scaler)
elif method == 'relgcn':
print('Use Relational GCN predictor...')
num_edge_type = 4
relgcn = RelGCN(out_channels=n_unit, num_edge_type=num_edge_type,
scale_adj=True)
predictor = GraphConvPredictor(relgcn, mlp, scaler)
elif method == 'relgat':
print('Train Relational GAT predictor...')
relgat = RelGAT(out_dim=n_unit, hidden_dim=n_unit,
n_layers=conv_layers)
predictor = GraphConvPredictor(relgat, mlp, scaler)
else:
raise ValueError('[ERROR] Invalid method: {}'.format(method))
return predictor


def main():
# Parse the arguments.
args = parse_arguments()
Expand Down
57 changes: 0 additions & 57 deletions examples/tox21/predictor.py

This file was deleted.

4 changes: 2 additions & 2 deletions examples/tox21/train_tox21.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from chainer_chemistry import datasets as D
from chainer_chemistry.iterators.balanced_serial_iterator import BalancedSerialIterator # NOQA
from chainer_chemistry.models.prediction import Classifier
from chainer_chemistry.models.prediction import set_up_predictor
from chainer_chemistry.training.extensions import ROCAUCEvaluator # NOQA

import data
import predictor

# Disable errors by RDKit occurred in preprocessing Tox21 dataset.
lg = RDLogger.logger()
Expand Down Expand Up @@ -94,7 +94,7 @@ def main():
train, val, _ = data.load_dataset(method, labels, num_data=args.num_data)

# Network
predictor_ = predictor.build_predictor(
predictor_ = set_up_predictor(
method, args.unit_num, args.conv_layers, class_num)

iterator_type = args.iterator_type
Expand Down
Loading