From 7509ec8ade06b13bd10aff41ece78c02a7993535 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Fri, 6 Sep 2019 20:05:01 +0800 Subject: [PATCH] [python] fix group type in lgb.cv (#2384) --- python-package/lightgbm/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 24b5403eeba..cf48d0e4423 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -307,7 +307,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi if hasattr(folds, 'split'): group_info = full_data.get_group() if group_info is not None: - group_info = group_info.astype(int) + group_info = np.array(group_info, dtype=int) flatted_group = np.repeat(range_(len(group_info)), repeats=group_info) else: flatted_group = np.zeros(num_data, dtype=int) @@ -317,7 +317,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi if not SKLEARN_INSTALLED: raise LightGBMError('Scikit-learn is required for lambdarank cv.') # lambdarank task, split according to groups - group_info = full_data.get_group().astype(int) + group_info = np.array(full_data.get_group(), dtype=int) flatted_group = np.repeat(range_(len(group_info)), repeats=group_info) group_kfold = _LGBMGroupKFold(n_splits=nfold) folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group)