Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add early stopping in dart mode #4805

Closed
benwu232 opened this issue Nov 16, 2021 · 10 comments
Closed

add early stopping in dart mode #4805

benwu232 opened this issue Nov 16, 2021 · 10 comments
Labels

Comments

@benwu232
Copy link

I checed #1893 and #1895. It is said that early stopping is disabled in dart mode. The question is I don't know when to stop training in dart mode. Is it possible to add early stopping in dart mode? or is there any way found best model in dart mode?

Thanks a lot for the brilliant lgb!

@jameslamb
Copy link
Collaborator

Thanks for using LightGBM and for your question!

Per #1893 (comment)

I think early stopping and dart cannot be used together.
The reason is when using dart, the previous trees will be updated.

You can learn more about DART in the original DART paper (link), especially the section "Description of the DART Algorithm".

any way found best model in dart mode

One way to do this is to use hyperparameter tuning over parameter num_iterations (number of trees to create), limiting the model complexity by setting conservative values of num_leaves, max_depth, and min_data_in_leaf.

Maybe others will be able to propose alternative ideas. If you are sure you want to use early stopping even knowing that it will produce unstable results because of how DART works, I think it's possible by providing a custom callback function in the Python API. But I don't have time at the moment to create an example of that and I'm not certain which examples / docs to point you to.

@benwu232
Copy link
Author

Thanks a lot for the suggestions!
I'll try those ideas.
Will feedback if get any useful thing.

@no-response
Copy link

no-response bot commented Dec 17, 2021

This issue has been automatically closed because it has been awaiting a response for too long. When you have time to to work with the maintainers to resolve this issue, please post a new comment and it will be re-opened. If the issue has been locked for editing by the time you return to it, please open a new issue and reference this one. Thank you for taking the time to improve LightGBM!

@no-response no-response bot closed this as completed Dec 17, 2021
@cheongalc
Copy link

cheongalc commented Apr 20, 2023

Hello @jameslamb and @benwu232,

With reference to this:

If you are sure you want to use early stopping even knowing that it will produce unstable results because of how DART works, I think it's possible by providing a custom callback function in the Python API.

I have referenced this kaggle notebook to write an early stopping and model checkpoint function for DART.

def dart_early_stopping(save_path: str, stopping_rounds: int):
    best_score = None
    best_iteration = -1
    counter = 0
    
    def _callback(env):
        nonlocal best_score
        nonlocal best_iteration
        nonlocal counter
        if env.evaluation_result_list is not None:
            # Format of example entry in env.evaluation_result_list:
            # ('valid_0', 'binary_logloss', 0.2946547510575049, False)
            # Index 0 = Result group name, in this case valid_0
            # Index 1 = Metric name, in this case binary_logloss
            # Index 2 = Numeric score, in this case 0.2946547510575049
            # Index 3 = Higher is Better, in this case False because lower loss is better
            score = env.evaluation_result_list[0][2]
            higher_is_better = 1 if env.evaluation_result_list[0][3] else -1

            if best_score is None or higher_is_better * best_score < higher_is_better * score:
                # Best score improved
                counter = 0
                print(f"\tBest score improved from {best_score} to {score}, saving model at iteration {env.iteration}")
                best_score = score
                best_iteration = env.iteration
                env.model.save_model(save_path, env.iteration)
                return

            counter += 1
            if counter >= stopping_rounds:
                print(f"\tBest score did not improve for {stopping_rounds} iterations, early stopping at iteration {env.iteration}")
                payload = [
                    (
                        env.evaluation_result_list[0][0],
                        env.evaluation_result_list[0][1],
                        best_score,
                        higher_is_better
                    )
                ]
                raise lgbm.callback.EarlyStopException(best_iteration=best_iteration, best_score=payload)
    return _callback

This is how I used it:

clf = lgbm.LGBMClassifier(
    boosting_type='dart',
    num_leaves=31,
    max_depth=10,
    learning_rate=0.003,
    n_estimators=1000,
    objective='binary',
    random_state=42,
    importance_type='gain',
    metric=None
)

clf.fit(
    x_train, 
    y_train,
    eval_set=[(x_val, y_val)],
    feature_name='auto',
    callbacks=[dart_early_stopping('model_bestscore.txt', 50)]
)

Note that in my setup, I only have one evaluation metric and one evaluation set.
Thus, my env.evaluation_result_list looks like this:

[ ('valid_0', 'binary_logloss', 0.2946547510575049, False) ]

The callback above is designed to work only when env.evaluation_result_list has this format.

Extending functionality

Bear in mind how the env.evaluation_result_list changes, so that you grab the correct index and monitor the correct metric.

Case 1: More than one evaluation metric

clf.fit(
    x_train, 
    y_train,
    eval_set=[(x_val, y_val)],
    eval_metric=['auc', 'logloss']
    feature_name='auto',
    callbacks=[dart_early_stopping('model_bestscore.txt', 50)]
)
[ ('valid_0', 'auc', 0.6404208382827083, True), ('valid_0', 'binary_logloss', 0.2946547510575049, False) ]

Case 2: More than one evaluation set

clf.fit(
    x_train, 
    y_train,
    eval_set=[(x_train, y_train), (x_val, y_val)],
    feature_name='auto',
    callbacks=[dart_early_stopping('model_bestscore.txt', 50)]
)
[ ('training', 'binary_logloss', 0.2943494874886071, False), ('valid_1', 'binary_logloss', 0.2946547510575049, False) ] 

Case 3: More than one evaluation metric and more than one evaluation set

clf.fit(
    x_train, 
    y_train,
    eval_set=[(x_train, y_train), (x_val, y_val)],
    eval_metric=['auc', 'logloss']
    feature_name='auto',
    callbacks=[dart_early_stopping('model_bestscore.txt', 50)]
)
[ 
 ('training', 'auc', 0.7868969760983039, True), 
 ('training', 'binary_logloss', 0.2943494874886071, False), 
 ('valid_1', 'auc', 0.6404208382827083, True), 
 ('valid_1', 'binary_logloss', 0.2946547510575049, False)
]

Thank you!

@3zhang
Copy link

3zhang commented May 18, 2023

You shouldn't use dart with early stopping. In dart mode, you should expect the validation error to be roughly monotonically decreasing (it will always has some fluctuations but the overall trend should be decreasing). I think one major advantage of dart is to get rid of the tuning of boosting rounds which in some cases is very sensitive and contributes to high variance.

@34j
Copy link

34j commented May 19, 2023

It seems that there are actually cases where early stopping is needed in dart mode. (Below is the metric plot when the diabetes dataset was trained with LGBMRegressor(boosting_type="dart", n_estimators=1000).) I have created https://github.com/34j/lightgbm-callbacks for this purpose, which may be used if copying the cheongalc code is bothersome and my code does not have a bug.

image

@3zhang
Copy link

3zhang commented May 21, 2023

It seems that there are actually cases where early stopping is needed in dart mode. (Below is the metric plot when the diabetes dataset was trained with LGBMRegressor(boosting_type="dart", n_estimators=1000).) I have created https://github.com/34j/lightgbm-callbacks for this purpose, which may be used if copying the cheongalc code is bothersome and my code does not have a bug.

image

Well sometimes you cannot prevent overfitting by simply using default dart. That's why there are several hyperparameters for dart. For your example, I just decrease the skip_drop and then the the overfitting seems to be gone.

dart_1
dart_2

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
import lightgbm as lgb

X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train)

params = {'objective':'l2',
        'boosting': 'dart',
        'seed':1}

ds_train = lgb.Dataset(X_train, label=y_train)
ds_test = lgb.Dataset(X_test, label=y_test, reference=ds_train)

val_his = {}
model = lgb.train(params, ds_train, valid_sets=[ds_train, ds_test], 
                  callbacks = [lgb.record_evaluation(val_his)],
                  num_boost_round = 1000)
lgb.plot_metric(val_his, ylim=(0,6000), title='default skip_drop=0.5')

params['skip_drop'] = 0.3
val_his = {}
model = lgb.train(params, ds_train, valid_sets=[ds_train, ds_test], 
                  callbacks = [lgb.record_evaluation(val_his)],
                  num_boost_round = 1000)
lgb.plot_metric(val_his, ylim=(0,6000), title='skip_drop=0.3')

@34j
Copy link

34j commented May 21, 2023

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
import lightgbm as lgb

X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train)

params = {'objective':'l2',
'boosting': 'dart',
'seed':1}

ds_train = lgb.Dataset(X_train, label=y_train)
ds_test = lgb.Dataset(X_test, label=y_test, reference=ds_train)

val_his = {}
model = lgb.train(params, ds_train, valid_sets=[ds_train, ds_test],
callbacks = [lgb.record_evaluation(val_his)],
num_boost_round = 1000)
lgb.plot_metric(val_his, ylim=(0,6000), title='default skip_drop=0.5')

params['skip_drop'] = 0.3
val_his = {}
model = lgb.train(params, ds_train, valid_sets=[ds_train, ds_test],
callbacks = [lgb.record_evaluation(val_his)],
num_boost_round = 1000)
lgb.plot_metric(val_his, ylim=(0,6000), title='skip_drop=0.3')

I don't know much about dart, but I think the reason your example seems to work is because of the slower convergence. Here is an example of your script changing n_iter to 10000.

default
skip_drop=0 3

By the way, for some reason, the results were different when using GPU.

default
skip_drop=0 3

@3zhang
Copy link

3zhang commented May 21, 2023

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
import lightgbm as lgb

X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train)

params = {'objective':'l2',
'boosting': 'dart',
'seed':1}

ds_train = lgb.Dataset(X_train, label=y_train)
ds_test = lgb.Dataset(X_test, label=y_test, reference=ds_train)

val_his = {}
model = lgb.train(params, ds_train, valid_sets=[ds_train, ds_test],
callbacks = [lgb.record_evaluation(val_his)],
num_boost_round = 1000)
lgb.plot_metric(val_his, ylim=(0,6000), title='default skip_drop=0.5')

params['skip_drop'] = 0.3
val_his = {}
model = lgb.train(params, ds_train, valid_sets=[ds_train, ds_test],
callbacks = [lgb.record_evaluation(val_his)],
num_boost_round = 1000)
lgb.plot_metric(val_his, ylim=(0,6000), title='skip_drop=0.3')

Yes. It might ultimately overfit. My point is that by tuning dart parameters you don't need to care too much about number of iterations. You can simply set it to a moderately large number and that's it.

@github-actions
Copy link

This issue has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Aug 23, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

6 participants