From b0265c05edc77987933255f75e453e97c6ddeeb1 Mon Sep 17 00:00:00 2001 From: Christopher Jenness Date: Thu, 10 Aug 2017 05:41:10 -0400 Subject: [PATCH] Finish initial model selection testing --- tests/test_modelselection.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_modelselection.py b/tests/test_modelselection.py index 6610f83..0beee8f 100644 --- a/tests/test_modelselection.py +++ b/tests/test_modelselection.py @@ -62,3 +62,28 @@ def test_error_cross_entropy_error(): predictions = np.array([0.0, 0.0, 1.0, 1.0, 1.0, 1.0]) error = modelselection.Error.cross_entropy_error(y, predictions) assert np.isclose(error, 6.912757780650054) + + +def test_k_fold_generator(): + np.random.seed(10) + splitter = modelselection.k_fold_generator(20) + split = splitter.next() + assert len(split[0]) == 18 + assert len(split[1]) == 2 + + +def test_k_fold_generator_odd(): + np.random.seed(10) + splitter = modelselection.k_fold_generator(21) + split = splitter.next() + assert len(split[0]) == 19 + assert len(split[1]) == 2 + + +def test_test_train_splitter(): + X, y = data.categorical_2Dmatrix_data_big() + X_train, X_test, y_train, y_test = modelselection.test_train_splitter(X, y) + assert X_train.shape == (9, 2) + assert X_test.shape == (2, 2) + assert len(y_train) == 9 + assert len(y_test) == 2