Skip to content

Commit

Permalink
BUG: Filter/transform fail in some cases when multi-grouping with a d…
Browse files Browse the repository at this point in the history
…atetime-like key (GH pandas-dev#10114)
  • Loading branch information
Evan Wright committed Jul 29, 2015
1 parent 92da9ed commit c84ab54
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 31 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v0.17.0.txt
Expand Up @@ -390,7 +390,7 @@ Bug Fixes
- Bug in ``pd.get_dummies`` with `sparse=True` not returning ``SparseDataFrame`` (:issue:`10531`)
- Bug in ``Index`` subtypes (such as ``PeriodIndex``) not returning their own type for ``.drop`` and ``.insert`` methods (:issue:`10620`)


- Bug in ``filter`` (regression from 0.16.0) and ``transform`` when grouping on multiple keys, one of which is datetime-like (:issue:`10114`)



Expand Down
66 changes: 36 additions & 30 deletions pandas/core/groupby.py
Expand Up @@ -422,46 +422,55 @@ def indices(self):
""" dict {group name -> group indices} """
return self.grouper.indices

def _get_index(self, name):
""" safe get index, translate keys for datelike to underlying repr """
def _get_indices(self, names):
""" safe get multiple indices, translate keys for datelike to underlying repr """

def convert(key, s):
def get_converter(s):
# possibly convert to they actual key types
# in the indices, could be a Timestamp or a np.datetime64

if isinstance(s, (Timestamp,datetime.datetime)):
return Timestamp(key)
return lambda key: Timestamp(key)
elif isinstance(s, np.datetime64):
return Timestamp(key).asm8
return key
return lambda key: Timestamp(key).asm8
else:
return lambda key: key

if len(names) == 0:
return []

if len(self.indices) > 0:
sample = next(iter(self.indices))
index_sample = next(iter(self.indices))
else:
sample = None # Dummy sample
index_sample = None # Dummy sample

if isinstance(sample, tuple):
if not isinstance(name, tuple):
name_sample = names[0]
if isinstance(index_sample, tuple):
if not isinstance(name_sample, tuple):
msg = ("must supply a tuple to get_group with multiple"
" grouping keys")
raise ValueError(msg)
if not len(name) == len(sample):
if not len(name_sample) == len(index_sample):
try:
# If the original grouper was a tuple
return self.indices[name]
return [self.indices[name] for name in names]
except KeyError:
# turns out it wasn't a tuple
msg = ("must supply a a same-length tuple to get_group"
" with multiple grouping keys")
raise ValueError(msg)

name = tuple([ convert(n, k) for n, k in zip(name,sample) ])
converters = [get_converter(s) for s in index_sample]
names = [tuple([f(n) for f, n in zip(converters, name)]) for name in names]

else:
converter = get_converter(index_sample)
names = [converter(name) for name in names]

name = convert(name, sample)
return [self.indices.get(name, []) for name in names]

return self.indices[name]
def _get_index(self, name):
""" safe get index, translate keys for datelike to underlying repr """
return self._get_indices([name])[0]

@property
def name(self):
Expand Down Expand Up @@ -507,7 +516,7 @@ def _set_result_index_ordered(self, result):

# shortcut of we have an already ordered grouper
if not self.grouper.is_monotonic:
index = Index(np.concatenate([ indices.get(v, []) for v in self.grouper.result_index]))
index = Index(np.concatenate(self._get_indices(self.grouper.result_index)))
result.index = index
result = result.sort_index()

Expand Down Expand Up @@ -612,6 +621,9 @@ def get_group(self, name, obj=None):
obj = self._selected_obj

inds = self._get_index(name)
if not len(inds):
raise KeyError(name)

return obj.take(inds, axis=self.axis, convert=False)

def __iter__(self):
Expand Down Expand Up @@ -2457,9 +2469,6 @@ def transform(self, func, *args, **kwargs):

wrapper = lambda x: func(x, *args, **kwargs)
for i, (name, group) in enumerate(self):
if name not in self.indices:
continue

object.__setattr__(group, 'name', name)
res = wrapper(group)

Expand All @@ -2474,7 +2483,7 @@ def transform(self, func, *args, **kwargs):
except:
pass

indexer = self.indices[name]
indexer = self._get_index(name)
result[indexer] = res

result = _possibly_downcast_to_dtype(result, dtype)
Expand Down Expand Up @@ -2528,11 +2537,8 @@ def true_and_notnull(x, *args, **kwargs):
return b and notnull(b)

try:
indices = []
for name, group in self:
if true_and_notnull(group) and name in self.indices:
indices.append(self.indices[name])

indices = [self._get_index(name) for name, group in self
if true_and_notnull(group)]
except ValueError:
raise TypeError("the filter must return a boolean result")
except TypeError:
Expand Down Expand Up @@ -3060,8 +3066,8 @@ def transform(self, func, *args, **kwargs):
results = np.empty_like(obj.values, result.values.dtype)
indices = self.indices
for (name, group), (i, row) in zip(self, result.iterrows()):
if name in indices:
indexer = indices[name]
indexer = self._get_index(name)
if len(indexer) > 0:
results[indexer] = np.tile(row.values,len(indexer)).reshape(len(indexer),-1)

counts = self.size().fillna(0).values
Expand Down Expand Up @@ -3162,8 +3168,8 @@ def filter(self, func, dropna=True, *args, **kwargs):

# interpret the result of the filter
if is_bool(res) or (lib.isscalar(res) and isnull(res)):
if res and notnull(res) and name in self.indices:
indices.append(self.indices[name])
if res and notnull(res):
indices.append(self._get_index(name))
else:
# non scalars aren't allowed
raise TypeError("filter function returned a %s, "
Expand Down
26 changes: 26 additions & 0 deletions pandas/tests/test_groupby.py
Expand Up @@ -4477,6 +4477,32 @@ def test_filter_maintains_ordering(self):
expected = s.iloc[[1, 2, 4, 7]]
assert_series_equal(actual, expected)

def test_filter_multiple_timestamp(self):
# GH 10114
df = DataFrame({'A' : np.arange(5),
'B' : ['foo','bar','foo','bar','bar'],
'C' : Timestamp('20130101') })

grouped = df.groupby(['B', 'C'])

result = grouped['A'].filter(lambda x: True)
assert_series_equal(df['A'], result)

result = grouped['A'].transform(len)
expected = Series([2, 3, 2, 3, 3], name='A')
assert_series_equal(result, expected)

result = grouped.filter(lambda x: True)
assert_frame_equal(df, result)

result = grouped.transform('sum')
expected = DataFrame({'A' : [2, 8, 2, 8, 8]})
assert_frame_equal(result, expected)

result = grouped.transform(len)
expected = DataFrame({'A' : [2, 3, 2, 3, 3]})
assert_frame_equal(result, expected)

def test_filter_and_transform_with_non_unique_int_index(self):
# GH4620
index = [1, 1, 1, 2, 1, 1, 0, 1]
Expand Down

0 comments on commit c84ab54

Please sign in to comment.