Skip to content

Commit

Permalink
fully support (one) validation set for early stopping to be handled
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Dec 30, 2020
1 parent c19000b commit 8cc8e83
Showing 1 changed file with 42 additions and 7 deletions.
49 changes: 42 additions & 7 deletions dask_lightgbm/core.py
Expand Up @@ -63,9 +63,18 @@ def _train_part(params, model_factory, list_of_parts, worker_addresses, return_m

# Concatenate many parts into one
parts = tuple(zip(*list_of_parts))
data = concat(parts[0])
label = concat(parts[1])
weight = concat(parts[2]) if len(parts) == 3 else None
data = concat(parts[kwargs['parts_list'].index('X')])
label = concat(parts[kwargs['parts_list'].index('y')])
weight = concat(parts[kwargs['parts_list'].index('weight')]) if 'weight' in kwargs['parts_list'] else None
valid_X = concat(parts[kwargs['parts_list'].index('valid_X')]) if 'valid_X' in kwargs['parts_list'] else None
valid_y = concat(parts[kwargs['parts_list'].index('valid_y')]) if 'valid_y' in kwargs['parts_list'] else None
eval_sample_weight = concat(parts[kwargs['parts_list'].index('eval_sample_weight')]) if 'eval_sample_weight' in kwargs['parts_list'] else None
# only first eval_set supported
kwargs = kwargs.copy() # avoid contaminating upstream
if valid_X is not None and valid_y is not None:
kwargs['eval_set'] = [(valid_X, valid_y)]
kwargs['eval_sample_weight'] = [eval_sample_weight]
kwargs.pop('parts_list', None)

try:
model = model_factory(**params)
Expand All @@ -86,13 +95,39 @@ def _split_to_parts(data, is_matrix):

def train(client, data, label, params, model_factory, weight=None, **kwargs):
# Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality
tozip = []
kwargs['parts_list'] = []
data_parts = _split_to_parts(data, is_matrix=True)
tozip.append(data_parts)
kwargs['parts_list'].append('X')
label_parts = _split_to_parts(label, is_matrix=False)
if weight is None:
parts = list(map(delayed, zip(data_parts, label_parts)))
else:
tozip.append(label_parts)
kwargs['parts_list'].append('y')
if weight is not None:
weight_parts = _split_to_parts(weight, is_matrix=False)
parts = list(map(delayed, zip(data_parts, label_parts, weight_parts)))
tozip.append(weight_parts)
kwargs['parts_list'].append('weight')

if 'eval_set' in kwargs and kwargs['eval_set'] is not None:
# only support used validation set for now (i.e. first)
valid_data = kwargs['eval_set'][0][0]
valid_label = kwargs['eval_set'][0][1]
valid_data_parts = _split_to_parts(valid_data, is_matrix=True)
valid_label_parts = _split_to_parts(valid_label, is_matrix=False)
tozip.append(valid_data_parts)
kwargs['parts_list'].append('valid_X')
tozip.append(valid_label_parts)
kwargs['parts_list'].append('valid_y')
kwargs.pop('eval_set', None)
if 'eval_sample_weight' in kwargs and kwargs['eval_sample_weight'] is not None and len(kwargs['eval_sample_weight']) > 0:
# only support used validation set for now (i.e. first)
valid_sample_weight = kwargs['eval_sample_weight'][0]
valid_weight_parts = _split_to_parts(valid_sample_weight, is_matrix=False)
tozip.append(valid_weight_parts)
kwargs['parts_list'].append('eval_sample_weight')
kwargs.pop('eval_sample_weight', None)

parts = list(map(delayed, zip(*tuple(tozip))))

# Start computation in the background
parts = client.compute(parts)
Expand Down

0 comments on commit 8cc8e83

Please sign in to comment.