Skip to content

Commit

Permalink
Merge pull request #81 from sseyler/aggcategories
Browse files Browse the repository at this point in the history
Fixes #72.

Fixed behavior of groupby to preserve order of multiple keys
  • Loading branch information
dotsdl committed Jul 13, 2016
2 parents 20f724f + b1ac988 commit cf25408
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 27 deletions.
11 changes: 5 additions & 6 deletions src/datreant/core/agglimbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def groupby(self, keys):
corresponding Bundles are groupings of members in the collection having
the same category values (for the category specied by `keys`).
If `keys` is a list or set of keys, returns a dict of Bundles whose
If `keys` is a list of keys, returns a dict of Bundles whose
(new) keys are tuples of category values. The corresponding Bundles
contain the members in the collection that have the same set of
category values (for the categories specified by `keys`); members in
Expand All @@ -581,7 +581,7 @@ def groupby(self, keys):
Parameters
----------
keys : str, list, set
keys : str, list
Valid key(s) of categories in this collection.
Returns
Expand All @@ -603,9 +603,8 @@ def groupby(self, keys):
k in m.categories and m.categories[k] in groupkeys)
for m, catval in gen:
groups[catval].add(m)

elif isinstance(keys, (list, set)):
keys = sorted(keys)
# Note: redundant code in if/elif block can be consolidated in future
elif isinstance(keys, list):
catvals = list(zip(*members.categories[keys]))
groupkeys = [v for v in catvals if None not in v]
groups = {k: Bundle() for k in groupkeys}
Expand All @@ -617,6 +616,6 @@ def groupby(self, keys):
groups[catvals[i]].add(m)

else:
raise TypeError("Keys must be a string or a list or set of"
raise TypeError("Keys must be a string or a list of"
" strings")
return groups
31 changes: 10 additions & 21 deletions src/datreant/core/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,17 +788,10 @@ def test_categories_groupby(self, collection, testtreant, testgroup,
assert {t3} == set(age_bark[('old', 'mossy')])
assert {t4} == set(age_bark[('young', 'mossy')])

age_bark = collection.categories.groupby({'age', 'bark'})
assert len(age_bark) == 4
assert {t1} == set(age_bark[('young', 'smooth')])
assert {t2} == set(age_bark[('adult', 'fibrous')])
assert {t3} == set(age_bark[('old', 'mossy')])
assert {t4} == set(age_bark[('young', 'mossy')])

type_health = collection.categories.groupby(['type', 'health'])
assert len(type_health) == 2
assert {t3} == set(type_health[('poor', 'deciduous')])
assert {t4} == set(type_health[('good', 'deciduous')])
assert {t3} == set(type_health[('deciduous', 'poor')])
assert {t4} == set(type_health[('deciduous', 'good')])
for bundle in type_health.values():
assert {t1, t2}.isdisjoint(set(bundle))

Expand Down Expand Up @@ -826,29 +819,25 @@ def test_categories_groupby(self, collection, testtreant, testgroup,
keys = ['age', 'bark', 'type', 'nickname']
abtn = collection.categories.groupby(keys)
assert len(abtn) == 1
assert {t2} == set(abtn[('adult', 'fibrous', 'redwood',
'evergreen')])
assert {t2} == set(abtn[('adult', 'fibrous', 'evergreen',
'redwood')])
for bundle in abtn.values():
assert {t1, t3, t4}.isdisjoint(set(bundle))

keys = ['bark', 'nickname', 'type', 'age']
abtn2 = collection.categories.groupby(keys)
assert len(abtn2) == 1
assert {t2} == set(abtn2[('adult', 'fibrous', 'redwood',
'evergreen')])
assert {t2} == set(abtn2[('fibrous', 'redwood', 'evergreen',
'adult')])
for bundle in abtn2.values():
assert {t1, t3, t4}.isdisjoint(set(bundle))

keys = {'age', 'bark', 'type', 'nickname'}
abtn_set = collection.categories.groupby(keys)
assert len(abtn_set) == 1
assert {t2} == set(abtn_set[('adult', 'fibrous', 'redwood',
'evergreen')])
for bundle in abtn_set.values():
assert {t1, t3, t4}.isdisjoint(set(bundle))

keys = ['health', 'nickname']
health_nick = collection.categories.groupby(keys)
assert len(health_nick) == 0
for bundle in health_nick.values():
assert {t1, t2, t3, t4}.isdisjoint(set(bundle))

# Test key TypeError in groupby
with pytest.raises(TypeError) as e:
collection.categories.groupby({'health', 'nickname'})

0 comments on commit cf25408

Please sign in to comment.