Skip to content

Commit

Permalink
try to fix problem with multi-dimensional sliced object. (#1210)
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Jan 24, 2018
1 parent 61fb5ea commit 1e61f24
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 20 deletions.
51 changes: 31 additions & 20 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,22 @@ def writelines(self, lines):
"group": C_API_DTYPE_INT32}


def convert_from_sliced_object(data):
"""fix the memory of multi-dimensional sliced object"""
if data.base is not None and isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray):
if not data.flags.c_contiguous:
warnings.warn("Use subset(sliced data) of np.ndarray is not recommended due to it will double the peak memory cost in LightGBM.")
return np.copy(data)
return data


def c_float_array(data):
"""get pointer of float numpy array / list"""
if is_1d_list(data):
data = np.array(data, copy=False)
if is_numpy_1d_array(data):
data = convert_from_sliced_object(data)
assert data.flags.c_contiguous
if data.dtype == np.float32:
ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
type_data = C_API_DTYPE_FLOAT32
Expand All @@ -195,14 +206,16 @@ def c_float_array(data):
.format(data.dtype))
else:
raise TypeError("Unknown type({})".format(type(data).__name__))
return (ptr_data, type_data)
return (ptr_data, type_data, data)


def c_int_array(data):
"""get pointer of int numpy array / list"""
if is_1d_list(data):
data = np.array(data, copy=False)
if is_numpy_1d_array(data):
data = convert_from_sliced_object(data)
assert data.flags.c_contiguous
if data.dtype == np.int32:
ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
type_data = C_API_DTYPE_INT32
Expand All @@ -214,7 +227,7 @@ def c_int_array(data):
.format(data.dtype))
else:
raise TypeError("Unknown type({})".format(type(data).__name__))
return (ptr_data, type_data)
return (ptr_data, type_data, data)


PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int',
Expand Down Expand Up @@ -472,7 +485,7 @@ def __pred_for_np2d(self, mat, num_iteration, predict_type):
else:
"""change non-float data to float data, need to copy"""
data = np.array(mat.reshape(mat.size), dtype=np.float32)
ptr_data, type_ptr_data = c_float_array(data)
ptr_data, type_ptr_data, _ = c_float_array(data)
n_preds = self.__get_num_preds(num_iteration, mat.shape[0],
predict_type)
preds = np.zeros(n_preds, dtype=np.float64)
Expand Down Expand Up @@ -502,8 +515,8 @@ def __pred_for_csr(self, csr, num_iteration, predict_type):
preds = np.zeros(n_preds, dtype=np.float64)
out_num_preds = ctypes.c_int64(0)

ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr)
ptr_data, type_ptr_data = c_float_array(csr.data)
ptr_indptr, type_ptr_indptr, __ = c_int_array(csr.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csr.data)

_safe_call(_LIB.LGBM_BoosterPredictForCSR(
self.handle,
Expand Down Expand Up @@ -533,8 +546,8 @@ def __pred_for_csc(self, csc, num_iteration, predict_type):
preds = np.zeros(n_preds, dtype=np.float64)
out_num_preds = ctypes.c_int64(0)

ptr_indptr, type_ptr_indptr = c_int_array(csc.indptr)
ptr_data, type_ptr_data = c_float_array(csc.data)
ptr_indptr, type_ptr_indptr, __ = c_int_array(csc.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csc.data)

_safe_call(_LIB.LGBM_BoosterPredictForCSC(
self.handle,
Expand Down Expand Up @@ -747,7 +760,7 @@ def __init_from_np2d(self, mat, params_str, ref_dataset):
# change non-float data to float data, need to copy
data = np.array(mat.reshape(mat.size), dtype=np.float32)

ptr_data, type_ptr_data = c_float_array(data)
ptr_data, type_ptr_data, _ = c_float_array(data)
_safe_call(_LIB.LGBM_DatasetCreateFromMat(
ptr_data,
ctypes.c_int(type_ptr_data),
Expand All @@ -766,8 +779,8 @@ def __init_from_csr(self, csr, params_str, ref_dataset):
raise ValueError('Length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data)))
self.handle = ctypes.c_void_p()

ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr)
ptr_data, type_ptr_data = c_float_array(csr.data)
ptr_indptr, type_ptr_indptr, __ = c_int_array(csr.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csr.data)

_safe_call(_LIB.LGBM_DatasetCreateFromCSR(
ptr_indptr,
Expand All @@ -790,8 +803,8 @@ def __init_from_csc(self, csc, params_str, ref_dataset):
raise ValueError('Length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data)))
self.handle = ctypes.c_void_p()

ptr_indptr, type_ptr_indptr = c_int_array(csc.indptr)
ptr_data, type_ptr_data = c_float_array(csc.data)
ptr_indptr, type_ptr_indptr, __ = c_int_array(csc.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csc.data)

_safe_call(_LIB.LGBM_DatasetCreateFromCSC(
ptr_indptr,
Expand Down Expand Up @@ -824,6 +837,7 @@ def construct(self):
else:
# construct subset
used_indices = list_to_1d_numpy(self.used_indices, np.int32, name='used_indices')
assert used_indices.flags.c_contiguous
self.handle = ctypes.c_void_p()
params_str = param_dict_to_str(self.params)
_safe_call(_LIB.LGBM_DatasetGetSubset(
Expand Down Expand Up @@ -952,15 +966,10 @@ def set_field(self, field_name, data):
elif field_name == 'init_score':
dtype = np.float64
data = list_to_1d_numpy(data, dtype, name=field_name)
if data.dtype == np.float32:
ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
type_data = C_API_DTYPE_FLOAT32
elif data.dtype == np.float64:
ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
type_data = C_API_DTYPE_FLOAT64
if data.dtype == np.float32 or data.dtype == np.float64:
ptr_data, type_data, _ = c_float_array(data)
elif data.dtype == np.int32:
ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
type_data = C_API_DTYPE_INT32
ptr_data, type_data, _ = c_int_array(data)
else:
raise TypeError("Excepted np.float32/64 or np.int32, meet type({})".format(data.dtype))
if type_data != FIELD_TYPE_MAPPER[field_name]:
Expand Down Expand Up @@ -1536,6 +1545,8 @@ def __boost(self, grad, hess):
"""
grad = list_to_1d_numpy(grad, name='gradient')
hess = list_to_1d_numpy(hess, name='hessian')
assert grad.flags.c_contiguous
assert hess.flags.c_contiguous
if len(grad) != len(hess):
raise ValueError("Lengths of gradient({}) and hessian({}) don't match".format(len(grad), len(hess)))
is_finished = ctypes.c_int(0)
Expand Down
51 changes: 51 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
load_iris, load_svmlight_file)
from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split, TimeSeriesSplit
from scipy.sparse import csr_matrix

try:
import pandas as pd
Expand Down Expand Up @@ -548,3 +549,53 @@ def test_contribs(self):
evals_result=evals_result)

self.assertLess(np.linalg.norm(gbm.predict(X_test, raw_score=True) - np.sum(gbm.predict(X_test, pred_contrib=True), axis=1)), 1e-4)

def test_sliced_data(self):
def train_and_get_predictions(features, labels):
dataset = lgb.Dataset(features, label=labels)
lgb_params = {
'application': 'binary',
'verbose': -1,
'min_data': 5,
}
lgbm_model = lgb.train(
params=lgb_params,
train_set=dataset,
num_boost_round=10,
)
predictions = lgbm_model.predict(features)
return predictions
num_samples = 100
features = np.random.rand(num_samples, 5)
positive_samples = int(num_samples * 0.25)
labels = np.append(
np.ones(positive_samples, dtype=np.float32),
np.zeros(num_samples - positive_samples, dtype=np.float32),
)
# test sliced labels
origin_pred = train_and_get_predictions(features, labels)
stacked_labels = np.column_stack((labels, np.ones(num_samples, dtype=np.float32)))
sliced_labels = stacked_labels[:, 0]
sliced_pred = train_and_get_predictions(features, sliced_labels)
np.testing.assert_almost_equal(origin_pred, sliced_pred)
# append some columns
stacked_features = np.column_stack((np.ones(num_samples, dtype=np.float32), features))
stacked_features = np.column_stack((np.ones(num_samples, dtype=np.float32), stacked_features))
stacked_features = np.column_stack((stacked_features, np.ones(num_samples, dtype=np.float32)))
stacked_features = np.column_stack((stacked_features, np.ones(num_samples, dtype=np.float32)))
# append some rows
stacked_features = np.concatenate((np.ones(9, dtype=np.float32).reshape((1, 9)), stacked_features), axis=0)
stacked_features = np.concatenate((np.ones(9, dtype=np.float32).reshape((1, 9)), stacked_features), axis=0)
stacked_features = np.concatenate((stacked_features, np.ones(9, dtype=np.float32).reshape((1, 9))), axis=0)
stacked_features = np.concatenate((stacked_features, np.ones(9, dtype=np.float32).reshape((1, 9))), axis=0)
# test sliced 2d matrix
sliced_features = stacked_features[2:102, 2: 7]
assert np.all(sliced_features == features)
sliced_pred = train_and_get_predictions(sliced_features, sliced_labels)
np.testing.assert_almost_equal(origin_pred, sliced_pred)
# test sliced CSR
stacked_csr = csr_matrix(stacked_features)
sliced_csr = stacked_csr[2:102, 2: 7]
assert np.all(sliced_csr == features)
sliced_pred = train_and_get_predictions(sliced_csr, sliced_labels)
np.testing.assert_almost_equal(origin_pred, sliced_pred)

0 comments on commit 1e61f24

Please sign in to comment.