Skip to content

Commit

Permalink
DEPR: try_cast kwarg in mask, where (pandas-dev#38836)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and luckyvs1 committed Jan 20, 2021
1 parent 5f3c37d commit 276c1b8
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 30 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ Deprecations
- Deprecating allowing scalars passed to the :class:`Categorical` constructor (:issue:`38433`)
- Deprecated allowing subclass-specific keyword arguments in the :class:`Index` constructor, use the specific subclass directly instead (:issue:`14093`,:issue:`21311`,:issue:`22315`,:issue:`26974`)
- Deprecated ``astype`` of datetimelike (``timedelta64[ns]``, ``datetime64[ns]``, ``Datetime64TZDtype``, ``PeriodDtype``) to integer dtypes, use ``values.view(...)`` instead (:issue:`38544`)
-
- Deprecated keyword ``try_cast`` in :meth:`Series.where`, :meth:`Series.mask`, :meth:`DataFrame.where`, :meth:`DataFrame.mask`; cast results manually if desired (:issue:`38836`)
-

.. ---------------------------------------------------------------------------
Expand Down
33 changes: 24 additions & 9 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8781,7 +8781,6 @@ def _where(
axis=None,
level=None,
errors="raise",
try_cast=False,
):
"""
Equivalent to public method `where`, except that `other` is not
Expand Down Expand Up @@ -8932,7 +8931,6 @@ def _where(
cond=cond,
align=align,
errors=errors,
try_cast=try_cast,
axis=block_axis,
)
result = self._constructor(new_data)
Expand All @@ -8954,7 +8952,7 @@ def where(
axis=None,
level=None,
errors="raise",
try_cast=False,
try_cast=lib.no_default,
):
"""
Replace values where the condition is {cond_rev}.
Expand Down Expand Up @@ -8986,9 +8984,12 @@ def where(
- 'raise' : allow exceptions to be raised.
- 'ignore' : suppress exceptions. On error return original object.
try_cast : bool, default False
try_cast : bool, default None
Try to cast the result back to the input type (if possible).
.. deprecated:: 1.3.0
Manually cast back if necessary.
Returns
-------
Same type as caller or None if ``inplace=True``.
Expand Down Expand Up @@ -9077,9 +9078,16 @@ def where(
4 True True
"""
other = com.apply_if_callable(other, self)
return self._where(
cond, other, inplace, axis, level, errors=errors, try_cast=try_cast
)

if try_cast is not lib.no_default:
warnings.warn(
"try_cast keyword is deprecated and will be removed in a "
"future version",
FutureWarning,
stacklevel=2,
)

return self._where(cond, other, inplace, axis, level, errors=errors)

@final
@doc(
Expand All @@ -9098,12 +9106,20 @@ def mask(
axis=None,
level=None,
errors="raise",
try_cast=False,
try_cast=lib.no_default,
):

inplace = validate_bool_kwarg(inplace, "inplace")
cond = com.apply_if_callable(cond, self)

if try_cast is not lib.no_default:
warnings.warn(
"try_cast keyword is deprecated and will be removed in a "
"future version",
FutureWarning,
stacklevel=2,
)

# see gh-21891
if not hasattr(cond, "__invert__"):
cond = np.array(cond)
Expand All @@ -9114,7 +9130,6 @@ def mask(
inplace=inplace,
axis=axis,
level=level,
try_cast=try_cast,
errors=errors,
)

Expand Down
21 changes: 5 additions & 16 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,9 +1290,7 @@ def _maybe_reshape_where_args(self, values, other, cond, axis):

return other, cond

def where(
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0
) -> List["Block"]:
def where(self, other, cond, errors="raise", axis: int = 0) -> List["Block"]:
"""
evaluate the block; return result block(s) from the result
Expand All @@ -1303,7 +1301,6 @@ def where(
errors : str, {'raise', 'ignore'}, default 'raise'
- ``raise`` : allow exceptions to be raised
- ``ignore`` : suppress exceptions. On error return original object
try_cast: bool, default False
axis : int, default 0
Returns
Expand Down Expand Up @@ -1342,9 +1339,7 @@ def where(
# we cannot coerce, return a compat dtype
# we are explicitly ignoring errors
block = self.coerce_to_target_dtype(other)
blocks = block.where(
orig_other, cond, errors=errors, try_cast=try_cast, axis=axis
)
blocks = block.where(orig_other, cond, errors=errors, axis=axis)
return self._maybe_downcast(blocks, "infer")

if not (
Expand Down Expand Up @@ -1825,9 +1820,7 @@ def shift(
)
]

def where(
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0
) -> List["Block"]:
def where(self, other, cond, errors="raise", axis: int = 0) -> List["Block"]:

cond = _extract_bool_array(cond)
assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame))
Expand Down Expand Up @@ -2075,9 +2068,7 @@ def to_native_types(self, na_rep="NaT", **kwargs):
result = arr._format_native_types(na_rep=na_rep, **kwargs)
return self.make_block(result)

def where(
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0
) -> List["Block"]:
def where(self, other, cond, errors="raise", axis: int = 0) -> List["Block"]:
# TODO(EA2D): reshape unnecessary with 2D EAs
arr = self.array_values().reshape(self.shape)

Expand All @@ -2086,9 +2077,7 @@ def where(
try:
res_values = arr.T.where(cond, other).T
except (ValueError, TypeError):
return super().where(
other, cond, errors=errors, try_cast=try_cast, axis=axis
)
return super().where(other, cond, errors=errors, axis=axis)

# TODO(EA2D): reshape not needed with 2D EAs
res_values = res_values.reshape(self.values.shape)
Expand Down
5 changes: 1 addition & 4 deletions pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,7 @@ def get_axe(block, qs, axes):
def isna(self, func) -> "BlockManager":
return self.apply("apply", func=func)

def where(
self, other, cond, align: bool, errors: str, try_cast: bool, axis: int
) -> "BlockManager":
def where(self, other, cond, align: bool, errors: str, axis: int) -> "BlockManager":
if align:
align_keys = ["other", "cond"]
else:
Expand All @@ -557,7 +555,6 @@ def where(
other=other,
cond=cond,
errors=errors,
try_cast=try_cast,
axis=axis,
)

Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/frame/indexing/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,16 @@ def test_mask_dtype_conversion(self):
expected = bools.astype(float).mask(mask)
result = bools.mask(mask)
tm.assert_frame_equal(result, expected)


def test_mask_try_cast_deprecated(frame_or_series):

obj = DataFrame(np.random.randn(4, 3))
if frame_or_series is not DataFrame:
obj = obj[0]

mask = obj > 0

with tm.assert_produces_warning(FutureWarning):
# try_cast keyword deprecated
obj.mask(mask, -1, try_cast=True)
12 changes: 12 additions & 0 deletions pandas/tests/frame/indexing/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,3 +672,15 @@ def test_where_ea_other(self):
expected["B"] = expected["B"].astype(object)
result = df.where(mask, ser2, axis=1)
tm.assert_frame_equal(result, expected)


def test_where_try_cast_deprecated(frame_or_series):
obj = DataFrame(np.random.randn(4, 3))
if frame_or_series is not DataFrame:
obj = obj[0]

mask = obj > 0

with tm.assert_produces_warning(FutureWarning):
# try_cast keyword deprecated
obj.where(mask, -1, try_cast=False)

0 comments on commit 276c1b8

Please sign in to comment.