Skip to content

Commit

Permalink
BUG: fix groupby.aggregate resulting dtype coercion, xref pandas-dev#…
Browse files Browse the repository at this point in the history
…11444, pandas-dev#13046

make sure .size includes the name of the grouped
  • Loading branch information
jreback committed Feb 27, 2017
1 parent 251826f commit 61fa8be
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 13 deletions.
4 changes: 2 additions & 2 deletions doc/source/whatsnew/v0.20.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -632,12 +632,12 @@ Bug Fixes

- Bug in ``DataFrame.to_stata()`` and ``StataWriter`` which produces incorrectly formatted files to be produced for some locales (:issue:`13856`)
- Bug in ``pd.concat()`` in which concatting with an empty dataframe with ``join='inner'`` was being improperly handled (:issue:`15328`)
- Bug in ``groupby.agg()`` incorrectly localizing timezone on ``datetime`` (:issue:`15426`, :issue:`10668`)
- Bug in ``groupby.agg()`` incorrectly localizing timezone on ``datetime`` (:issue:`15426`, :issue:`10668`, :issue:`13046`)



- Bug in ``.read_csv()`` with ``parse_dates`` when multiline headers are specified (:issue:`15376`)
- Bug in ``groupby.transform()`` that would coerce the resultant dtypes back to the original (:issue:`10972`)
- Bug in ``groupby.transform()`` that would coerce the resultant dtypes back to the original (:issue:`10972`, :issue:`11444`)

- Bug in ``DataFrame.hist`` where ``plt.tight_layout`` caused an ``AttributeError`` (use ``matplotlib >= 0.2.0``) (:issue:`9351`)
- Bug in ``DataFrame.boxplot`` where ``fontsize`` was not applied to the tick labels on both axes (:issue:`15108`)
Expand Down
23 changes: 17 additions & 6 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,19 +767,23 @@ def _index_with_as_index(self, b):
new.names = gp.names + original.names
return new

def _try_cast(self, result, obj):
def _try_cast(self, result, obj, numeric_only=False):
"""
try to cast the result to our obj original type,
we may have roundtripped thru object in the mean-time
if numeric_only is True, then only try to cast numerics
and not datetimelikes
"""
if obj.ndim > 1:
dtype = obj.values.dtype
else:
dtype = obj.dtype

if not is_scalar(result):
result = _possibly_downcast_to_dtype(result, dtype)
if numeric_only and is_numeric_dtype(dtype) or not numeric_only:
result = _possibly_downcast_to_dtype(result, dtype)

return result

Expand Down Expand Up @@ -830,7 +834,7 @@ def _python_agg_general(self, func, *args, **kwargs):
for name, obj in self._iterate_slices():
try:
result, counts = self.grouper.agg_series(obj, f)
output[name] = self._try_cast(result, obj)
output[name] = self._try_cast(result, obj, numeric_only=True)
except TypeError:
continue

Expand Down Expand Up @@ -1117,7 +1121,11 @@ def sem(self, ddof=1):
@Appender(_doc_template)
def size(self):
"""Compute group sizes"""
return self.grouper.size()
result = self.grouper.size()

if isinstance(self.obj, Series):
result.name = getattr(self, 'name', None)
return result

sum = _groupby_function('sum', 'add', np.sum)
prod = _groupby_function('prod', 'prod', np.prod)
Expand Down Expand Up @@ -1689,7 +1697,9 @@ def size(self):
ids, _, ngroup = self.group_info
ids = _ensure_platform_int(ids)
out = np.bincount(ids[ids != -1], minlength=ngroup or None)
return Series(out, index=self.result_index, dtype='int64')
return Series(out,
index=self.result_index,
dtype='int64')

@cache_readonly
def _max_groupsize(self):
Expand Down Expand Up @@ -2908,7 +2918,8 @@ def transform(self, func, *args, **kwargs):
result = concat(results).sort_index()

# we will only try to coerce the result type if
# we have a numeric dtype
# we have a numeric dtype, as these are *always* udfs
# the cython take a different path (and casting)
dtype = self._selected_obj.dtype
if is_numeric_dtype(dtype):
result = _possibly_downcast_to_dtype(result, dtype)
Expand Down
23 changes: 23 additions & 0 deletions pandas/tests/groupby/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,29 @@ def test_agg_dict_parameter_cast_result_dtypes(self):
assert_series_equal(grouped.time.last(), exp['time'])
assert_series_equal(grouped.time.agg('last'), exp['time'])

# count
exp = pd.Series([2, 2, 2, 2],
index=Index(list('ABCD'), name='class'),
name='time')
assert_series_equal(grouped.time.agg(len), exp)
assert_series_equal(grouped.time.size(), exp)

exp = pd.Series([0, 1, 1, 2],
index=Index(list('ABCD'), name='class'),
name='time')
assert_series_equal(grouped.time.count(), exp)

def test_agg_cast_results_dtypes(self):
# similar to GH12821
# xref #11444
u = [datetime(2015, x + 1, 1) for x in range(12)]
v = list('aaabbbbbbccd')
df = pd.DataFrame({'X': v, 'Y': u})

result = df.groupby('X')['Y'].agg(len)
expected = df.groupby('X')['Y'].count()
assert_series_equal(result, expected)

def test_agg_must_agg(self):
grouped = self.df.groupby('A')['C']
self.assertRaises(Exception, grouped.agg, lambda x: x.describe())
Expand Down
29 changes: 28 additions & 1 deletion pandas/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pandas as pd
from pandas.util import testing as tm
from pandas import Series, DataFrame, Timestamp, MultiIndex, concat, date_range
from pandas.types.common import _ensure_platform_int
from pandas.types.common import _ensure_platform_int, is_timedelta64_dtype
from pandas.compat import StringIO
from .common import MixIn, assert_fp_equal

from pandas.util.testing import assert_frame_equal, assert_series_equal
Expand Down Expand Up @@ -227,6 +228,32 @@ def test_transform_datetime_to_numeric(self):
expected = Series([0, 1], name='b')
assert_series_equal(result, expected)

def test_transform_casting(self):
# 13046
data = """
idx A ID3 DATETIME
0 B-028 b76cd912ff "2014-10-08 13:43:27"
1 B-054 4a57ed0b02 "2014-10-08 14:26:19"
2 B-076 1a682034f8 "2014-10-08 14:29:01"
3 B-023 b76cd912ff "2014-10-08 18:39:34"
4 B-023 f88g8d7sds "2014-10-08 18:40:18"
5 B-033 b76cd912ff "2014-10-08 18:44:30"
6 B-032 b76cd912ff "2014-10-08 18:46:00"
7 B-037 b76cd912ff "2014-10-08 18:52:15"
8 B-046 db959faf02 "2014-10-08 18:59:59"
9 B-053 b76cd912ff "2014-10-08 19:17:48"
10 B-065 b76cd912ff "2014-10-08 19:21:38"
"""
df = pd.read_csv(StringIO(data), sep='\s+',
index_col=[0], parse_dates=['DATETIME'])

result = df.groupby('ID3')['DATETIME'].transform(lambda x: x.diff())
assert is_timedelta64_dtype(result.dtype)

result = df[['ID3', 'DATETIME']].groupby('ID3').transform(
lambda x: x.diff())
assert is_timedelta64_dtype(result.DATETIME.dtype)

def test_transform_multiple(self):
grouped = self.ts.groupby([lambda x: x.year, lambda x: x.month])

Expand Down
6 changes: 2 additions & 4 deletions pandas/tests/tseries/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,10 +757,8 @@ def test_resample_empty_series(self):
freq in ['M', 'D']):
# GH12871 - TODO: name should propagate, but currently
# doesn't on lower / same frequency with PeriodIndex
assert_series_equal(result, expected, check_dtype=False,
check_names=False)
# this assert will break when fixed
self.assertTrue(result.name is None)
assert_series_equal(result, expected, check_dtype=False)

else:
assert_series_equal(result, expected, check_dtype=False)

Expand Down

0 comments on commit 61fa8be

Please sign in to comment.