Skip to content

Commit

Permalink
Merge pull request #67 from dimagi/mk/add-distinct-fix
Browse files Browse the repository at this point in the history
Distinct clause followup
  • Loading branch information
mkangia committed Mar 26, 2020
2 parents 5e004ee + a308990 commit 9458225
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name='sqlagg',
version='0.16.0',
version='0.16.1',
description='SQL aggregation tool',
author='Dimagi',
author_email='dev@dimagi.com',
Expand Down
48 changes: 27 additions & 21 deletions sqlagg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def __repr__(self):


class QueryMeta(object):
def __init__(self, table_name, filters, group_by, distinct, order_by):
def __init__(self, table_name, filters, group_by, distinct_on, order_by):
self.filters = filters
self.group_by = group_by
self.distinct = distinct
self.distinct_on = distinct_on
self.order_by = order_by
self.table_name = table_name

Expand All @@ -67,8 +67,8 @@ class SimpleQueryMeta(QueryMeta):
"""
Metadata about a query including the table being queried, list of columns, filters and group by columns.
"""
def __init__(self, table_name, filters, group_by, distinct, order_by, start=None, limit=None):
super(SimpleQueryMeta, self).__init__(table_name, filters, group_by, distinct, order_by)
def __init__(self, table_name, filters, group_by, distinct_on, order_by, start=None, limit=None):
super(SimpleQueryMeta, self).__init__(table_name, filters, group_by, distinct_on, order_by)
self.start = start
self.limit = limit
self.columns = []
Expand Down Expand Up @@ -109,7 +109,7 @@ def count(self, connection, filter_values):
assert self.limit is None
self._check()
query = self._build_query_generic(self.columns, group_by=self.group_by, filters=self.filters,
distinct=self.distinct)
distinct_on=self.distinct_on)
query = query.alias().count()
return connection.execute(query, **filter_values).fetchall()[0][0]

Expand All @@ -118,7 +118,7 @@ def totals(self, connection, filter_values, total_columns):
assert self.limit is None
self._check()

subquery = self._build_query_generic(self.columns, self.group_by, self.filters, self.distinct).alias()
subquery = self._build_query_generic(self.columns, self.group_by, self.filters, self.distinct_on).alias()
query = sqlalchemy.select().select_from(subquery)

for total_column in total_columns:
Expand All @@ -134,18 +134,18 @@ def _build_query(self):
self._check()
return self._build_query_generic(
self.columns, self.group_by,
self.filters, self.distinct, self.order_by, self.start, self.limit
self.filters, self.distinct_on, self.order_by, self.start, self.limit
)

def _build_query_generic(self, columns, group_by=None, filters=None, distinct=None,
def _build_query_generic(self, columns, group_by=None, filters=None, distinct_on=None,
order_by=None, start=None, limit=None):
try:
query = sqlalchemy.select()
if group_by or distinct:
if group_by or distinct_on:
cols = [c.column_name for c in columns]
alias = [c.alias for c in columns]
if distinct:
for col_key in distinct:
if distinct_on:
for col_key in distinct_on:
if col_key in cols:
query = query.distinct(column(col_key))
elif col_key in alias:
Expand Down Expand Up @@ -190,16 +190,17 @@ def _build_query_generic(self, columns, group_by=None, filters=None, distinct=No
return query

def __repr__(self):
return "Querymeta(columns=%s, filters=%s, group_by=%s, distinct=%s, order_by=%s, table=%s)" % \
(self.columns, self.filters, self.group_by, self.distinct, self.order_by, self.table_name)
return "Querymeta(columns=%s, filters=%s, group_by=%s, distinct_on=%s, order_by=%s, table=%s)" % \
(self.columns, self.filters, self.group_by, self.distinct_on, self.order_by, self.table_name)


class QueryContext(object):
def __init__(self, table, filters=None, group_by=None, distinct=None, order_by=None, start=None, limit=None):
def __init__(self, table, filters=None, group_by=None, distinct_on=None, order_by=None,
start=None, limit=None):
self.table_name = table
self.filters = filters or []
self.group_by = group_by or []
self.distinct = distinct or []
self.distinct_on = distinct_on or []
self.order_by = order_by or []
self.start = start
self.limit = limit
Expand All @@ -222,14 +223,15 @@ def append_column(self, column):

def _new_query_meta(self, column):
if isinstance(column, QueryColumn):
return column.get_query_meta(self.table_name, self.filters, self.group_by, self.order_by)
return column.get_query_meta(self.table_name, self.filters, self.group_by, self.distinct_on,
self.order_by)
else:
table_name = column.table_name or self.table_name
filters = column.filters or self.filters
group_by = column.group_by or self.group_by
order_by = column.order_by or self.order_by
return SimpleQueryMeta(
table_name, filters, group_by, self.distinct, order_by,
table_name, filters, group_by, self.distinct_on, order_by,
start=self.start, limit=self.limit
)

Expand Down Expand Up @@ -320,19 +322,21 @@ def get_value(self, row):


class QueryColumn(SqlAggColumn):
def get_query_meta(self, table_name, filters, group_by, order_by):
def get_query_meta(self, table_name, filters, group_by, distinct_on, order_by):
raise NotImplementedError()


class BaseColumn(SqlAggColumn):
aggregate_fn = None

def __init__(self, key, alias=None, table_name=None, filters=None, group_by=None, order_by=None):
def __init__(self, key, alias=None, table_name=None, filters=None, group_by=None, distinct_on=None,
order_by=None):
self.key = key
self.alias = alias
self.table_name = table_name
self.filters = filters
self.group_by = group_by
self.distinct_on = distinct_on
self.order_by = order_by

if self.filters:
Expand Down Expand Up @@ -365,12 +369,14 @@ class CustomQueryColumn(BaseColumn, QueryColumn):
query_cls = None
name = None

def get_query_meta(self, default_table_name, default_filters, default_group_by, default_order_by):
def get_query_meta(self, default_table_name, default_filters, default_group_by, default_distinct_on,
default_order_by):
table_name = self.table_name or default_table_name
filters = self.filters or default_filters
group_by = self.group_by or default_group_by
distinct_on = self.distinct_on or default_distinct_on
order_by = self.order_by or default_order_by
return self.query_cls(table_name, filters, group_by, order_by)
return self.query_cls(table_name, filters, group_by, distinct_on, order_by)

@property
def column_key(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,10 @@ def _get_view_data(self, view):
vc.append_column(view)
return vc.resolve(self.session.connection())

def test_distinct(self):
def test_distinct_on(self):
vc = QueryContext(
"user_table",
distinct=['user', 'year'],
distinct_on=['user', 'year'],
order_by=[OrderBy('user'), OrderBy('year'), OrderBy('date', is_ascending=False)],
group_by=['user', 'date']
)
Expand Down

0 comments on commit 9458225

Please sign in to comment.