Skip to content

Commit

Permalink
Adapt to deprecations in pandas 2.2.0 (#3620)
Browse files Browse the repository at this point in the history
* Adapt to deprecations in pandas 2.2.0

* Backcompat for bug where the fix was deprecated?

* More multi-version support
  • Loading branch information
mwaskom committed Jan 21, 2024
1 parent a3cb0f1 commit 7aed2a0
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 14 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,6 @@ exclude = ["doc/_static/*.svg"]
[tool.pytest.ini_options]
filterwarnings = [
"ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning",
"ignore:\\s*Pyarrow will become a required dependency of pandas:DeprecationWarning",
"ignore:datetime.datetime.utcfromtimestamp\\(\\) is deprecated:DeprecationWarning",
]
6 changes: 3 additions & 3 deletions seaborn/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,9 +942,9 @@ def iter_data(

for key in iter_keys:

# Pandas fails with singleton tuple inputs
pd_key = key[0] if len(key) == 1 else key

pd_key = (
key[0] if len(key) == 1 and _version_predates(pd, "2.2.0") else key
)
try:
data_subset = grouped_data.get_group(pd_key)
except KeyError:
Expand Down
7 changes: 7 additions & 0 deletions seaborn/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Literal

import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib.figure import Figure
from seaborn.utils import _version_predates
Expand Down Expand Up @@ -114,3 +115,9 @@ def get_legend_handles(legend):
return legend.legendHandles
else:
return legend.legend_handles


def groupby_apply_include_groups(val):
if _version_predates(pd, "2.2.0"):
return {}
return {"include_groups": val}
8 changes: 5 additions & 3 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from seaborn._core.exceptions import PlotSpecError
from seaborn._core.rules import categorical_order
from seaborn._compat import get_layout_engine, set_layout_engine
from seaborn.utils import _version_predates
from seaborn.rcmod import axes_style, plotting_context
from seaborn.palettes import color_palette

Expand Down Expand Up @@ -1637,9 +1638,10 @@ def split_generator(keep_na=False) -> Generator:

for key in itertools.product(*grouping_keys):

# Pandas fails with singleton tuple inputs
pd_key = key[0] if len(key) == 1 else key

pd_key = (
key[0] if len(key) == 1 and _version_predates(pd, "2.2.0")
else key
)
try:
df_subset = grouped_df.get_group(pd_key)
except KeyError:
Expand Down
7 changes: 4 additions & 3 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_scatter_legend_artist,
_version_predates,
)
from seaborn._compat import groupby_apply_include_groups
from seaborn._statistics import (
EstimateAggregator,
LetterValues,
Expand Down Expand Up @@ -634,10 +635,10 @@ def get_props(element, artist=mpl.lines.Line2D):
ax = self._get_axes(sub_vars)

grouped = sub_data.groupby(self.orient)[value_var]
positions = sorted(sub_data[self.orient].unique().astype(float))
value_data = [x.to_numpy() for _, x in grouped]
stats = pd.DataFrame(mpl.cbook.boxplot_stats(value_data, whis=whis,
bootstrap=bootstrap))
positions = grouped.grouper.result_index.to_numpy(dtype=float)

orig_width = width * self._native_width
data = pd.DataFrame({self.orient: positions, "width": orig_width})
Expand Down Expand Up @@ -1207,7 +1208,7 @@ def plot_points(
agg_data = sub_data if sub_data.empty else (
sub_data
.groupby(self.orient)
.apply(aggregator, agg_var)
.apply(aggregator, agg_var, **groupby_apply_include_groups(False))
.reindex(pd.Index(positions, name=self.orient))
.reset_index()
)
Expand Down Expand Up @@ -1278,7 +1279,7 @@ def plot_bars(
agg_data = sub_data if sub_data.empty else (
sub_data
.groupby(self.orient)
.apply(aggregator, agg_var)
.apply(aggregator, agg_var, **groupby_apply_include_groups(False))
.reset_index()
)

Expand Down
7 changes: 6 additions & 1 deletion seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_get_transform_functions,
_scatter_legend_artist,
)
from ._compat import groupby_apply_include_groups
from ._statistics import EstimateAggregator, WeightedAggregator
from .axisgrid import FacetGrid, _facet_docs
from ._docstrings import DocstringComponents, _core_docs
Expand Down Expand Up @@ -290,7 +291,11 @@ def plot(self, ax, kws):
grouped = sub_data.groupby(orient, sort=self.sort)
# Could pass as_index=False instead of reset_index,
# but that fails on a corner case with older pandas.
sub_data = grouped.apply(agg, other).reset_index()
sub_data = (
grouped
.apply(agg, other, **groupby_apply_include_groups(False))
.reset_index()
)
else:
sub_data[f"{other}min"] = np.nan
sub_data[f"{other}max"] = np.nan
Expand Down
17 changes: 14 additions & 3 deletions tests/_stats/test_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from seaborn._core.groupby import GroupBy
from seaborn._stats.density import KDE, _no_scipy
from seaborn._compat import groupby_apply_include_groups


class TestKDE:
Expand Down Expand Up @@ -93,7 +94,10 @@ def test_common_norm(self, df, common_norm):

areas = (
res.groupby("alpha")
.apply(lambda x: self.integrate(x["density"], x[ori]))
.apply(
lambda x: self.integrate(x["density"], x[ori]),
**groupby_apply_include_groups(False),
)
)

if common_norm:
Expand All @@ -111,11 +115,18 @@ def test_common_norm_variables(self, df):
def integrate_by_color_and_sum(x):
return (
x.groupby("color")
.apply(lambda y: self.integrate(y["density"], y[ori]))
.apply(
lambda y: self.integrate(y["density"], y[ori]),
**groupby_apply_include_groups(False)
)
.sum()
)

areas = res.groupby("alpha").apply(integrate_by_color_and_sum)
areas = (
res
.groupby("alpha")
.apply(integrate_by_color_and_sum, **groupby_apply_include_groups(False))
)
assert_array_almost_equal(areas, [1, 1], decimal=3)

@pytest.mark.parametrize("param", ["norm", "grid"])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,7 +2078,7 @@ def test_xy_native_scale_log_transform(self):

def test_datetime_native_scale_axis(self):

x = pd.date_range("2010-01-01", periods=20, freq="m")
x = pd.date_range("2010-01-01", periods=20, freq="MS")
y = np.arange(20)
ax = barplot(x=x, y=y, native_scale=True)
assert "Date" in ax.xaxis.get_major_locator().__class__.__name__
Expand Down

0 comments on commit 7aed2a0

Please sign in to comment.