Skip to content

Commit

Permalink
ENH Adds pandas IntegerArray support to check_array (scikit-learn#16508)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored and gio8tisu committed May 15, 2020
1 parent 752a4ba commit 6d383d0
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 34 deletions.
20 changes: 20 additions & 0 deletions doc/whats_new/v0.23.rst
Expand Up @@ -316,6 +316,10 @@ Changelog
``max_value`` and ``min_value``. Array-like inputs allow a different max and min to be specified
for each feature. :pr:`16403` by :user:`Narendra Mukherjee <narendramukherjee>`.

- |Enhancement| :class:`impute.SimpleImputer`, :class:`impute.KNNImputer`, and
:class:`impute.SimpleImputer` accepts pandas' nullable integer dtype with
missing values. :pr:`16508` by `Thomas Fan`_.

:mod:`sklearn.inspection`
.........................

Expand Down Expand Up @@ -485,6 +489,13 @@ Changelog
can now contain `None`, where `drop_idx_[i] = None` means that no category
is dropped for index `i`. :pr:`16585` by :user:`Chiara Marmo <cmarmo>`.

- |Enhancement| :class:`preprocessing.MaxAbsScaler`,
:class:`preprocessing.MinMaxScaler`, :class:`preprocessing.StandardScaler`,
:class:`preprocessing.PowerTransformer`,
:class:`preprocessing.QuantileTransformer`,
:class:`preprocessing.RobustScaler` now supports pandas' nullable integer
dtype with missing values. :pr:`16508` by `Thomas Fan`_.

- |Efficiency| :class:`preprocessing.OneHotEncoder` is now faster at
transforming. :pr:`15762` by `Thomas Fan`_.

Expand Down Expand Up @@ -566,6 +577,15 @@ Changelog
matrix from a pandas DataFrame that contains only `SparseArray` columns.
:pr:`16728` by `Thomas Fan`_.

- |Enhancement| :func:`utils.validation.check_array` supports pandas'
nullable integer dtype with missing values when `force_all_finite` is set to
`False` or `'allow-nan'` in which case the data is converted to floating
point values where `pd.NA` values are replaced by `np.nan`. As a consequence,
all :mod:`sklearn.preprocessing` transformers that accept numeric inputs with
missing values represented as `np.nan` now also accepts being directly fed
pandas dataframes with `pd.Int* or `pd.Uint*` typed columns that use `pd.NA`
as a missing value marker. :pr:`16508` by `Thomas Fan`_.

- |API| Passing classes to :func:`utils.estimator_checks.check_estimator` and
:func:`utils.estimator_checks.parametrize_with_checks` is now deprecated,
and support for classes will be removed in 0.24. Pass instances instead.
Expand Down
9 changes: 6 additions & 3 deletions sklearn/impute/_base.py
Expand Up @@ -128,7 +128,9 @@ class SimpleImputer(_BaseImputer):
----------
missing_values : number, string, np.nan (default) or None
The placeholder for the missing values. All occurrences of
`missing_values` will be imputed.
`missing_values` will be imputed. For pandas' dataframes with
nullable integer dtypes with missing values, `missing_values`
should be set to `np.nan`, since `pd.NA` will be converted to `np.nan`.
strategy : string, default='mean'
The imputation strategy.
Expand Down Expand Up @@ -476,8 +478,9 @@ class MissingIndicator(TransformerMixin, BaseEstimator):
----------
missing_values : number, string, np.nan (default) or None
The placeholder for the missing values. All occurrences of
`missing_values` will be indicated (True in the output array), the
other values will be marked as False.
`missing_values` will be imputed. For pandas' dataframes with
nullable integer dtypes with missing values, `missing_values`
should be set to `np.nan`, since `pd.NA` will be converted to `np.nan`.
features : str, default=None
Whether the imputer mask should represent all or a subset of
Expand Down
4 changes: 3 additions & 1 deletion sklearn/impute/_iterative.py
Expand Up @@ -54,7 +54,9 @@ class IterativeImputer(_BaseImputer):
missing_values : int, np.nan, default=np.nan
The placeholder for the missing values. All occurrences of
``missing_values`` will be imputed.
`missing_values` will be imputed. For pandas' dataframes with
nullable integer dtypes with missing values, `missing_values`
should be set to `np.nan`, since `pd.NA` will be converted to `np.nan`.
sample_posterior : boolean, default=False
Whether to sample from the (Gaussian) predictive posterior of the
Expand Down
4 changes: 3 additions & 1 deletion sklearn/impute/_knn.py
Expand Up @@ -32,7 +32,9 @@ class KNNImputer(_BaseImputer):
----------
missing_values : number, string, np.nan or None, default=`np.nan`
The placeholder for the missing values. All occurrences of
`missing_values` will be imputed.
`missing_values` will be imputed. For pandas' dataframes with
nullable integer dtypes with missing values, `missing_values`
should be set to `np.nan`, since `pd.NA` will be converted to `np.nan`.
n_neighbors : int, default=5
Number of neighboring samples to use for imputation.
Expand Down
29 changes: 29 additions & 0 deletions sklearn/impute/tests/test_common.py
Expand Up @@ -84,3 +84,32 @@ def test_imputers_add_indicator_sparse(imputer, marker):
imputer.set_params(add_indicator=False)
X_trans_no_indicator = imputer.fit_transform(X)
assert_allclose_dense_sparse(X_trans[:, :-4], X_trans_no_indicator)


# ConvergenceWarning will be raised by the IterativeImputer
@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
@pytest.mark.parametrize("imputer", IMPUTERS)
@pytest.mark.parametrize("add_indicator", [True, False])
def test_imputers_pandas_na_integer_array_support(imputer, add_indicator):
# Test pandas IntegerArray with pd.NA
pd = pytest.importorskip('pandas', minversion="1.0")
marker = np.nan
imputer = imputer.set_params(add_indicator=add_indicator,
missing_values=marker)

X = np.array([
[marker, 1, 5, marker, 1],
[2, marker, 1, marker, 2],
[6, 3, marker, marker, 3],
[1, 2, 9, marker, 4]
])
# fit on numpy array
X_trans_expected = imputer.fit_transform(X)

# Creates dataframe with IntegerArrays with pd.NA
X_df = pd.DataFrame(X, dtype="Int16", columns=["a", "b", "c", "d", "e"])

# fit on pandas dataframe with IntegerArrays
X_trans = imputer.fit_transform(X_df)

assert_allclose(X_trans_expected, X_trans)
23 changes: 15 additions & 8 deletions sklearn/metrics/pairwise.py
Expand Up @@ -100,17 +100,20 @@ def check_pairwise_arrays(X, Y, *, precomputed=False, dtype=None,
raise an error.
force_all_finite : boolean or 'allow-nan', (default=True)
Whether to raise an error on np.inf and np.nan in array. The
Whether to raise an error on np.inf, np.nan, pd.NA in array. The
possibilities are:
- True: Force all values of array to be finite.
- False: accept both np.inf and np.nan in array.
- 'allow-nan': accept only np.nan values in array. Values cannot
be infinite.
- False: accepts np.inf, np.nan, pd.NA in array.
- 'allow-nan': accepts only np.nan and pd.NA values in array. Values
cannot be infinite.
.. versionadded:: 0.22
``force_all_finite`` accepts the string ``'allow-nan'``.
.. versionchanged:: 0.23
Accepts `pd.NA` and converts it into `np.nan`
copy : bool
Whether a forced copy will be triggered. If copy=False, a copy might
be triggered by a conversion.
Expand Down Expand Up @@ -1691,15 +1694,19 @@ def pairwise_distances(X, Y=None, metric="euclidean", *, n_jobs=None,
for more details.
force_all_finite : boolean or 'allow-nan', (default=True)
Whether to raise an error on np.inf and np.nan in array. The
Whether to raise an error on np.inf, np.nan, pd.NA in array. The
possibilities are:
- True: Force all values of array to be finite.
- False: accept both np.inf and np.nan in array.
- 'allow-nan': accept only np.nan values in array. Values cannot
be infinite.
- False: accepts np.inf, np.nan, pd.NA in array.
- 'allow-nan': accepts only np.nan and pd.NA values in array. Values
cannot be infinite.
.. versionadded:: 0.22
``force_all_finite`` accepts the string ``'allow-nan'``.
.. versionchanged:: 0.23
Accepts `pd.NA` and converts it into `np.nan`
**kwds : optional keyword parameters
Any further parameters are passed directly to the distance function.
Expand Down
30 changes: 30 additions & 0 deletions sklearn/preprocessing/tests/test_common.py
Expand Up @@ -126,3 +126,33 @@ def test_missing_value_handling(est, func, support_sparse, strictly_positive):
Xt_inv_sp = est_sparse.inverse_transform(Xt_sp)
assert len(records) == 0
assert_allclose(Xt_inv_sp.A, Xt_inv_dense)


@pytest.mark.parametrize(
"est, func",
[(MaxAbsScaler(), maxabs_scale),
(MinMaxScaler(), minmax_scale),
(StandardScaler(), scale),
(StandardScaler(with_mean=False), scale),
(PowerTransformer('yeo-johnson'), power_transform),
(PowerTransformer('box-cox'), power_transform,),
(QuantileTransformer(n_quantiles=3), quantile_transform),
(RobustScaler(), robust_scale),
(RobustScaler(with_centering=False), robust_scale)]
)
def test_missing_value_pandas_na_support(est, func):
# Test pandas IntegerArray with pd.NA
pd = pytest.importorskip('pandas', minversion="1.0")

X = np.array([[1, 2, 3, np.nan, np.nan, 4, 5, 1],
[np.nan, np.nan, 8, 4, 6, np.nan, np.nan, 8],
[1, 2, 3, 4, 5, 6, 7, 8]]).T

# Creates dataframe with IntegerArrays with pd.NA
X_df = pd.DataFrame(X, dtype="Int16", columns=['a', 'b', 'c'])
X_df['c'] = X_df['c'].astype('int')

X_trans = est.fit_transform(X)
X_df_trans = est.fit_transform(X_df)

assert_allclose(X_trans, X_df_trans)
31 changes: 31 additions & 0 deletions sklearn/utils/tests/test_validation.py
Expand Up @@ -349,6 +349,37 @@ def test_check_array():
check_array(X, dtype="numeric")


@pytest.mark.parametrize("pd_dtype", ["Int8", "Int16", "UInt8", "UInt16"])
@pytest.mark.parametrize("dtype, expected_dtype", [
([np.float32, np.float64], np.float32),
(np.float64, np.float64),
("numeric", np.float64),
])
def test_check_array_pandas_na_support(pd_dtype, dtype, expected_dtype):
# Test pandas IntegerArray with pd.NA
pd = pytest.importorskip('pandas', minversion="1.0")

X_np = np.array([[1, 2, 3, np.nan, np.nan],
[np.nan, np.nan, 8, 4, 6],
[1, 2, 3, 4, 5]]).T

# Creates dataframe with IntegerArrays with pd.NA
X = pd.DataFrame(X_np, dtype=pd_dtype, columns=['a', 'b', 'c'])
# column c has no nans
X['c'] = X['c'].astype('float')
X_checked = check_array(X, force_all_finite='allow-nan', dtype=dtype)
assert_allclose(X_checked, X_np)
assert X_checked.dtype == expected_dtype

X_checked = check_array(X, force_all_finite=False, dtype=dtype)
assert_allclose(X_checked, X_np)
assert X_checked.dtype == expected_dtype

msg = "Input contains NaN, infinity"
with pytest.raises(ValueError, match=msg):
check_array(X, force_all_finite=True)


def test_check_array_pandas_dtype_object_conversion():
# test that data-frame like objects with dtype object
# get converted
Expand Down
70 changes: 49 additions & 21 deletions sklearn/utils/validation.py
Expand Up @@ -135,17 +135,20 @@ def as_float_array(X, *, copy=True, force_all_finite=True):
returned if X's dtype is not a floating point type.
force_all_finite : boolean or 'allow-nan', (default=True)
Whether to raise an error on np.inf and np.nan in X. The possibilities
are:
Whether to raise an error on np.inf, np.nan, pd.NA in X. The
possibilities are:
- True: Force all values of X to be finite.
- False: accept both np.inf and np.nan in X.
- 'allow-nan': accept only np.nan values in X. Values cannot be
infinite.
- False: accepts np.inf, np.nan, pd.NA in X.
- 'allow-nan': accepts only np.nan and pd.NA values in X. Values cannot
be infinite.
.. versionadded:: 0.20
``force_all_finite`` accepts the string ``'allow-nan'``.
.. versionchanged:: 0.23
Accepts `pd.NA` and converts it into `np.nan`
Returns
-------
XT : {array, sparse matrix}
Expand Down Expand Up @@ -317,17 +320,20 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy,
be triggered by a conversion.
force_all_finite : boolean or 'allow-nan', (default=True)
Whether to raise an error on np.inf and np.nan in X. The possibilities
are:
Whether to raise an error on np.inf, np.nan, pd.NA in X. The
possibilities are:
- True: Force all values of X to be finite.
- False: accept both np.inf and np.nan in X.
- 'allow-nan': accept only np.nan values in X. Values cannot be
infinite.
- False: accepts np.inf, np.nan, pd.NA in X.
- 'allow-nan': accepts only np.nan and pd.NA values in X. Values cannot
be infinite.
.. versionadded:: 0.20
``force_all_finite`` accepts the string ``'allow-nan'``.
.. versionchanged:: 0.23
Accepts `pd.NA` and converts it into `np.nan`
Returns
-------
spmatrix_converted : scipy sparse matrix.
Expand Down Expand Up @@ -438,19 +444,20 @@ def check_array(array, accept_sparse=False, *, accept_large_sparse=True,
be triggered by a conversion.
force_all_finite : boolean or 'allow-nan', (default=True)
Whether to raise an error on np.inf and np.nan in array. The
Whether to raise an error on np.inf, np.nan, pd.NA in array. The
possibilities are:
- True: Force all values of array to be finite.
- False: accept both np.inf and np.nan in array.
- 'allow-nan': accept only np.nan values in array. Values cannot
be infinite.
For object dtyped data, only np.nan is checked and not np.inf.
- False: accepts np.inf, np.nan, pd.NA in array.
- 'allow-nan': accepts only np.nan and pd.NA values in array. Values
cannot be infinite.
.. versionadded:: 0.20
``force_all_finite`` accepts the string ``'allow-nan'``.
.. versionchanged:: 0.23
Accepts `pd.NA` and converts it into `np.nan`
ensure_2d : boolean (default=True)
Whether to raise a value error if array is not 2D.
Expand Down Expand Up @@ -491,6 +498,7 @@ def check_array(array, accept_sparse=False, *, accept_large_sparse=True,
# check if the object contains several dtypes (typically a pandas
# DataFrame), and store them. If not, store None.
dtypes_orig = None
has_pd_integer_array = False
if hasattr(array, "dtypes") and hasattr(array.dtypes, '__array__'):
# throw warning if columns are sparse. If all columns are sparse, then
# array.sparse exists and sparsity will be perserved (later).
Expand All @@ -508,6 +516,19 @@ def check_array(array, accept_sparse=False, *, accept_large_sparse=True,
for i, dtype_iter in enumerate(dtypes_orig):
if dtype_iter.kind == 'b':
dtypes_orig[i] = np.dtype(np.object)
elif dtype_iter.name.startswith(("Int", "UInt")):
# name looks like an Integer Extension Array, now check for
# the dtype
with suppress(ImportError):
from pandas import (Int8Dtype, Int16Dtype,
Int32Dtype, Int64Dtype,
UInt8Dtype, UInt16Dtype,
UInt32Dtype, UInt64Dtype)
if isinstance(dtype_iter, (Int8Dtype, Int16Dtype,
Int32Dtype, Int64Dtype,
UInt8Dtype, UInt16Dtype,
UInt32Dtype, UInt64Dtype)):
has_pd_integer_array = True

if all(isinstance(dtype, np.dtype) for dtype in dtypes_orig):
dtype_orig = np.result_type(*dtypes_orig)
Expand All @@ -528,6 +549,10 @@ def check_array(array, accept_sparse=False, *, accept_large_sparse=True,
# list of accepted types.
dtype = dtype[0]

if has_pd_integer_array:
# If there are any pandas integer extension arrays,
array = array.astype(dtype)

if force_all_finite not in (True, False, 'allow-nan'):
raise ValueError('force_all_finite should be a bool or "allow-nan"'
'. Got {!r} instead'.format(force_all_finite))
Expand Down Expand Up @@ -712,18 +737,21 @@ def check_X_y(X, y, accept_sparse=False, *, accept_large_sparse=True,
be triggered by a conversion.
force_all_finite : boolean or 'allow-nan', (default=True)
Whether to raise an error on np.inf and np.nan in X. This parameter
does not influence whether y can have np.inf or np.nan values.
Whether to raise an error on np.inf, np.nan, pd.NA in X. This parameter
does not influence whether y can have np.inf, np.nan, pd.NA values.
The possibilities are:
- True: Force all values of X to be finite.
- False: accept both np.inf and np.nan in X.
- 'allow-nan': accept only np.nan values in X. Values cannot be
infinite.
- False: accepts np.inf, np.nan, pd.NA in X.
- 'allow-nan': accepts only np.nan or pd.NA values in X. Values cannot
be infinite.
.. versionadded:: 0.20
``force_all_finite`` accepts the string ``'allow-nan'``.
.. versionchanged:: 0.23
Accepts `pd.NA` and converts it into `np.nan`
ensure_2d : boolean (default=True)
Whether to raise a value error if X is not 2D.
Expand Down

0 comments on commit 6d383d0

Please sign in to comment.