From 798dc1d4191b93fd34797d62b79c66cd95209406 Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Fri, 29 Oct 2021 07:25:22 +0300 Subject: [PATCH] [tests] [python] add test for non-serializable callback (#4741) --- tests/python_package_test/test_sklearn.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 6b1ac8a9f3d..152757c7963 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -32,6 +32,14 @@ decreasing_generator = itertools.count(0, -1) +class UnpicklableCallback: + def __reduce__(self): + raise Exception("This class in not picklable") + + def __call__(self, env): + env.model.set_attr(attr_set_inside_callback=str(env.iteration * 10)) + + def custom_asymmetric_obj(y_true, y_pred): residual = (y_true - y_pred).astype(np.float64) grad = np.where(residual < 0, -2 * 10.0 * residual, -2 * residual) @@ -427,6 +435,18 @@ def test_joblib(): np.testing.assert_allclose(pred_origin, pred_pickle) +def test_non_serializable_objects_in_callbacks(tmp_path): + unpicklable_callback = UnpicklableCallback() + + with pytest.raises(Exception, match="This class in not picklable"): + joblib.dump(unpicklable_callback, tmp_path / 'tmp.joblib') + + X, y = load_boston(return_X_y=True) + gbm = lgb.LGBMRegressor(n_estimators=5) + gbm.fit(X, y, callbacks=[unpicklable_callback]) + assert gbm.booster_.attr('attr_set_inside_callback') == '40' + + def test_random_state_object(): X, y = load_iris(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)