Skip to content

Commit

Permalink
BUG: Respect axis when doing DataFrame.expanding
Browse files Browse the repository at this point in the history
  • Loading branch information
gfyoung committed Oct 29, 2018
1 parent 360e727 commit 6470a1a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,7 @@ Groupby/Resample/Rolling
- Bug in :meth:`SeriesGroupBy.mean` when values were integral but could not fit inside of int64, overflowing instead. (:issue:`22487`)
- :func:`RollingGroupby.agg` and :func:`ExpandingGroupby.agg` now support multiple aggregation functions as parameters (:issue:`15072`)
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`)
- Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`)

Reshaping
^^^^^^^^^
Expand Down
9 changes: 9 additions & 0 deletions pandas/core/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,15 @@ def _constructor(self):

def _get_window(self, other=None):
obj = self._selected_obj
axis = self.obj._get_axis_number(self.axis)

# If axis == 0, leave `obj` alone.
# If axis == 1, transpose.
#
# Other values of axis are invalid.
if axis == 1:
obj = obj.T

if other is None:
return (max(len(obj), self.min_periods) if self.min_periods
else len(obj))
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,16 @@ def test_iter_raises(self, klass):
with pytest.raises(NotImplementedError):
iter(obj.expanding(2))

def test_expanding_axis(self):
# see gh-23372.
df = DataFrame(np.ones((10, 20)))

exp_row = [np.nan] * 2 + [float(i) for i in range(3, 21)]
expected = DataFrame([exp_row] * 10)

result = df.expanding(3, axis=1).sum()
tm.assert_frame_equal(result, expected)


class TestEWM(Base):

Expand Down

0 comments on commit 6470a1a

Please sign in to comment.