Skip to content

Commit

Permalink
Merge c909542 into 0cc08cf
Browse files Browse the repository at this point in the history
  • Loading branch information
snopoke committed Nov 6, 2018
2 parents 0cc08cf + c909542 commit b531833
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 28 deletions.
52 changes: 31 additions & 21 deletions sqlagg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,67 +105,77 @@ def get_query_string(self, metadata, connection):
def count(self, metadata, connection, filter_values):
assert self.start is None
assert self.limit is None
query = self._build_query(metadata).alias().count()
self._check()
query = self._build_query_generic(metadata, self.columns, group_by=self.group_by, filters=self.filters)
query = query.alias().count()
return connection.execute(query, **filter_values).fetchall()[0][0]

def totals(self, metadata, connection, filter_values, total_columns):
assert self.start is None
assert self.limit is None
self._check()

def _generate_total_column(column_name, selectable):
from sqlagg import SumColumn
return SumColumn(column_name).sql_column.build_column(selectable)

subquery = self._build_query(metadata).alias()
subquery = self._build_query_generic(metadata, self.columns, self.group_by, self.filters).alias()
query = sqlalchemy.select().select_from(subquery)

for total_column in total_columns:
query.append_column(_generate_total_column(total_column, subquery))
column = SimpleSqlColumn(total_column, sqlalchemy.func.sum)
query.append_column(column.build_column(subquery))

return dict(zip(
total_columns,
connection.execute(query, **filter_values).fetchall()[0]
))

def _build_query(self, metadata):
self._check()
return self._build_query_generic(
metadata, self.columns, self.group_by,
self.filters, self.order_by, self.start, self.limit
)

def _build_query_generic(self, metadata, columns, group_by=None, filters=None, order_by=None, start=None, limit=None):
try:
table = metadata.tables[self.table_name]
except KeyError:
raise TableNotFoundException("Unable to query table, table not found: %s" % self.table_name)

try:
query = sqlalchemy.select()
if self.group_by:
cols = [c.column_name for c in self.columns]
alias = [c.alias for c in self.columns]
for group_key in self.group_by:
if group_by:
cols = [c.column_name for c in columns]
alias = [c.alias for c in columns]
for group_key in group_by:
if group_key in cols:
query.append_group_by(table.c[group_key])
elif group_key in alias:
aliased_columns = [col.build_column(table) for col in self.columns if col.alias == group_key]
aliased_columns = [col.build_column(table) for col in columns if col.alias == group_key]
assert len(aliased_columns) == 1, "Only one column should have this alias"
query.append_group_by(aliased_columns[0])
else:
raise SqlAggException("Group by column not present in query columns or aliases")

for c in self.columns:
for c in columns:
query.append_column(c.build_column(table))
except KeyError as e:
raise ColumnNotFoundException("Missing column in table (%s): %s" % (self.table_name, e))

if self.filters:
for filter in self.filters:
if filters:
for filter in filters:
query.append_whereclause(filter.build_expression(table))

if not query.froms:
query = query.select_from(table)

if self.order_by:
for order_by_column in self.order_by:
if order_by:
for order_by_column in order_by:
order = order_by_column.build_expression()
query = query.order_by(order)

if self.start is not None:
query = query.offset(self.start)
if self.limit is not None:
query = query.limit(self.limit)
if start is not None:
query = query.offset(start)
if limit is not None:
query = query.limit(limit)

return query

Expand Down
32 changes: 25 additions & 7 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest import TestCase

from sqlagg.exceptions import DuplicateColumnsException
from sqlagg.sorting import OrderBy
from . import BaseTest
from sqlalchemy.orm import scoped_session, sessionmaker
from datetime import date
Expand Down Expand Up @@ -177,12 +178,9 @@ def test_totals_no_filter(self):
group_by=["user"],
)

for column_name in [
'indicator_a',
'indicator_b',
'indicator_c',
]:
vc.append_column(SumColumn(column_name))
vc.append_column(MeanColumn('indicator_a'))
vc.append_column(SumColumn('indicator_b'))
vc.append_column(SumColumn('indicator_c'))

self.assertEqual(
vc.totals(
Expand All @@ -194,7 +192,7 @@ def test_totals_no_filter(self):
],
),
{
'indicator_a': 6,
'indicator_a': 3,
'indicator_b': 5,
'indicator_c': 3,
},
Expand Down Expand Up @@ -231,6 +229,26 @@ def test_totals_with_filter(self):
},
)

def test_count_group_by(self):
vc = QueryContext(
"user_table",
group_by=["user"],
order_by=[OrderBy("user", is_ascending=True)]
)

vc.append_column(SumColumn('indicator_a'))
self.assertEqual(2, vc.count(self.session.connection()))

def test_count_with_filter(self):
vc = QueryContext(
"user_table",
filters=[EQ('user', 'username')],
group_by=["user"],
)

vc.append_column(SumColumn('indicator_a'))
self.assertEqual(1, vc.count(self.session.connection(), {'username': 'user1'},))

def test_user_view_data(self):
data = self._get_user_view_data(None, None)

Expand Down

0 comments on commit b531833

Please sign in to comment.