From f8b01bee5789eb3471ed91fd935a35b43ddc3050 Mon Sep 17 00:00:00 2001 From: corochann Date: Fri, 13 Apr 2018 14:15:36 +0900 Subject: [PATCH 1/2] change file name --- .../tox21/{predict_tox21.py => predict_tox21_with_classifier.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/tox21/{predict_tox21.py => predict_tox21_with_classifier.py} (100%) diff --git a/examples/tox21/predict_tox21.py b/examples/tox21/predict_tox21_with_classifier.py similarity index 100% rename from examples/tox21/predict_tox21.py rename to examples/tox21/predict_tox21_with_classifier.py From c5b455a2ccb0f0a052465439dbdbbeec733aab20 Mon Sep 17 00:00:00 2001 From: corochann Date: Fri, 13 Apr 2018 14:21:26 +0900 Subject: [PATCH 2/2] update example --- examples/tox21/predict_tox21_with_classifier.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/tox21/predict_tox21_with_classifier.py b/examples/tox21/predict_tox21_with_classifier.py index 4748898f..772d0b7a 100644 --- a/examples/tox21/predict_tox21_with_classifier.py +++ b/examples/tox21/predict_tox21_with_classifier.py @@ -14,7 +14,13 @@ from chainer_chemistry import datasets as D from chainer_chemistry.dataset.converters import concat_mols -from chainer_chemistry.models.prediction import Classifier +try: + from chainer_chemistry.models.prediction import Classifier +except ImportError: + print('[WARNING] If you want to use Classifier in Chainer Chemistry, ' + 'please install the library from master branch.\n See ' + 'https://github.com/pfnet-research/chainer-chemistry#installation' + ' for detail.') from chainer_chemistry.training.extensions.roc_auc_evaluator import \ ROCAUCEvaluator @@ -23,7 +29,6 @@ # Disable errors by RDKit occurred in preprocessing Tox21 dataset. - lg = RDLogger.logger() lg.setLevel(RDLogger.CRITICAL) @@ -101,6 +106,9 @@ def main(): # ---- predict --- print('Predicting...') + # We need to feed only input features `x` to `predict`/`predict_proba`. + # This converter extracts only inputs (x1, x2, ...) from the features which + # consist of input `x` and label `t` (x1, x2, ..., t). def extract_inputs(batch, device=None): return concat_mols(batch, device=device)[:-1]