From 32e432bdd87ad0c9029bfe7b60415093d61780cc Mon Sep 17 00:00:00 2001 From: Tetsuya Shioda Date: Thu, 4 Oct 2018 10:34:45 +0900 Subject: [PATCH] fix example: sklearn cross_validation module EOL --- example/classifier_kfold.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/example/classifier_kfold.py b/example/classifier_kfold.py index 83ced0c..e892a4e 100644 --- a/example/classifier_kfold.py +++ b/example/classifier_kfold.py @@ -11,12 +11,19 @@ calculate metrics (`classification_report`) using scikit-learn. """ +import sklearn import sklearn.datasets import sklearn.metrics -from sklearn.cross_validation import StratifiedKFold - from jubakit.classifier import Classifier, Dataset, Config +# switch StratifiedKFold API +sklearn_version = int(sklearn.__version__.split('.')[1]) +if sklearn_version < 18: + from sklearn.cross_validation import StratifiedKFold +else: + from sklearn.model_selection import StratifiedKFold + + # Load built-in `iris` dataset from scikit-learn. iris = sklearn.datasets.load_iris() @@ -37,7 +44,14 @@ predicted_labels = [] # Run stratified K-fold validation. -for train_idx, test_idx in StratifiedKFold(list(dataset.get_labels()), n_folds=10): +labels = list(dataset.get_labels()) +if sklearn_version < 18: + train_test_indices = StratifiedKFold(labels, n_folds=10) +else: + skf = StratifiedKFold(n_splits=10) + train_test_indices = skf.split(labels, labels) + +for train_idx, test_idx in train_test_indices: # Clear the classifier (call `clear` RPC). classifier.clear()