Skip to content

Commit

Permalink
[python] updated multiclass objective in sklearn (#1218)
Browse files Browse the repository at this point in the history
* added comment to not forget

* updated LGBMClassifier according to new aliases
  • Loading branch information
StrikerRUS authored and guolinke committed Jan 25, 2018
1 parent 1e61f24 commit 87f2acb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 3 additions & 3 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def convert_from_sliced_object(data):
"""fix the memory of multi-dimensional sliced object"""
if data.base is not None and isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray):
if not data.flags.c_contiguous:
warnings.warn("Use subset(sliced data) of np.ndarray is not recommended due to it will double the peak memory cost in LightGBM.")
warnings.warn("Usage subset(sliced data) of np.ndarray is not recommended due to it will double the peak memory cost in LightGBM.")
return np.copy(data)
return data

Expand All @@ -206,7 +206,7 @@ def c_float_array(data):
.format(data.dtype))
else:
raise TypeError("Unknown type({})".format(type(data).__name__))
return (ptr_data, type_data, data)
return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed


def c_int_array(data):
Expand All @@ -227,7 +227,7 @@ def c_int_array(data):
.format(data.dtype))
else:
raise TypeError("Unknown type({})".format(type(data).__name__))
return (ptr_data, type_data, data)
return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed


PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int',
Expand Down
3 changes: 2 additions & 1 deletion python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,8 @@ def fit(self, X, y,
self._n_classes = len(self._classes)
if self._n_classes > 2:
# Switch to using a multiclass objective in the underlying LGBM instance
if self._objective != "multiclassova" and not callable(self._objective):
ova_aliases = ("multiclassova", "multiclass_ova", "ova", "ovr")
if self._objective not in ova_aliases and not callable(self._objective):
self._objective = "multiclass"
if eval_metric == 'logloss' or eval_metric == 'binary_logloss':
eval_metric = "multi_logloss"
Expand Down

0 comments on commit 87f2acb

Please sign in to comment.