Skip to content

Commit

Permalink
Merge pull request #56 from dimagi/sk/return-all-rows
Browse files Browse the repository at this point in the history
return all rows even when no grouping is taking place
  • Loading branch information
mkangia committed Apr 15, 2020
2 parents 9458225 + 4d79879 commit c6273a5
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 21 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name='sqlagg',
version='0.16.1',
version='0.17.0-beta',
description='SQL aggregation tool',
author='Dimagi',
author_email='dev@dimagi.com',
Expand Down
29 changes: 15 additions & 14 deletions sqlagg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,11 @@ def resolve(self, connection, filter_values=None):
"""
Returns a dict containing the data of the following format:
* If group_by == [] or None
return a dict mapping column names to values: {'col_a': 1....}
return a dict mapping row numbers to row data:
e.g. {
1: {'col_a': 1....}
2: {...}
}
* If len(group_by) == 1
return a dict mapping groupings to that group levels data
e.g. {
Expand All @@ -277,22 +281,19 @@ def resolve(self, connection, filter_values=None):
for qm in self.query_meta.values():
result = qm.execute(self.connection, filter_values or {})

for r in result:
for index, sql_row in enumerate(result):
if not qm.group_by:
row_key = None
row_key = index
elif len(qm.group_by) == 1:
row_key = r[qm.group_by[0]]
row_key = sql_row[qm.group_by[0]]
elif len(qm.group_by) > 1:
row_key = tuple([r[group] for group in qm.group_by])

if qm.group_by:
if row_key is None:
# null values coming out of the database wreak havoc elsewhere in the code
row_key = ''
row = data.setdefault(row_key, {})
row.update(kvp for kvp in r.items())
else:
data.update(kvp for kvp in r.items())
row_key = tuple([sql_row[group] for group in qm.group_by])

if row_key is None:
# null values coming out of the database wreak havoc elsewhere in the code
row_key = ''
row = data.setdefault(row_key, {})
row.update(kvp for kvp in sql_row.items())

return data

Expand Down
13 changes: 8 additions & 5 deletions tests/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ def test_alias_column(self):
vc.append_column(i_a)
vc.append_column(i_a2)
data = vc.resolve(self.session.connection())
self.assertEqual(i_a.get_value(data), 6)
self.assertEqual(i_a2.get_value(data), 6)
self.assertEqual(len(data), 1)
self.assertEqual(i_a.get_value(data[0]), 6)
self.assertEqual(i_a2.get_value(data[0]), 6)

def test_alias_column_with_aliases(self):
vc = QueryContext("user_table")
Expand All @@ -101,8 +102,9 @@ def test_alias_column_with_aliases(self):
vc.append_column(i_a)
vc.append_column(i_a2)
data = vc.resolve(self.session.connection())
self.assertEqual(i_a.get_value(data), 6)
self.assertEqual(i_a2.get_value(data), 6)
self.assertEqual(len(data), 1)
self.assertEqual(i_a.get_value(data[0]), 6)
self.assertEqual(i_a2.get_value(data[0]), 6)

def test_aggregate_column(self):
col = AggregateColumn(lambda x, y: x + y,
Expand Down Expand Up @@ -216,7 +218,8 @@ def test_quarter(self):

def _test_view(self, view, expected):
data = self._get_view_data(view)
value = view.get_value(data)
self.assertEqual(len(data), 1)
value = view.get_value(data[0])
self.assertAlmostEqual(float(value), float(expected))

def _get_view_data(self, view):
Expand Down
3 changes: 2 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ class CustomColumn(BaseColumn):
vc.append_column(agg_view)
data = vc.resolve(self.session.connection(), None)

self.assertAlmostEqual(float(data["indicator_a"]), float(0.25))
self.assertEqual(len(data), 1)
self.assertAlmostEqual(float(data[0]["indicator_a"]), float(0.25))

def test_multiple_tables(self):
filters = [LT('date', 'enddate')]
Expand Down

0 comments on commit c6273a5

Please sign in to comment.