Skip to content

Commit

Permalink
BUG: preserve categorical & sparse types when grouping / pivot
Browse files Browse the repository at this point in the history
  • Loading branch information
jreback committed Jun 27, 2019
1 parent d94146c commit 1152507
Show file tree
Hide file tree
Showing 12 changed files with 154 additions and 62 deletions.
29 changes: 29 additions & 0 deletions doc/source/whatsnew/v0.25.0.rst
Expand Up @@ -317,6 +317,35 @@ of ``object`` dtype. :attr:`Series.str` will now infer the dtype data *within* t
s
s.str.startswith(b'a')
.. _whatsnew_0250.api_breaking.groupby_categorical:

Categorical dtypes are preserved during groupby
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Previously, columns that were categorical, but not the groupby key(s) would be converted to ``object`` dtype during groupby operations. Pandas now will preserve these dtypes. (:issue:`18502`)

.. ipython:: python
df = pd.DataFrame(
{'payload': [-1, -2, -1, -2],
'col': pd.Categorical(["foo", "bar", "bar", "qux"], ordered=True)})
df
df.dtypes
*Previous Behavior*:

.. code-block:: python
In [5]: df.groupby('payload').first().col.dtype
Out[5]: dtype('O')
*New Behavior*:

.. ipython:: python
df.groupby('payload').first().col.dtype
.. _whatsnew_0250.api_breaking.incompatible_index_unions:

Incompatible Index Type Unions
Expand Down
11 changes: 9 additions & 2 deletions pandas/core/groupby/generic.py
Expand Up @@ -156,12 +156,19 @@ def _cython_agg_blocks(self, how, alt=None, numeric_only=True,

obj = self.obj[data.items[locs]]
s = groupby(obj, self.grouper)
result = s.aggregate(lambda x: alt(x, axis=self.axis))
try:
result = s.aggregate(lambda x: alt(x, axis=self.axis))
except Exception:
# we may have an exception in trying to aggregate
# continue and exclude the block
pass

finally:

dtype = block.values.dtype

# see if we can cast the block back to the original dtype
result = block._try_coerce_and_cast_result(result)
result = block._try_coerce_and_cast_result(result, dtype=dtype)
newb = block.make_block(result)

new_items.append(locs)
Expand Down
42 changes: 32 additions & 10 deletions pandas/core/groupby/groupby.py
Expand Up @@ -786,6 +786,8 @@ def _try_cast(self, result, obj, numeric_only=False):
elif is_extension_array_dtype(dtype):
# The function can return something of any type, so check
# if the type is compatible with the calling EA.

# return the same type (Series) as our caller
try:
result = obj._values._from_sequence(result, dtype=dtype)
except Exception:
Expand Down Expand Up @@ -1157,7 +1159,8 @@ def mean(self, *args, **kwargs):
"""
nv.validate_groupby_func('mean', args, kwargs, ['numeric_only'])
try:
return self._cython_agg_general('mean', **kwargs)
return self._cython_agg_general(
'mean', alt=lambda x, axis: Series(x).mean(**kwargs), **kwargs)
except GroupByError:
raise
except Exception: # pragma: no cover
Expand All @@ -1179,7 +1182,11 @@ def median(self, **kwargs):
Median of values within each group.
"""
try:
return self._cython_agg_general('median', **kwargs)
return self._cython_agg_general(
'median',
alt=lambda x,
axis: Series(x).median(axis=axis, **kwargs),
**kwargs)
except GroupByError:
raise
except Exception: # pragma: no cover
Expand Down Expand Up @@ -1235,7 +1242,10 @@ def var(self, ddof=1, *args, **kwargs):
nv.validate_groupby_func('var', args, kwargs)
if ddof == 1:
try:
return self._cython_agg_general('var', **kwargs)
return self._cython_agg_general(
'var',
alt=lambda x, axis: Series(x).var(ddof=ddof, **kwargs),
**kwargs)
except Exception:
f = lambda x: x.var(ddof=ddof, **kwargs)
with _group_selection_context(self):
Expand Down Expand Up @@ -1263,7 +1273,6 @@ def sem(self, ddof=1):
Series or DataFrame
Standard error of the mean of values within each group.
"""

return self.std(ddof=ddof) / np.sqrt(self.count())

@Substitution(name='groupby')
Expand All @@ -1290,7 +1299,7 @@ def _add_numeric_operations(cls):
"""

def groupby_function(name, alias, npfunc,
numeric_only=True, _convert=False,
numeric_only=True,
min_count=-1):

_local_template = """
Expand All @@ -1312,17 +1321,30 @@ def f(self, **kwargs):
kwargs['min_count'] = min_count

self._set_group_selection()

# try a cython aggregation if we can
try:
return self._cython_agg_general(
alias, alt=npfunc, **kwargs)
except AssertionError as e:
raise SpecificationError(str(e))
except Exception:
result = self.aggregate(
lambda x: npfunc(x, axis=self.axis))
if _convert:
result = result._convert(datetime=True)
return result
pass

# apply a non-cython aggregation
result = self.aggregate(
lambda x: npfunc(x, axis=self.axis))

# coerce the resulting columns if we can
if isinstance(result, DataFrame):
for col in result.columns:
result[col] = self._try_cast(
result[col], self.obj[col])
else:
result = self._try_cast(
result, self.obj)

return result

set_function_name(f, name, cls)

Expand Down
6 changes: 3 additions & 3 deletions pandas/core/groupby/ops.py
Expand Up @@ -19,7 +19,7 @@
from pandas.core.dtypes.common import (
ensure_float64, ensure_int64, ensure_int_or_float, ensure_object,
ensure_platform_int, is_bool_dtype, is_categorical_dtype, is_complex_dtype,
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype,
is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype, is_sparse,
is_timedelta64_dtype, needs_i8_conversion)
from pandas.core.dtypes.missing import _maybe_fill, isna

Expand Down Expand Up @@ -451,9 +451,9 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1,

# categoricals are only 1d, so we
# are not setup for dim transforming
if is_categorical_dtype(values):
if is_categorical_dtype(values) or is_sparse(values):
raise NotImplementedError(
"categoricals are not support in cython ops ATM")
"{} are not support in cython ops".format(values.dtype))
elif is_datetime64_any_dtype(values):
if how in ['add', 'prod', 'cumsum', 'cumprod']:
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexing.py
Expand Up @@ -10,7 +10,7 @@

from pandas.core.dtypes.common import (
ensure_platform_int, is_float, is_integer, is_integer_dtype, is_iterator,
is_list_like, is_numeric_dtype, is_scalar, is_sequence, is_sparse)
is_list_like, is_numeric_dtype, is_scalar, is_sequence)
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
from pandas.core.dtypes.missing import _infer_fill_value, isna

Expand Down
24 changes: 23 additions & 1 deletion pandas/core/internals/blocks.py
Expand Up @@ -594,7 +594,8 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
values = self.get_values(dtype=dtype)

# _astype_nansafe works fine with 1-d only
values = astype_nansafe(values.ravel(), dtype, copy=True)
values = astype_nansafe(
values.ravel(), dtype, copy=True, **kwargs)

# TODO(extension)
# should we make this attribute?
Expand Down Expand Up @@ -1746,6 +1747,27 @@ def _slice(self, slicer):

return self.values[slicer]

def _try_cast_result(self, result, dtype=None):
"""
if we have an operation that operates on for example floats
we want to try to cast back to our EA here if possible
result could be a 2-D numpy array, e.g. the result of
a numeric operation; but it must be shape (1, X) because
we by-definition operate on the ExtensionBlocks one-by-one
result could also be an EA Array itself, in which case it
is already a 1-D array
"""
try:

result = self._holder._from_sequence(
np.asarray(result).ravel(), dtype=dtype)
except Exception:
pass

return result

def formatting_values(self):
# Deprecating the ability to override _formatting_values.
# Do the warning here, it's only user in pandas, since we
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/internals/construction.py
Expand Up @@ -666,7 +666,10 @@ def sanitize_array(data, index, dtype=None, copy=False,
data = np.array(data, dtype=dtype, copy=False)
subarr = np.array(data, dtype=object, copy=copy)

if is_object_dtype(subarr.dtype) and dtype != 'object':
if (not (is_extension_array_dtype(subarr.dtype) or
is_extension_array_dtype(dtype)) and
is_object_dtype(subarr.dtype) and
not is_object_dtype(dtype)):
inferred = lib.infer_dtype(subarr, skipna=False)
if inferred == 'period':
try:
Expand Down
9 changes: 5 additions & 4 deletions pandas/core/nanops.py
Expand Up @@ -72,11 +72,12 @@ def _f(*args, **kwargs):

class bottleneck_switch:

def __init__(self, **kwargs):
def __init__(self, name=None, **kwargs):
self.name = name
self.kwargs = kwargs

def __call__(self, alt):
bn_name = alt.__name__
bn_name = self.name or alt.__name__

try:
bn_func = getattr(bn, bn_name)
Expand Down Expand Up @@ -804,7 +805,8 @@ def nansem(values, axis=None, skipna=True, ddof=1, mask=None):


def _nanminmax(meth, fill_value_typ):
@bottleneck_switch()

@bottleneck_switch(name='nan' + meth)
def reduction(values, axis=None, skipna=True, mask=None):

values, mask, dtype, dtype_max, fill_value = _get_values(
Expand All @@ -824,7 +826,6 @@ def reduction(values, axis=None, skipna=True, mask=None):
result = _wrap_results(result, dtype, fill_value)
return _maybe_null_out(result, axis, mask, values.shape)

reduction.__name__ = 'nan' + meth
return reduction


Expand Down
53 changes: 26 additions & 27 deletions pandas/tests/groupby/test_function.py
Expand Up @@ -12,7 +12,7 @@
from pandas import (
DataFrame, Index, MultiIndex, Series, Timestamp, date_range, isna)
import pandas.core.nanops as nanops
from pandas.util import testing as tm
from pandas.util import _test_decorators as td, testing as tm


@pytest.mark.parametrize("agg_func", ['any', 'all'])
Expand Down Expand Up @@ -144,6 +144,7 @@ def test_arg_passthru():
index=Index([1, 2], name='group'),
columns=['int', 'float', 'category_int',
'datetime', 'datetimetz', 'timedelta'])

for attr in ['mean', 'median']:
f = getattr(df.groupby('group'), attr)
result = f()
Expand Down Expand Up @@ -459,35 +460,33 @@ def test_groupby_cumprod():
tm.assert_series_equal(actual, expected)


def test_ops_general():
ops = [('mean', np.mean),
('median', np.median),
('std', np.std),
('var', np.var),
('sum', np.sum),
('prod', np.prod),
('min', np.min),
('max', np.max),
('first', lambda x: x.iloc[0]),
('last', lambda x: x.iloc[-1]),
('count', np.size), ]
try:
from scipy.stats import sem
except ImportError:
pass
else:
ops.append(('sem', sem))
def scipy_sem(*args, **kwargs):
from scipy.stats import sem
return sem(*args, ddof=1, **kwargs)


@pytest.mark.parametrize(
'op,targop',
[('mean', np.mean),
('median', np.median),
('std', np.std),
('var', np.var),
('sum', np.sum),
('prod', np.prod),
('min', np.min),
('max', np.max),
('first', lambda x: x.iloc[0]),
('last', lambda x: x.iloc[-1]),
('count', np.size),
pytest.param(
'sem', scipy_sem, marks=td.skip_if_no_scipy)])
def test_ops_general(op, targop):
df = DataFrame(np.random.randn(1000))
labels = np.random.randint(0, 50, size=1000).astype(float)

for op, targop in ops:
result = getattr(df.groupby(labels), op)().astype(float)
expected = df.groupby(labels).agg(targop)
try:
tm.assert_frame_equal(result, expected)
except BaseException as exc:
exc.args += ('operation: %s' % op, )
raise
result = getattr(df.groupby(labels), op)().astype(float)
expected = df.groupby(labels).agg(targop)
tm.assert_frame_equal(result, expected)


def test_max_nan_bug():
Expand Down
19 changes: 11 additions & 8 deletions pandas/tests/groupby/test_nth.py
Expand Up @@ -282,18 +282,21 @@ def test_first_last_tz(data, expected_first, expected_last):
])
def test_first_last_tz_multi_column(method, ts, alpha):
# GH 21603
category_string = pd.Series(list('abc')).astype(
'category')
df = pd.DataFrame({'group': [1, 1, 2],
'category_string': pd.Series(list('abc')).astype(
'category'),
'category_string': category_string,
'datetimetz': pd.date_range('20130101', periods=3,
tz='US/Eastern')})
result = getattr(df.groupby('group'), method)()
expepcted = pd.DataFrame({'category_string': [alpha, 'c'],
'datetimetz': [ts,
Timestamp('2013-01-03',
tz='US/Eastern')]},
index=pd.Index([1, 2], name='group'))
assert_frame_equal(result, expepcted)
expected = pd.DataFrame(
{'category_string': pd.Categorical(
[alpha, 'c'], dtype=category_string.dtype),
'datetimetz': [ts,
Timestamp('2013-01-03',
tz='US/Eastern')]},
index=pd.Index([1, 2], name='group'))
assert_frame_equal(result, expected)


def test_nth_multi_index_as_expected():
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/resample/test_datetime_index.py
Expand Up @@ -112,6 +112,12 @@ def test_resample_integerarray():
dtype="Int64")
assert_series_equal(result, expected)

result = ts.resample('3T').mean()
expected = Series([1, 4, 7],
index=pd.date_range('1/1/2000', periods=3, freq='3T'),
dtype='Int64')
assert_series_equal(result, expected)


def test_resample_basic_grouper(series):
s = series
Expand Down

0 comments on commit 1152507

Please sign in to comment.