Skip to content

Commit

Permalink
[python] fix group type in lgb.cv (#2384)
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke authored and StrikerRUS committed Sep 6, 2019
1 parent 29525ff commit 7509ec8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python-package/lightgbm/engine.py
Expand Up @@ -307,7 +307,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
if hasattr(folds, 'split'):
group_info = full_data.get_group()
if group_info is not None:
group_info = group_info.astype(int)
group_info = np.array(group_info, dtype=int)
flatted_group = np.repeat(range_(len(group_info)), repeats=group_info)
else:
flatted_group = np.zeros(num_data, dtype=int)
Expand All @@ -317,7 +317,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for lambdarank cv.')
# lambdarank task, split according to groups
group_info = full_data.get_group().astype(int)
group_info = np.array(full_data.get_group(), dtype=int)
flatted_group = np.repeat(range_(len(group_info)), repeats=group_info)
group_kfold = _LGBMGroupKFold(n_splits=nfold)
folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group)
Expand Down

0 comments on commit 7509ec8

Please sign in to comment.