Skip to content

Commit

Permalink
FIX Raise error when missing-values encountered in scikit-tree trees (#…
Browse files Browse the repository at this point in the history
…264)

* Implement smoke test

---------

Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed May 6, 2024
1 parent 920a819 commit dace5cf
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 7 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Expand Up @@ -61,6 +61,7 @@ jobs:
run: |
brew install ccache
brew install gcc
brew install gettext
- name: show-gcc
run: |
Expand Down
8 changes: 4 additions & 4 deletions build_requirements.txt
@@ -1,9 +1,9 @@
meson
meson-python
cython>=3.0.8
meson>=1.4.0
meson-python>=0.16.0
cython>=3.0.10
ninja
numpy
scikit-learn>=1.4.1
scikit-learn>=1.4.2
click
rich-click
doit
Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new/v0.8.rst
Expand Up @@ -13,6 +13,12 @@ Version 0.8
Changelog
---------

- |Fix| Previously missing-values in ``X`` input array for sktree estimators
did not raise an error, and silently ran, assuming the missing-values were
encoded as infinity value. This is now fixed, and the estimators will raise an
ValueError if missing-values are encountered in ``X`` input array.
By `Adam Li`_ (:pr:`#264`)

Code and Documentation Contributors
-----------------------------------

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
@@ -1,15 +1,15 @@
[build-system]
build-backend = "mesonpy"
requires = [
"meson-python>=0.15.0",
"meson-python>=0.16.0",
'ninja',
# `wheel` is needed for non-isolated builds, given that `meson-python`
# doesn't list it as a runtime requirement (at least in 0.10.0)
# See https://github.com/FFY00/meson-python/blob/main/pyproject.toml#L4
"wheel",
"setuptools<=65.5",
"packaging",
"Cython>=3.0.8",
"Cython>=3.0.10",
"scikit-learn>=1.4.1",
"scipy>=1.5.0",
"numpy>=1.25; python_version>='3.9'"
Expand Down
4 changes: 4 additions & 0 deletions sktree/tree/_neighbors.py
Expand Up @@ -64,3 +64,7 @@ def compute_similarity_matrix(self, X):
The similarity matrix among the samples.
"""
return compute_forest_similarity_matrix(self, X)

def _more_tags(self):
# XXX: no scikit-tree estimators support NaNs as of now
return {"allow_nan": False}
23 changes: 22 additions & 1 deletion sktree/tree/tests/test_all_trees.py
Expand Up @@ -3,7 +3,7 @@
import pytest
from numpy.testing import assert_almost_equal, assert_array_equal
from sklearn.base import is_classifier
from sklearn.datasets import make_blobs
from sklearn.datasets import load_iris, make_blobs
from sklearn.tree._tree import TREE_LEAF

from sktree.tree import (
Expand Down Expand Up @@ -162,3 +162,24 @@ def test_similarity_matrix(tree):

assert np.allclose(sim_mat, sim_mat.T)
assert np.all((sim_mat.diagonal() == 1))


@pytest.mark.parametrize("tree", ALL_TREES)
def test_missing_values(tree):
"""Smoke test to ensure that correct error is raised when missing values are present.
xref: https://github.com/neurodata/scikit-tree/issues/263
"""
rng = np.random.default_rng(123)

iris_X, iris_y = load_iris(return_X_y=True, as_frame=True)

# Make the feature matrix 25% sparse
iris_X = iris_X.mask(rng.standard_normal(iris_X.shape) < 0.25)

classifier = tree()
with pytest.raises(ValueError, match="Input X contains NaN"):
if tree.__name__.startswith("Unsupervised"):
classifier.fit(iris_X)
else:
classifier.fit(iris_X, iris_y)

0 comments on commit dace5cf

Please sign in to comment.