Skip to content

Commit

Permalink
Fix unsupervised fitting
Browse files Browse the repository at this point in the history
This fix allows fitting unsupervised estimators with the assumption that
they will always predict to shape (n_samples,).

Output dtype is now determined based on the `_estimator_type` attribute.
This is likely a temporary solution as `_estimator_type` is planned for
deprecation in favor of tags and explicit estimator type checking
functions, but neither of those solutions are fully implemented yet.

See scikit-learn/scikit-learn#28960
  • Loading branch information
aazuspan committed May 17, 2024
1 parent fc1bdbf commit 29195f7
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 8 deletions.
15 changes: 9 additions & 6 deletions src/sknnr_spatial/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def _reset_estimator(estimator: EstimatorType) -> EstimatorType:

return estimator

def _get_n_targets(self, y: np.ndarray | pd.DataFrame | pd.Series) -> int:
def _get_n_targets(self, y: np.ndarray | pd.DataFrame | pd.Series | None) -> int:
"""Get the number of targets used to fit the estimator."""
if y.ndim == 1:
# Unsupervised and single-output estimators should both return a single target
if y is None or y.ndim == 1:
return 1

return y.shape[-1]
Expand Down Expand Up @@ -104,10 +105,12 @@ def fit(self, X, y=None, **kwargs) -> ImageEstimator[EstimatorType]:
self : ImageEstimator
The wrapper around the fitted estimator.
"""
# Squeeze extra y dimensions. This will convert from shape (n_samples, 1) which
# causes inconsistent output shapes with different sklearn estimators, to
# (n_samples,), which has a consistent output shape.
y = y.squeeze()
if y is not None:
# Squeeze extra y dimensions. This will convert from shape (n_samples, 1)
# which causes inconsistent output shapes with different sklearn estimators,
# to (n_samples,), which has a consistent output shape.
y = y.squeeze()

self._wrapped = self._wrapped.fit(X, y, **kwargs)

self._wrapped_meta = FittedMetadata(
Expand Down
14 changes: 13 additions & 1 deletion src/sknnr_spatial/image/_dask_backed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING

import dask.array as da
import numpy as np
from sklearn.utils.validation import check_is_fitted

from ..types import DaskBackedType
Expand All @@ -16,6 +17,12 @@
from .dataarray import DataArrayPreprocessor
from .dataset import DatasetPreprocessor

ESTIMATOR_OUTPUT_DTYPES: dict[str, np.dtype] = {
"classifier": np.int32,
"clusterer": np.int32,
"regressor": np.float64,
}


class DaskBackedWrapper(ImageWrapper[DaskBackedType]):
"""A wrapper around a Dask-backed image that provides sklearn methods."""
Expand All @@ -39,12 +46,17 @@ def predict(
signature = "(x)->(y)"
output_sizes = {"y": meta.n_targets}

# Any estimator with an undefined type should fall back to floating
# point for safety.
estimator_type = getattr(estimator, "_estimator_type", "")
output_dtype = ESTIMATOR_OUTPUT_DTYPES.get(estimator_type, np.float64)

y_pred = da.apply_gufunc(
estimator._wrapped.predict,
signature,
self.preprocessor.flat,
axis=self.preprocessor.flat_band_dim,
output_dtypes=[float],
output_dtypes=[output_dtype],
output_sizes=output_sizes,
allow_rechunk=True,
)
Expand Down
38 changes: 37 additions & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import xarray as xr
from numpy.testing import assert_array_equal
from sklearn.base import clone
from sklearn.cluster import AffinityPropagation, KMeans, MeanShift
from sklearn.ensemble import RandomForestRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.neighbors import KNeighborsRegressor, NearestNeighbors
from sklearn.utils.validation import NotFittedError

from sknnr_spatial import wrap
Expand Down Expand Up @@ -45,6 +46,22 @@ def test_predict(dummy_model_data, image_type, estimator, single_output, squeeze
assert_array_equal(y_pred.shape, expected_shape)


@parametrize_image_types
@pytest.mark.parametrize("estimator", [KMeans, MeanShift, AffinityPropagation])
def test_predict_unsupervised(dummy_model_data, image_type, estimator):
"""Test that predict works with all image types with unsupervised estimators."""
X_image, X, _ = dummy_model_data

estimator = wrap(estimator()).fit(X)

X_wrapped = wrap_image(X_image, type=image_type.cls)
y_pred = unwrap_image(estimator.predict(X_wrapped))

assert y_pred.ndim == 3
expected_shape = (X_image.shape[0], X_image.shape[1], 1)
assert_array_equal(y_pred.shape, expected_shape)


@parametrize_image_types
@pytest.mark.parametrize("k", [1, 3], ids=lambda k: f"k{k}")
def test_kneighbors_with_distance(dummy_model_data, image_type, k):
Expand Down Expand Up @@ -80,6 +97,25 @@ def test_kneighbors_without_distance(dummy_model_data, image_type, k):
assert_array_equal(nn.shape, (X_image.shape[0], X_image.shape[1], k))


@parametrize_image_types
@pytest.mark.parametrize("k", [1, 3], ids=lambda k: f"k{k}")
def test_kneighbors_unsupervised(dummy_model_data, image_type, k):
"""Test kneighbors works with all image types when unsupervised."""
X_image, X, _ = dummy_model_data
estimator = wrap(NearestNeighbors(n_neighbors=k)).fit(X)

X_wrapped = wrap_image(X_image, type=image_type.cls)
dist, nn = estimator.kneighbors(X_wrapped, return_distance=True)
dist = unwrap_image(dist)
nn = unwrap_image(nn)

assert dist.ndim == 3
assert nn.ndim == 3

assert_array_equal(dist.shape, (X_image.shape[0], X_image.shape[1], k))
assert_array_equal(nn.shape, (X_image.shape[0], X_image.shape[1], k))


def test_predict_dataarray_with_custom_dim_name(dummy_model_data):
"""Test that predict works if the band dimension is not named "variable"."""
X_image, X, y = dummy_model_data
Expand Down

0 comments on commit 29195f7

Please sign in to comment.