Skip to content
This repository has been archived by the owner on Jul 16, 2021. It is now read-only.

Commit

Permalink
Add sample_weight and sample_weight_eval_set (#74)
Browse files Browse the repository at this point in the history
* Added sample_weight and sample_weight_eval_set.
  • Loading branch information
Mike McCarty committed Jul 27, 2020
1 parent 9fd6362 commit 34f9b66
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 99 deletions.
19 changes: 11 additions & 8 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jobs:
docker:
- image: continuumio/miniconda:latest
environment:
PYTHON: "3.6"
PYTHON: "3.7"
steps:
- checkout
- run:
Expand All @@ -14,20 +14,23 @@ jobs:
conda update -q conda
conda install conda-build anaconda-client --yes
conda config --add channels conda-forge
conda create -q -n test-environment python=${PYTHON}
source activate test-environment
conda install -c defaults -c conda-forge -q coverage flake8 pytest pytest-cov pytest-xdist numpy pandas xgboost dask distributed scikit-learn sparse scipy 'pytest-asyncio>=0.10.0'
conda env create -f ci/environment-${PYTHON}.yaml
source activate dask-xgboost-test
pip install --no-deps -e .
conda list test-environment
conda list dask-xgboost-test
- run:
# TODO: Check on the conda-forge recipe for why this is nescessary
command: |
source activate test-environment
source activate dask-xgboost-test
- run:
command: |
source activate test-environment
source activate dask-xgboost-test
pytest -v -s dask_xgboost
- run:
command: |
source activate test-environment
source activate dask-xgboost-test
flake8 dask_xgboost
- run:
command: |
source activate dask-xgboost-test
black .
16 changes: 16 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
repos:
- repo: https://github.com/python/black
rev: 19.10b0
hooks:
- id: black
language_version: python3.7
- repo: https://gitlab.com/pycqa/flake8
rev: 3.7.9
hooks:
- id: flake8
language_version: python3.7
- repo: https://github.com/pre-commit/mirrors-isort
rev: v4.3.21
hooks:
- id: isort

31 changes: 31 additions & 0 deletions ci/environment-3.7.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: dask-xgboost-test
channels:
- conda-forge
- defaults
dependencies:
- black
- coverage
- codecov
- dask
- dask-glm >=0.2.0
- distributed
- flake8
- isort==4.3.21
- multipledispatch >=0.4.9
- mypy
- numba
- numpy >=1.16.3
- numpydoc
- packaging
- pandas
- psutil
- pytest
- pytest-cov
- pytest-mock
- pytest-xdist
- python=3.7.*
- scikit-learn>=0.23.0
- scipy
- sparse
- toolz
- xgboost=0.90
119 changes: 70 additions & 49 deletions dask_xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

try:
import sparse
import scipy.sparse as ss
except ImportError:
sparse = False
ss = False

try:
import scipy.sparse as ss
except ImportError:
ss = False

logger = logging.getLogger(__name__)

Expand All @@ -36,7 +38,7 @@ def parse_host_port(address):
def start_tracker(host, n_workers):
""" Start Rabit tracker """
if host is None:
host = get_host_ip('auto')
host = get_host_ip("auto")
env = {"DMLC_NUM_WORKER": n_workers}
rabit = RabitTracker(hostIP=host, nslave=n_workers)
env.update(rabit.slave_envs())
Expand All @@ -54,7 +56,7 @@ def concat(L):
return np.concatenate(L, axis=0)
elif isinstance(L[0], (pd.DataFrame, pd.Series)):
return pd.concat(L, axis=0)
elif ss and isinstance(L[0], ss.spmatrix):
elif ss and isinstance(L[0], ss.csr_matrix):
return ss.vstack(L, format="csr")
elif sparse and isinstance(L[0], sparse.SparseArray):
return sparse.concatenate(L, axis=0)
Expand Down Expand Up @@ -86,14 +88,16 @@ def train_part(
-------
model if rank zero, None otherwise
"""
data, labels = zip(*list_of_parts) # Prepare data
data, labels, sample_weight = zip(*list_of_parts) # Prepare data
data = concat(data) # Concatenate many parts into one
labels = concat(labels)
sample_weight = concat(sample_weight) if np.all(sample_weight) else None

if dmatrix_kwargs is None:
dmatrix_kwargs = {}

dmatrix_kwargs["feature_names"] = getattr(data, "columns", None)
dtrain = xgb.DMatrix(data, labels, **dmatrix_kwargs)
dtrain = xgb.DMatrix(data, labels, weight=sample_weight, **dmatrix_kwargs)

evals = _package_evals(
eval_set,
Expand Down Expand Up @@ -123,32 +127,32 @@ def train_part(
return result, evals_result


def _package_evals(
eval_set, sample_weight_eval_set=None, missing=None, n_jobs=None
):
def _package_evals(eval_set, sample_weight_eval_set=None, missing=None, n_jobs=None):
if eval_set is not None:
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
evals = list(
xgb.DMatrix(
data,
label=label,
missing=missing,
weight=weight,
nthread=n_jobs,
)
for ((data, label), weight) in zip(
eval_set, sample_weight_eval_set
data, label=label, missing=missing, weight=weight, nthread=n_jobs,
)
for ((data, label), weight) in zip(eval_set, sample_weight_eval_set)
)
evals = list(
zip(evals, ["validation_{}".format(i) for i in range(len(evals))])
)
evals = list(zip(evals, ["validation_{}".format(i) for i in range(len(evals))]))
else:
evals = ()
return evals


def _has_dask_collections(list_of_collections, message):
list_of_collections = list_of_collections or []
if any(
is_dask_collection(collection)
for collections in list_of_collections
for collection in collections
):
raise TypeError(message)


@gen.coroutine
def _train(
client,
Expand All @@ -157,6 +161,7 @@ def _train(
labels,
dmatrix_kwargs={},
evals_result=None,
sample_weight=None,
**kwargs
):
"""
Expand All @@ -175,42 +180,52 @@ def _train(
if isinstance(label_parts, np.ndarray):
assert label_parts.ndim == 1 or label_parts.shape[1] == 1
label_parts = label_parts.flatten().tolist()
if sample_weight is not None:
sample_weight_parts = sample_weight.to_delayed()
if isinstance(sample_weight_parts, np.ndarray):
assert sample_weight_parts.ndim == 1 or sample_weight_parts.shape[1] == 1
sample_weight_parts = sample_weight_parts.flatten().tolist()
else:
# If sample_weight is None construct a list of Nones to keep
# the structure of parts consistent.
sample_weight_parts = [None] * len(data_parts)

# Check that data, labels, and sample_weights are the same length
lists = [data_parts, label_parts, sample_weight_parts]
if len(set([len(l) for l in lists])) > 1:
raise ValueError(
"data, label, and sample_weight parts/chunks must have same length."
)

# Arrange parts into pairs. This enforces co-locality
parts = list(map(delayed, zip(data_parts, label_parts)))
# Arrange parts into triads. This enforces co-locality
parts = list(map(delayed, zip(data_parts, label_parts, sample_weight_parts)))
parts = client.compute(parts) # Start computation in the background
yield wait(parts)

for part in parts:
if part.status == "error":
yield part # trigger error locally

if kwargs.get("eval_set"):
if any(
is_dask_collection(e)
for evals in kwargs.get("eval_set")
for e in evals
):
raise TypeError(
"Evaluation set must not contain dask collections."
)
_has_dask_collections(
kwargs.get("eval_set", []), "Evaluation set must not contain dask collections."
)
_has_dask_collections(
kwargs.get("sample_weight_eval_set", []),
"Sample weight evaluation set must not contain dask collections.",
)

# Because XGBoost-python doesn't yet allow iterative training, we need to
# find the locations of all chunks and map them to particular Dask workers
key_to_part_dict = dict([(part.key, part) for part in parts])
who_has = yield client.scheduler.who_has(
keys=[part.key for part in parts]
)
who_has = yield client.scheduler.who_has(keys=[part.key for part in parts])
worker_map = defaultdict(list)
for key, workers in who_has.items():
worker_map[first(workers)].append(key_to_part_dict[key])

ncores = yield client.scheduler.ncores() # Number of cores per worker

# Start the XGBoost tracker on the Dask scheduler
env = yield client._run_on_scheduler(start_tracker,
None,
len(worker_map))
env = yield client._run_on_scheduler(start_tracker, None, len(worker_map))

# Tell each worker to train on the chunks/parts that it has locally
futures = [
Expand Down Expand Up @@ -246,6 +261,7 @@ def train(
labels,
dmatrix_kwargs={},
evals_result=None,
sample_weight=None,
**kwargs
):
""" Train an XGBoost model on a Dask Cluster
Expand All @@ -264,6 +280,8 @@ def train(
evals_result: dict, optional
Stores the evaluation result history of all the items in the eval_set
by mutating evals_result in place.
sample_weight : array_like, optional
instance weights
**kwargs: Keywords to give to XGBoost train
Examples
Expand All @@ -287,6 +305,7 @@ def train(
labels,
dmatrix_kwargs,
evals_result,
sample_weight,
**kwargs
)

Expand Down Expand Up @@ -341,14 +360,10 @@ def predict(client, model, data):
num_class = int(num_class)

if num_class > 2:
kwargs = dict(
drop_axis=None, chunks=(data.chunks[0], (num_class,))
)
kwargs = dict(drop_axis=None, chunks=(data.chunks[0], (num_class,)))
else:
kwargs = dict(drop_axis=1)
result = data.map_blocks(
_predict_part, model=model, dtype=np.float32, **kwargs
)
result = data.map_blocks(_predict_part, model=model, dtype=np.float32, **kwargs)
else:
model = model.result() # Future to concrete
if not isinstance(data, xgb.DMatrix):
Expand All @@ -364,6 +379,7 @@ def fit(
X,
y=None,
eval_set=None,
sample_weight=None,
sample_weight_eval_set=None,
eval_metric=None,
early_stopping_rounds=None,
Expand All @@ -388,6 +404,8 @@ def fit(
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
sample_weight : array_like, optional
instance weights
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list
of instance weights on the i-th validation set.
Expand Down Expand Up @@ -436,6 +454,8 @@ def fit(
y,
num_boost_round=self.n_estimators,
eval_set=eval_set,
sample_weight=sample_weight,
sample_weight_eval_set=sample_weight_eval_set,
missing=self.missing,
n_jobs=self.n_jobs,
early_stopping_rounds=early_stopping_rounds,
Expand All @@ -460,6 +480,7 @@ def fit(
y=None,
classes=None,
eval_set=None,
sample_weight=None,
sample_weight_eval_set=None,
eval_metric=None,
early_stopping_rounds=None,
Expand All @@ -476,6 +497,8 @@ def fit(
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
sample_weight : array_like, optional
instance weights
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list
of instance weights on the i-th validation set.
Expand Down Expand Up @@ -518,8 +541,7 @@ def fit(
-----
This differs from the XGBoost version in three ways
1. The ``sample_weight`` and ``verbose`` fit kwargs are not
supported.
1. The ``verbose`` fit kwargs are not supported.
2. The labels are not automatically label-encoded
3. The ``classes_`` and ``n_classes_`` attributes are not learned
"""
Expand Down Expand Up @@ -558,7 +580,6 @@ def fit(

# TODO: auto label-encode y
# that will require a dependency on dask-ml
# TODO: sample weight

self.evals_result_ = {}
self._Booster = train(
Expand All @@ -568,6 +589,8 @@ def fit(
y,
num_boost_round=self.n_estimators,
eval_set=eval_set,
sample_weight=sample_weight,
sample_weight_eval_set=sample_weight_eval_set,
missing=self.missing,
n_jobs=self.n_jobs,
early_stopping_rounds=early_stopping_rounds,
Expand All @@ -592,8 +615,6 @@ def predict(self, X):
def predict_proba(self, data, ntree_limit=None):
client = default_client()
if ntree_limit is not None:
raise NotImplementedError(
"'ntree_limit' is not currently " "supported."
)
raise NotImplementedError("'ntree_limit' is not currently " "supported.")
class_probs = predict(client, self._Booster, data)
return class_probs
Loading

0 comments on commit 34f9b66

Please sign in to comment.