Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] reorganize early stopping callback #6114

Merged
merged 8 commits into from Oct 5, 2023
Merged
4 changes: 2 additions & 2 deletions python-package/lightgbm/callback.py
Expand Up @@ -273,8 +273,8 @@ def __init__(
min_delta: Union[float, List[float]] = 0.0
) -> None:

if stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be greater than zero. got: {stopping_rounds}")
if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")

self.order = 30
self.before_iteration = False
Expand Down
11 changes: 11 additions & 0 deletions tests/python_package_test/test_callback.py
Expand Up @@ -21,6 +21,17 @@ def test_early_stopping_callback_is_picklable(serializer):
assert callback.stopping_rounds == rounds


def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"):
lgb.early_stopping(stopping_rounds=0)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"):
lgb.early_stopping(stopping_rounds=-1)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"):
lgb.early_stopping(stopping_rounds="neverrrr")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha thank you, thank you 😂



@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_log_evaluation_callback_is_picklable(serializer):
periods = 42
Expand Down
4 changes: 2 additions & 2 deletions tests/python_package_test/test_engine.py
Expand Up @@ -4501,9 +4501,9 @@ def test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_object

def test_train_raises_informative_error_for_params_of_wrong_type():
X, y = make_synthetic_regression()
params = {"early_stopping_round": "too-many"}
params = {"num_leaves": "too-many"}
dtrain = lgb.Dataset(X, label=y)
with pytest.raises(lgb.basic.LightGBMError, match="Parameter early_stopping_round should be of type int, got \"too-many\""):
with pytest.raises(lgb.basic.LightGBMError, match="Parameter num_leaves should be of type int, got \"too-many\""):
lgb.train(params, dtrain)


Expand Down