diff --git a/dask_lightgbm/core.py b/dask_lightgbm/core.py index bf41227..0bef3a9 100644 --- a/dask_lightgbm/core.py +++ b/dask_lightgbm/core.py @@ -63,9 +63,18 @@ def _train_part(params, model_factory, list_of_parts, worker_addresses, return_m # Concatenate many parts into one parts = tuple(zip(*list_of_parts)) - data = concat(parts[0]) - label = concat(parts[1]) - weight = concat(parts[2]) if len(parts) == 3 else None + data = concat(parts[kwargs['parts_list'].index('X')]) + label = concat(parts[kwargs['parts_list'].index('y')]) + weight = concat(parts[kwargs['parts_list'].index('weight')]) if 'weight' in kwargs['parts_list'] else None + valid_X = concat(parts[kwargs['parts_list'].index('valid_X')]) if 'valid_X' in kwargs['parts_list'] else None + valid_y = concat(parts[kwargs['parts_list'].index('valid_y')]) if 'valid_y' in kwargs['parts_list'] else None + eval_sample_weight = concat(parts[kwargs['parts_list'].index('eval_sample_weight')]) if 'eval_sample_weight' in kwargs['parts_list'] else None + # only first eval_set supported + kwargs = kwargs.copy() # avoid contaminating upstream + if valid_X is not None and valid_y is not None: + kwargs['eval_set'] = [(valid_X, valid_y)] + kwargs['eval_sample_weight'] = [eval_sample_weight] + kwargs.pop('parts_list', None) try: model = model_factory(**params) @@ -86,13 +95,39 @@ def _split_to_parts(data, is_matrix): def train(client, data, label, params, model_factory, weight=None, **kwargs): # Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality + tozip = [] + kwargs['parts_list'] = [] data_parts = _split_to_parts(data, is_matrix=True) + tozip.append(data_parts) + kwargs['parts_list'].append('X') label_parts = _split_to_parts(label, is_matrix=False) - if weight is None: - parts = list(map(delayed, zip(data_parts, label_parts))) - else: + tozip.append(label_parts) + kwargs['parts_list'].append('y') + if weight is not None: weight_parts = _split_to_parts(weight, is_matrix=False) - parts = list(map(delayed, zip(data_parts, label_parts, weight_parts))) + tozip.append(weight_parts) + kwargs['parts_list'].append('weight') + + if 'eval_set' in kwargs and kwargs['eval_set'] is not None: + # only support used validation set for now (i.e. first) + valid_data = kwargs['eval_set'][0][0] + valid_label = kwargs['eval_set'][0][1] + valid_data_parts = _split_to_parts(valid_data, is_matrix=True) + valid_label_parts = _split_to_parts(valid_label, is_matrix=False) + tozip.append(valid_data_parts) + kwargs['parts_list'].append('valid_X') + tozip.append(valid_label_parts) + kwargs['parts_list'].append('valid_y') + kwargs.pop('eval_set', None) + if 'eval_sample_weight' in kwargs and kwargs['eval_sample_weight'] is not None and len(kwargs['eval_sample_weight']) > 0: + # only support used validation set for now (i.e. first) + valid_sample_weight = kwargs['eval_sample_weight'][0] + valid_weight_parts = _split_to_parts(valid_sample_weight, is_matrix=False) + tozip.append(valid_weight_parts) + kwargs['parts_list'].append('eval_sample_weight') + kwargs.pop('eval_sample_weight', None) + + parts = list(map(delayed, zip(*tuple(tozip)))) # Start computation in the background parts = client.compute(parts)