-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an example of using own dataset #114
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Example of using your own dataset | ||
## Usage | ||
``` | ||
python train.py dataset.csv --label value1 value2 | ||
``` | ||
|
||
## How to use your own dataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It reminds me that this is the only preferred way to use one's own dataset. But using |
||
1. Prepare a CSV file which contains the list of SMILES and the values you want to train. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think "... contains a list of and values ..." would be better. |
||
The first line of the CSV file should be label names. | ||
See `dataset.csv` as an example. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you write that |
||
|
||
2. Use `CSVFileParser` of Cheiner Chemistry to feed data to model. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to add a link to the document of CSVFileParser. |
||
See `train.csv` as an example. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
SMILES,value1,value2 | ||
CC1=CC2CC(CC1)O2,-0.227400004863739,0.010400000028312206 | ||
O=Cc1nccn1C=O,-0.2678000032901764,-0.09380000084638596 | ||
CCC(C)(C)C(O)C=O,-0.2685000002384186,-0.038100000470876694 | ||
C#CCC(C)(CO)OC,-0.2535000145435333,0.044599998742341995 | ||
Nc1coc(=O)nc1N,-0.2303999960422516,-0.04170000180602074 | ||
CC12C=CC(CCC1)C2,-0.2312999963760376,0.02239999920129776 | ||
CC12CCC1C2OC=O,-0.2605000138282776,0.005400000140070915 | ||
CC1C2CC3(COC3)N12,-0.23430000245571136,0.0697999969124794 | ||
O=C1NC=NC12CC2,-0.24070000648498535,-0.017000000923871994 | ||
C1=CC2CN2CC2NC12,-0.22169999778270721,0.007699999958276749 | ||
CC1C2COCC12O,-0.2467000037431717,0.07410000264644623 | ||
CC(=O)C1OCOC1=O,-0.2590000033378601,-0.042500000447034836 | ||
CC1N2C3CC1(C)C32,-0.2295999974012375,0.0835999995470047 | ||
CC1=CC2OC2(C#N)C1,-0.25999999046325684,-0.019899999722838402 | ||
OC1CCC1,-0.25600001215934753,0.08009999990463257 | ||
C#CC1(O)COC1C#N,-0.2849000096321106,-0.01769999973475933 | ||
CC1(C#N)CC12CCC2,-0.2685000002384186,0.03460000082850456 | ||
CCCC(N)(C#N)CO,-0.25760000944137573,0.028999999165534973 | ||
NC1=NC2(CC2)CC1=O,-0.22470000386238098,-0.053700000047683716 | ||
C#CC12C3CC1(C)OC32,-0.2273000031709671,0.026900000870227814 | ||
CC(C)C#CCC=O,-0.24539999663829803,-0.02669999934732914 | ||
CC#CC(C=O)CC,-0.24169999361038208,-0.02539999969303608 | ||
CC1OC2C1=CC1OC12,-0.2485000044107437,-0.01769999973475933 | ||
CNC(=N)C(C#N)OC,-0.23420000076293945,-0.0013000000035390258 | ||
C#CC(C#C)OCC=O,-0.26100000739097595,-0.031599998474121094 | ||
CN1CC(O)C12CC2,-0.20479999482631683,0.08730000257492065 | ||
OC1C2C3OC4C1C2C34,-0.24469999969005585,0.04230000078678131 | ||
OCC1C(O)C2CC12O,-0.24169999361038208,0.05739999935030937 | ||
O=C([O-])C12[NH2+]CC1C2O,-0.2508000135421753,-0.0003000000142492354 | ||
Cn1cc(O)c(CO)n1,-0.2045000046491623,0.01850000023841858 | ||
O=C1COC2C3OC2C13,-0.2498999983072281,-0.03700000047683716 | ||
C1#CCCOC=NCC1,-0.24279999732971191,0.012600000016391277 | ||
O=c1ocncc1CO,-0.2563000023365021,-0.06289999932050705 | ||
CC1NC1C(O)C(N)=O,-0.2547999918460846,0.023800000548362732 | ||
CC1OC(=N)CC2CC21,-0.2498999983072281,0.032499998807907104 | ||
OC12CCC3CN3C1C2,-0.21709999442100525,0.07280000299215317 | ||
C#CC(CCO)OC,-0.2581999897956848,0.033900000154972076 | ||
CCC1COC(CO)=N1,-0.2540999948978424,0.019200000911951065 | ||
ON=C1C=CC2C(O)C12,-0.2184000015258789,-0.04349999874830246 | ||
CN=c1cconn1,-0.23919999599456787,-0.037700001150369644 | ||
CC1(C)CC2CC2C1O,-0.2540999948978424,0.066600002348423 | ||
CCC1CCC(=N)O1,-0.2526000142097473,0.032600000500679016 | ||
O=C1C2CCC1C1NC21,-0.2282000035047531,-0.00279999990016222 | ||
CCOc1ccc(C)o1,-0.19059999287128448,0.033799998462200165 | ||
O=C1C2CC3C4C2C1N34,-0.23479999601840973,-0.026100000366568565 | ||
O=C1C=CCC=CC1=O,-0.24130000174045563,-0.08780000358819962 | ||
Cc1cc(F)c[nH]c1=O,-0.2117999941110611,-0.042100001126527786 | ||
CC1=CCc2nocc21,-0.22419999539852142,-0.019200000911951065 | ||
N#CC1(O)CN=COC1,-0.26980000734329224,-0.002400000113993883 | ||
Nc1n[nH]cc1N1CC1,-0.18649999797344208,0.03739999979734421 | ||
CN1C2CC3(O)C1C23C,-0.19619999825954437,0.07779999822378159 | ||
N=c1nccco1,-0.23680000007152557,-0.0689999982714653 | ||
COC12COC1(C)C2C,-0.22339999675750732,0.07020000368356705 | ||
CCOC1COC(=N)O1,-0.2547000050544739,0.0560000017285347 | ||
COC1(C(N)=O)CC1,-0.23800000548362732,0.0284000001847744 | ||
C#CCC#CC1NC1C,-0.23970000445842743,0.03180000185966492 | ||
C1NC1CN1C2CCC21,-0.2379000037908554,0.06539999693632126 | ||
CC(O)c1cc(N)[nH]n1,-0.21449999511241913,0.029899999499320984 | ||
CC1(O)C(O)C1C=O,-0.24230000376701355,-0.022099999710917473 | ||
C#CC1(C)C2C3OC3C21,-0.23819999396800995,0.025800000876188278 | ||
c1c[nH]c2cccc-2c1,-0.17229999601840973,-0.037300001829862595 | ||
CCC1(O)C(C)C1C=O,-0.24089999496936798,-0.01810000091791153 | ||
C1=C2C(CC1)CC1NC21,-0.2231999933719635,0.01940000057220459 | ||
C#CC1C2C(O)C1C2O,-0.24420000612735748,0.041999999433755875 | ||
CC1(C)CN2CC(C2)O1,-0.2093999981880188,0.07599999755620956 | ||
CC1OC1C1C2CN1C2,-0.22990000247955322,0.08429999649524689 | ||
CC(=O)C12CC(=O)C1C2,-0.25049999356269836,-0.04270000010728836 | ||
CC12C3=NCC1CC2O3,-0.23119999468326569,-0.016599999740719795 | ||
c1cc2onnc2[nH]1,-0.23520000278949738,-0.042399998754262924 | ||
O=CCCC1OC2CC12,-0.24369999766349792,-0.01850000023841858 | ||
OCCC1C2C3CC3N12,-0.2175000011920929,0.06040000170469284 | ||
OCC#CC1CC1,-0.23720000684261322,0.03359999880194664 | ||
OC1C2CC3C1N1C2C31,-0.22709999978542328,0.0640999972820282 | ||
CC1(C=O)C=CC(=O)N1,-0.25369998812675476,-0.05649999901652336 | ||
CC1CC23CC12CCO3,-0.20999999344348907,0.08139999955892563 | ||
CC(O)(C(N)=O)C1CO1,-0.24469999969005585,0.02889999933540821 | ||
CC1=NC2(CC2)C(=N)N1,-0.2134999930858612,0.0024999999441206455 | ||
N#CCCC(=O)C(N)=O,-0.25949999690055847,-0.08160000294446945 | ||
CC(O)(C#N)COC=N,-0.27379998564720154,0.00570000009611249 | ||
CC12C=CC(C)(N1)C2O,-0.22859999537467957,-0.0012000000569969416 | ||
CC12COC1CCO2,-0.2468000054359436,0.07940000295639038 | ||
c1noc2c1CCOC2,-0.24819999933242798,-0.010700000450015068 | ||
C#CC1CCCCOC1,-0.2467000037431717,0.053599998354911804 | ||
CN1C2C3OC2(C=O)C31,-0.23469999432563782,-0.04619999974966049 | ||
CCn1cc(O)nn1,-0.22519999742507935,0.0013000000035390258 | ||
CCOC(=NC)C(C)=O,-0.23420000076293945,-0.05640000104904175 | ||
CC12CC1(C#N)C1CC12,-0.26750001311302185,0.02070000022649765 | ||
CC(=O)C1OC1CC=O,-0.251800000667572,-0.04360000044107437 | ||
Nc1cc(=O)cno1,-0.23770000040531158,-0.053700000047683716 | ||
O=C1CC=CCC1O,-0.25519999861717224,-0.027300000190734863 | ||
CC1CC1CN1CC1C,-0.2190999984741211,0.08590000122785568 | ||
C#CCC(=N)OC=O,-0.2750999927520752,-0.032999999821186066 | ||
Cc1cnc(C=O)n1C,-0.23080000281333923,-0.053700000047683716 | ||
N=COCC(C=O)CO,-0.26260000467300415,-0.043699998408555984 | ||
CC1=C2CC3C(C1)C23C,-0.19580000638961792,-0.022700000554323196 | ||
CC1C=CCC(C)C1O,-0.2443999946117401,0.019099999219179153 | ||
C1c2n[nH]nc2C2CN12,-0.23350000381469727,-0.011800000444054604 | ||
COC1(C#N)CCC1C,-0.27480000257492065,0.02250000089406967 | ||
N=CNC(=O)C1CCO1,-0.25049999356269836,-0.020800000056624413 | ||
O=CC1(O)COC1=O,-0.27869999408721924,-0.06939999759197235 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
#!/usr/bin/env python | ||
|
||
from __future__ import print_function | ||
import argparse | ||
import sys | ||
|
||
from sklearn.preprocessing import StandardScaler | ||
|
||
try: | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
except ImportError: | ||
pass | ||
|
||
|
||
import chainer | ||
from chainer import functions as F, cuda, Variable | ||
from chainer import iterators as I | ||
from chainer import links as L | ||
from chainer import optimizers as O | ||
from chainer.datasets import split_dataset_random | ||
from chainer import training | ||
from chainer.training import extensions as E | ||
import numpy | ||
|
||
from chainer_chemistry.models import MLP, NFP, GGNN, SchNet, WeaveNet, RSGCN | ||
from chainer_chemistry.dataset.converters import concat_mols | ||
from chainer_chemistry.dataset.parsers import CSVFileParser | ||
from chainer_chemistry.dataset.preprocessors import preprocess_method_dict | ||
from chainer_chemistry.datasets import NumpyTupleDataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You do not have to separate import of chainer_chemisty because Chainer Chemisty is a third party library for this example and therefore can be treated in the same way as Chainer or NumPy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you sort import statements in alphabetical order? |
||
|
||
from rdkit import Chem | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as import of |
||
|
||
|
||
class GraphConvPredictor(chainer.Chain): | ||
|
||
def __init__(self, graph_conv, mlp=None): | ||
"""Initialize GraphConvPredictor | ||
|
||
Args: | ||
graph_conv: graph convolution network to obtain molecule feature | ||
representation | ||
mlp: multi layer perceptron, used as final connected layer. | ||
It can be `None` if no operation is necessary after | ||
`graph_conv` calculation. | ||
""" | ||
|
||
super(GraphConvPredictor, self).__init__() | ||
with self.init_scope(): | ||
self.graph_conv = graph_conv | ||
if isinstance(mlp, chainer.Link): | ||
self.mlp = mlp | ||
if not isinstance(mlp, chainer.Link): | ||
self.mlp = mlp | ||
|
||
def __call__(self, atoms, adjs): | ||
x = self.graph_conv(atoms, adjs) | ||
if self.mlp: | ||
x = self.mlp(x) | ||
return x | ||
|
||
|
||
def main(): | ||
# Supported preprocessing/network list | ||
method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'rsgcn'] | ||
scale_list = ['standardize', 'none'] | ||
|
||
parser = argparse.ArgumentParser( | ||
description='Regression with own dataset.') | ||
parser.add_argument('datafile', type=str) | ||
parser.add_argument('--method', '-m', type=str, choices=method_list, | ||
default='nfp') | ||
parser.add_argument('--label', '-l', nargs='+', | ||
help='target label for regression') | ||
parser.add_argument('--scale', type=str, choices=scale_list, | ||
default='standardize', help='Label scaling method') | ||
parser.add_argument('--conv-layers', '-c', type=int, default=4) | ||
parser.add_argument('--batchsize', '-b', type=int, default=32) | ||
parser.add_argument('--gpu', '-g', type=int, default=-1) | ||
parser.add_argument('--out', '-o', type=str, default='result') | ||
parser.add_argument('--epoch', '-e', type=int, default=20) | ||
parser.add_argument('--unit-num', '-u', type=int, default=16) | ||
parser.add_argument('--seed', '-s', type=int, default=777) | ||
parser.add_argument('--train-data-ratio', '-t', type=float, default=0.7) | ||
args = parser.parse_args() | ||
|
||
seed = args.seed | ||
train_data_ratio = args.train_data_ratio | ||
method = args.method | ||
if args.label: | ||
labels = args.label | ||
class_num = len(labels) if isinstance(labels, list) else 1 | ||
else: | ||
sys.exit("Error: No target label is specified.") | ||
|
||
# Dataset preparation | ||
dataset = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this line? |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this empty line. |
||
# Postprocess is required for regression task | ||
def postprocess_label(label_list): | ||
return numpy.asarray(label_list, dtype=numpy.float32) | ||
|
||
print('preprocessing dataset...') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer to capitalize the first character, as we do in other places. |
||
preprocessor = preprocess_method_dict[method]() | ||
parser = CSVFileParser(preprocessor, | ||
postprocess_label=postprocess_label, | ||
labels=labels, smiles_col='SMILES') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it intentional that you hard-coded the value of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I want to show that we can specify the column name by |
||
dataset = parser.parse(args.datafile)["dataset"] | ||
|
||
if args.scale == 'standardize': | ||
# Standard Scaler for labels | ||
ss = StandardScaler() | ||
labels = ss.fit_transform(dataset.get_datasets()[-1]) | ||
dataset = NumpyTupleDataset(*dataset.get_datasets()[:-1], labels) | ||
|
||
train_data_size = int(len(dataset) * train_data_ratio) | ||
train, val = split_dataset_random(dataset, train_data_size, seed) | ||
|
||
# Network | ||
n_unit = args.unit_num | ||
conv_layers = args.conv_layers | ||
if method == 'nfp': | ||
print('Train NFP model...') | ||
model = GraphConvPredictor(NFP(out_dim=n_unit, hidden_dim=n_unit, | ||
n_layers=conv_layers), | ||
MLP(out_dim=class_num, hidden_dim=n_unit)) | ||
elif method == 'ggnn': | ||
print('Train GGNN model...') | ||
model = GraphConvPredictor(GGNN(out_dim=n_unit, hidden_dim=n_unit, | ||
n_layers=conv_layers), | ||
MLP(out_dim=class_num, hidden_dim=n_unit)) | ||
elif method == 'schnet': | ||
print('Train SchNet model...') | ||
model = GraphConvPredictor( | ||
SchNet(out_dim=class_num, hidden_dim=n_unit, n_layers=conv_layers), | ||
None) | ||
elif method == 'weavenet': | ||
print('Train WeaveNet model...') | ||
n_atom = 20 | ||
n_sub_layer = 1 | ||
weave_channels = [50] * conv_layers | ||
model = GraphConvPredictor( | ||
WeaveNet(weave_channels=weave_channels, hidden_dim=n_unit, | ||
n_sub_layer=n_sub_layer, n_atom=n_atom), | ||
MLP(out_dim=class_num, hidden_dim=n_unit)) | ||
elif method == 'rsgcn': | ||
print('Train RSGCN model...') | ||
model = GraphConvPredictor( | ||
RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers), | ||
MLP(out_dim=class_num, hidden_dim=n_unit)) | ||
else: | ||
raise ValueError('[ERROR] Invalid method {}'.format(method)) | ||
|
||
train_iter = I.SerialIterator(train, args.batchsize) | ||
val_iter = I.SerialIterator(val, args.batchsize, | ||
repeat=False, shuffle=False) | ||
|
||
def scaled_abs_error(x0, x1): | ||
if isinstance(x0, Variable): | ||
x0 = cuda.to_cpu(x0.data) | ||
if isinstance(x1, Variable): | ||
x1 = cuda.to_cpu(x1.data) | ||
if args.scale == 'standardize': | ||
scaled_x0 = ss.inverse_transform(cuda.to_cpu(x0)) | ||
scaled_x1 = ss.inverse_transform(cuda.to_cpu(x1)) | ||
diff = scaled_x0 - scaled_x1 | ||
elif args.scale == 'none': | ||
diff = cuda.to_cpu(x0) - cuda.to_cpu(x1) | ||
return numpy.mean(numpy.absolute(diff), axis=0)[0] | ||
|
||
classifier = L.Classifier(model, lossfun=F.mean_squared_error, | ||
accfun=scaled_abs_error) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment somewhere that scaled errors are reported as |
||
|
||
if args.gpu >= 0: | ||
chainer.cuda.get_device_from_id(args.gpu).use() | ||
classifier.to_gpu() | ||
|
||
optimizer = O.Adam() | ||
optimizer.setup(classifier) | ||
|
||
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu, | ||
converter=concat_mols) | ||
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) | ||
trainer.extend(E.Evaluator(val_iter, classifier, device=args.gpu, | ||
converter=concat_mols)) | ||
trainer.extend(E.snapshot(), trigger=(args.epoch, 'epoch')) | ||
trainer.extend(E.LogReport()) | ||
trainer.extend(E.PrintReport(['epoch', 'main/loss', 'main/accuracy', | ||
'validation/main/loss', | ||
'validation/main/accuracy', | ||
'elapsed_time'])) | ||
trainer.extend(E.ProgressBar()) | ||
trainer.run() | ||
|
||
# Example of prediction using trained model | ||
smiles = 'c1ccccc1' | ||
mol = Chem.MolFromSmiles(smiles) | ||
preprocessor = preprocess_method_dict[method]() | ||
standardized_smiles, mol = preprocessor.prepare_smiles_and_mol(mol) | ||
input_features = preprocessor.get_input_features(mol) | ||
atoms, adjs = concat_mols([input_features]) | ||
prediction = model(atoms, adjs).data[0] | ||
print('Prediction for {}:'.format(smiles)) | ||
for i, label in enumerate(args.label): | ||
print('{}: {}'.format(label, prediction[i])) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please explain (a part of ) options of the script. At least
--label
needs description. But I think the option is the only one treat specially and it would be enough to write as "typepython train.py --help
to see complete options".