Skip to content

Commit

Permalink
fix #200 (#312)
Browse files Browse the repository at this point in the history
  • Loading branch information
wxchan authored and guolinke committed Mar 1, 2017
1 parent 9ea487b commit 7ae0e23
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 31 deletions.
4 changes: 2 additions & 2 deletions docs/Python-API.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

* [Training API](Python-API.md#training-api)
- [train](Python-API.md#trainparams-train_set-num_boost_round100-valid_setsnone-valid_namesnone-fobjnone-fevalnone-init_modelnone-feature_nameauto-early_stopping_roundsnone-evals_resultnone-verbose_evaltrue-learning_ratesnone-callbacksnone)
- [cv](Python-API.md#cvparams-train_set-num_boost_round10-nfold5-stratifiedfalse-shuffletrue-metricsnone-fobjnone-fevalnone-init_modelnone-feature_nameauto-early_stopping_roundsnone-fpreprocnone-verbose_evalnone-show_stdvtrue-seed0-callbacksnone)
- [cv](Python-API.md#cvparams-train_set-num_boost_round10-data_splitternone-nfold5-stratifiedfalse-shuffletrue-metricsnone-fobjnone-fevalnone-init_modelnone-feature_nameauto-early_stopping_roundsnone-fpreprocnone-verbose_evalnone-show_stdvtrue-seed0-callbacksnone)

* [Scikit-learn API](Python-API.md#scikit-learn-api)
- [Common Methods](Python-API.md#common-methods)
Expand Down Expand Up @@ -515,7 +515,7 @@ The methods of each Class is in alphabetical order.
booster : a trained booster model


####cv(params, train_set, num_boost_round=10, nfold=5, stratified=False, shuffle=True, metrics=None, fobj=None, feval=None, init_model=None, feature_name='auto', early_stopping_rounds=None, fpreproc=None, verbose_eval=None, show_stdv=True, seed=0, callbacks=None)
####cv(params, train_set, num_boost_round=10, data_splitter=None, nfold=5, stratified=False, shuffle=True, metrics=None, fobj=None, feval=None, init_model=None, feature_name='auto', early_stopping_rounds=None, fpreproc=None, verbose_eval=None, show_stdv=True, seed=0, callbacks=None)

Cross-validation with given paramaters.

Expand Down
30 changes: 10 additions & 20 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
import numpy as np
import scipy.sparse

from .compat import (DataFrame, Series, integer_types, json,
json_default_with_numpy, numeric_types, range_,
string_type)
from .compat import (DataFrame, Series, integer_types, json, numeric_types,
range_, string_type)
from .libpath import find_lib_path


Expand Down Expand Up @@ -223,25 +222,19 @@ def c_int_array(data):

def _data_from_pandas(data, feature_name):
if isinstance(data, DataFrame):
if feature_name == 'auto' or feature_name is None:
bad_fields = [data.columns[i] for i, dtype in enumerate(data.dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
if bad_fields:
msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields: """
raise ValueError(msg + ', '.join(bad_fields))
if feature_name == 'auto':
if all([isinstance(name, integer_types + (np.integer, )) for name in data.columns]):
msg = """Using Pandas (default) integer column names, not column indexes. You can use indexes with DataFrame.values."""
warnings.filterwarnings('once')
warnings.warn(msg, stacklevel=5)
data = data.rename(columns=str)
if feature_name == 'auto':
feature_name = list(data.columns)
data_dtypes = data.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
bad_fields = [data.columns[i] for i, dtype in
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]

msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
feature_name = [str(name) for name in data.columns]
data = data.values.astype('float')
else:
if feature_name == 'auto':
feature_name = None
elif feature_name == 'auto':
feature_name = None
return data, feature_name


Expand Down Expand Up @@ -366,9 +359,6 @@ def predict(self, data, num_iteration=-1,
elif isinstance(data, np.ndarray):
preds, nrow = self.__pred_for_np2d(data, num_iteration,
predict_type)
elif isinstance(data, DataFrame):
preds, nrow = self.__pred_for_np2d(data.values, num_iteration,
predict_type)
else:
try:
csr = scipy.sparse.csr_matrix(data)
Expand Down
9 changes: 0 additions & 9 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,6 @@ def argc_(func):
import json


def json_default_with_numpy(obj):
if isinstance(obj, (np.integer, np.floating, np.bool_)):
return obj.item()
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return obj


"""pandas"""
try:
from pandas import Series, DataFrame
Expand Down

0 comments on commit 7ae0e23

Please sign in to comment.