Skip to content

Commit

Permalink
Define metainfo and other parameters for all DMatrix interfaces.
Browse files Browse the repository at this point in the history
This PR ensures all DMatrix types have a common interface.

* Check for consistency between DMatrix types.
  • Loading branch information
trivialfis committed Jan 12, 2021
1 parent 03cd087 commit b21a3a4
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 117 deletions.
175 changes: 123 additions & 52 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,15 +331,18 @@ def data_handle(data, label=None, weight=None, base_margin=None,
data, feature_names, feature_types
)
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
self.proxy.set_info(label=label, weight=weight,
base_margin=base_margin,
group=group,
qid=qid,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_names=feature_names,
feature_types=feature_types,
feature_weights=feature_weights)
self.proxy.set_info(
label=label,
weight=weight,
base_margin=base_margin,
group=group,
qid=qid,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_names=feature_names,
feature_types=feature_types,
feature_weights=feature_weights
)
try:
# Differ the exception in order to return 0 and stop the iteration.
# Exception inside a ctype callback function has no effect except
Expand Down Expand Up @@ -427,21 +430,34 @@ def inner_f(*args, **kwargs):
return inner_f


class DMatrix: # pylint: disable=too-many-instance-attributes
class DMatrix: # pylint: disable=too-many-instance-attributes
"""Data Matrix used in XGBoost.
DMatrix is an internal data structure that is used by XGBoost,
which is optimized for both memory efficiency and training speed.
You can construct DMatrix from multiple different sources of data.
"""

def __init__(self, data, label=None, weight=None, base_margin=None,
missing=None,
silent=False,
feature_names=None,
feature_types=None,
nthread=None,
enable_categorical=False):
@_deprecate_positional_args
def __init__(
self,
data,
label=None,
*,
weight=None,
base_margin=None,
group=None,
qid=None,
label_lower_bound=None,
label_upper_bound=None,
missing: Optional[float] = None,
silent: bool = False,
feature_weights=None,
feature_names=None,
feature_types=None,
nthread: Optional[int] = None,
enable_categorical: bool = False
) -> None:
"""Parameters
----------
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
Expand All @@ -451,12 +467,9 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
libsvm format txt file, csv file (by specifying uri parameter
'path_to_csv?format=csv'), or binary file that xgboost can read
from.
label : list, numpy 1-D array or cudf.DataFrame, optional
label : array_like
Label of the training data.
missing : float, optional
Value in the input data which needs to be present as a missing
value. If None, defaults to np.nan.
weight : list, numpy 1-D array or cudf.DataFrame , optional
weight : array_like
Weight for each instance.
.. note:: For ranking task, weights are per-group.
Expand All @@ -465,9 +478,19 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
data point). This is because we only care about the relative
ordering of data points within each group, so it doesn't make
sense to assign weights to individual data points.
base_margin: array_like
Base margin used for boosting from existing model.
group : array_like:
Group size for all ranking group.
qid : array_like:
Query ID for data samples, used for ranking.
missing : float, optional
Value in the input data which needs to be present as a missing
value. If None, defaults to np.nan.
silent : boolean, optional
Whether print messages during construction
feature_weights : array_like, optional
Set feature weights for column sampling.
feature_names : list, optional
Set names for features.
feature_types : list, optional
Expand All @@ -488,7 +511,9 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
"""
if isinstance(data, list):
raise TypeError('Input data can not be a list.')
raise TypeError("Input data can not be a list.")
if group is not None and qid is not None:
raise ValueError("Either one of `group` or `qid` should be None.")

self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else -1
Expand All @@ -500,16 +525,28 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
return

from .data import dispatch_data_backend

handle, feature_names, feature_types = dispatch_data_backend(
data, missing=self.missing,
data,
missing=self.missing,
threads=self.nthread,
feature_names=feature_names,
feature_types=feature_types,
enable_categorical=enable_categorical)
enable_categorical=enable_categorical,
)
assert handle is not None
self.handle = handle

self.set_info(label=label, weight=weight, base_margin=base_margin)
self.set_info(
label=label,
weight=weight,
base_margin=base_margin,
group=group,
qid=qid,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_weights=feature_weights,
)

if feature_names is not None:
self.feature_names = feature_names
Expand All @@ -522,17 +559,23 @@ def __del__(self):
self.handle = None

@_deprecate_positional_args
def set_info(self, *,
label=None, weight=None, base_margin=None,
group=None,
qid=None,
label_lower_bound=None,
label_upper_bound=None,
feature_names=None,
feature_types=None,
feature_weights=None):
'''Set meta info for DMatrix.'''
def set_info(
self,
*,
label=None,
weight=None,
base_margin=None,
group=None,
qid=None,
label_lower_bound=None,
label_upper_bound=None,
feature_names=None,
feature_types=None,
feature_weights=None
) -> None:
"""Set meta info for DMatrix. See doc string for DMatrix constructor."""
from .data import dispatch_meta_backend

if label is not None:
self.set_label(label)
if weight is not None:
Expand Down Expand Up @@ -937,39 +980,67 @@ class DeviceQuantileDMatrix(DMatrix):
information may be lost in quantisation. This DMatrix is primarily designed
to save memory in training from device memory inputs by avoiding
intermediate storage. Set max_bin to control the number of bins during
quantisation.
quantisation. See doc string in `DMatrix` for documents on meta info.
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
.. versionadded:: 1.1.0
"""

def __init__(self, data, label=None, weight=None, # pylint: disable=W0231
base_margin=None,
missing=None,
silent=False,
feature_names=None,
feature_types=None,
nthread=None, max_bin=256):
@_deprecate_positional_args
def __init__( # pylint: disable=super-init-not-called
self,
data,
label=None,
*,
weight=None,
base_margin=None,
group=None,
qid=None,
label_lower_bound=None,
label_upper_bound=None,
missing=None,
silent=False,
feature_weights=None,
feature_names=None,
feature_types=None,
nthread: Optional[int] = None,
enable_categorical: bool = False,
max_bin: int = 256,
):
self.max_bin = max_bin
self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else 1
self._silent = silent # unused, kept for compatibility

if isinstance(data, ctypes.c_void_p):
self.handle = data
return
from .data import init_device_quantile_dmatrix
handle, feature_names, feature_types = init_device_quantile_dmatrix(
data, missing=self.missing, threads=self.nthread,
max_bin=self.max_bin,
data,
label=label, weight=weight,
base_margin=base_margin,
group=None,
label_lower_bound=None,
label_upper_bound=None,
group=group,
qid=qid,
missing=self.missing,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_weights=feature_weights,
feature_names=feature_names,
feature_types=feature_types)
feature_types=feature_types,
threads=self.nthread,
max_bin=self.max_bin,
)
if enable_categorical:
raise NotImplementedError(
'categorical support is not enabled on DeviceQuantileDMatrix.'
)
self.handle = handle
if qid is not None and group is not None:
raise ValueError(
'Only one of the eval_qid or eval_group for each evaluation '
'dataset should be provided.'
)

self.feature_names = feature_names
self.feature_types = feature_types
Expand Down

0 comments on commit b21a3a4

Please sign in to comment.