Skip to content

Commit

Permalink
tests: add a test suite for sklearn objects
Browse files Browse the repository at this point in the history
Add the test recipe from #155 to demonstrate that
sklearn DecisionTreeClassifier objects can be properly
serialized.

This behavior was originally fixed in #170.

Closes #155
Signed-off-by: David Aguilar <davvid@gmail.com>
  • Loading branch information
davvid committed Jan 18, 2021
1 parent 86b87d1 commit 0df93f8
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements-dev.txt
Expand Up @@ -13,6 +13,7 @@ pytest-black-multipy
pytest-cov
pytest-flake8
simplejson
sklearn
sqlalchemy
ujson; python_version < '3.8'
yajl; sys_platform != 'win32' and python_version < '3.8'
1 change: 1 addition & 0 deletions setup.cfg
Expand Up @@ -57,6 +57,7 @@ testing =
numpy
pandas
pymongo
sklearn
sqlalchemy

testing.libs =
Expand Down
57 changes: 57 additions & 0 deletions tests/sklearn_test.py
@@ -0,0 +1,57 @@
from __future__ import absolute_import, division, unicode_literals

import pytest

try:
import numpy as np
from sklearn.tree import DecisionTreeClassifier
except ImportError:
pytest.skip('sklearn is not available', allow_module_level=True)

import jsonpickle
import jsonpickle.ext.numpy


@pytest.fixture(scope='module', autouse=True)
def numpy_extension():
"""Initialize the numpy extension for this test module"""
jsonpickle.ext.numpy.register_handlers()
yield # control to the test function.
jsonpickle.ext.numpy.unregister_handlers()


def test_decision_tree():
# Create data.
np.random.seed(13)
x_values = np.random.randint(low=0, high=10, size=12)
x = x_values.reshape(4, 3)
y_values = np.random.randint(low=0, high=2, size=4)
y = y_values.reshape(-1, 1)

# train model
classifier = DecisionTreeClassifier(max_depth=1)
classifier.fit(x, y)

# freeze and thaw
pickler = jsonpickle.pickler.Pickler()
unpickler = jsonpickle.unpickler.Unpickler()
actual = unpickler.restore(pickler.flatten(classifier))

assert isinstance(actual, classifier.__class__)
if hasattr(classifier, 'tree_'):
assert isinstance(actual.tree_, classifier.tree_.__class__)

# predict from thawed
array_values = np.array([1, 2, 3])
array = array_values.reshape(1, -1)
prediction = actual.predict(array)
assert prediction[0] == 1

assert actual.max_depth == classifier.max_depth
assert actual.score(x, y) == classifier.score(x, y)
if hasattr(classifier, 'get_depth'):
assert actual.get_depth() == classifier.get_depth()


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 0df93f8

Please sign in to comment.