Skip to content

Commit

Permalink
Merge pull request #128 from jubatus/fix-sklearn-eol-functions
Browse files Browse the repository at this point in the history
fix example: sklearn cross_validation module EOL
  • Loading branch information
rimms committed Oct 19, 2018
2 parents 5b08606 + 32e432b commit 8096681
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions example/classifier_kfold.py
Expand Up @@ -11,12 +11,19 @@
calculate metrics (`classification_report`) using scikit-learn.
"""

import sklearn
import sklearn.datasets
import sklearn.metrics
from sklearn.cross_validation import StratifiedKFold

from jubakit.classifier import Classifier, Dataset, Config

# switch StratifiedKFold API
sklearn_version = int(sklearn.__version__.split('.')[1])
if sklearn_version < 18:
from sklearn.cross_validation import StratifiedKFold
else:
from sklearn.model_selection import StratifiedKFold


# Load built-in `iris` dataset from scikit-learn.
iris = sklearn.datasets.load_iris()

Expand All @@ -37,7 +44,14 @@
predicted_labels = []

# Run stratified K-fold validation.
for train_idx, test_idx in StratifiedKFold(list(dataset.get_labels()), n_folds=10):
labels = list(dataset.get_labels())
if sklearn_version < 18:
train_test_indices = StratifiedKFold(labels, n_folds=10)
else:
skf = StratifiedKFold(n_splits=10)
train_test_indices = skf.split(labels, labels)

for train_idx, test_idx in train_test_indices:
# Clear the classifier (call `clear` RPC).
classifier.clear()

Expand Down

0 comments on commit 8096681

Please sign in to comment.