Skip to content

Commit 8f490e6

Browse files
feat: pivot_table supports fill_value arg (#2257)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 20ab469 commit 8f490e6

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

bigframes/core/reshape/pivot.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,11 @@ def crosstab(
7171
columns=tmp_col_names,
7272
aggfunc=aggfunc or "count",
7373
sort=False,
74+
fill_value=0 if (aggfunc is None) else None,
7475
)
76+
# Undo temporary unique level labels
7577
pivot_table.index.names = rownames or [i.name for i in index]
7678
pivot_table.columns.names = colnames or [c.name for c in columns]
77-
if aggfunc is None:
78-
# TODO: Push this into pivot_table itself
79-
pivot_table = pivot_table.fillna(0)
8079
return pivot_table
8180

8281

bigframes/dataframe.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3486,10 +3486,6 @@ def pivot_table(
34863486
observed: bool = False,
34873487
sort: bool = True,
34883488
) -> DataFrame:
3489-
if fill_value is not None:
3490-
raise NotImplementedError(
3491-
"DataFrame.pivot_table fill_value arg not supported. {constants.FEEDBACK_LINK}"
3492-
)
34933489
if margins:
34943490
raise NotImplementedError(
34953491
"DataFrame.pivot_table margins arg not supported. {constants.FEEDBACK_LINK}"
@@ -3549,14 +3545,16 @@ def pivot_table(
35493545
index=index,
35503546
values=values if len(values) > 1 else None,
35513547
)
3548+
if fill_value is not None:
3549+
pivoted = pivoted.fillna(fill_value)
35523550
if sort:
35533551
pivoted = pivoted.sort_index()
35543552

35553553
# TODO: Remove the reordering step once the issue is resolved.
35563554
# The pivot_table method results in multi-index columns that are always ordered.
35573555
# However, the order of the pivoted result columns is not guaranteed to be sorted.
35583556
# Sort and reorder.
3559-
return pivoted[pivoted.columns.sort_values()]
3557+
return pivoted.sort_index(axis=1) # type: ignore
35603558

35613559
def stack(self, level: LevelsType = -1):
35623560
if not isinstance(self.columns, pandas.MultiIndex):

tests/system/small/test_dataframe.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3784,12 +3784,18 @@ def test_df_pivot_hockey(hockey_df, hockey_pandas_df, values, index, columns):
37843784

37853785

37863786
@pytest.mark.parametrize(
3787-
("values", "index", "columns", "aggfunc"),
3787+
("values", "index", "columns", "aggfunc", "fill_value"),
37883788
[
3789-
(("culmen_length_mm", "body_mass_g"), "species", "sex", "std"),
3790-
(["body_mass_g", "culmen_length_mm"], ("species", "island"), "sex", "sum"),
3791-
("body_mass_g", "sex", ["island", "species"], "mean"),
3792-
("culmen_depth_mm", "island", "species", "max"),
3789+
(("culmen_length_mm", "body_mass_g"), "species", "sex", "std", 1.0),
3790+
(
3791+
["body_mass_g", "culmen_length_mm"],
3792+
("species", "island"),
3793+
"sex",
3794+
"sum",
3795+
None,
3796+
),
3797+
("body_mass_g", "sex", ["island", "species"], "mean", None),
3798+
("culmen_depth_mm", "island", "species", "max", -1),
37933799
],
37943800
)
37953801
def test_df_pivot_table(
@@ -3799,12 +3805,21 @@ def test_df_pivot_table(
37993805
index,
38003806
columns,
38013807
aggfunc,
3808+
fill_value,
38023809
):
38033810
bf_result = penguins_df_default_index.pivot_table(
3804-
values=values, index=index, columns=columns, aggfunc=aggfunc
3811+
values=values,
3812+
index=index,
3813+
columns=columns,
3814+
aggfunc=aggfunc,
3815+
fill_value=fill_value,
38053816
).to_pandas()
38063817
pd_result = penguins_pandas_df_default_index.pivot_table(
3807-
values=values, index=index, columns=columns, aggfunc=aggfunc
3818+
values=values,
3819+
index=index,
3820+
columns=columns,
3821+
aggfunc=aggfunc,
3822+
fill_value=fill_value,
38083823
)
38093824
pd.testing.assert_frame_equal(
38103825
bf_result, pd_result, check_dtype=False, check_column_type=False

third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6414,6 +6414,10 @@ def pivot_table(self, values=None, index=None, columns=None, aggfunc="mean"):
64146414
aggfunc (str, default "mean"):
64156415
Aggregation function name to compute summary statistics (e.g., 'sum', 'mean').
64166416
6417+
fill_value (scalar, default None):
6418+
Value to replace missing values with (in the resulting pivot table, after
6419+
aggregation).
6420+
64176421
Returns:
64186422
bigframes.pandas.DataFrame: An Excel style pivot table.
64196423
"""

0 commit comments

Comments
 (0)