Skip to content

Commit

Permalink
add conditional columns
Browse files Browse the repository at this point in the history
  • Loading branch information
snopoke committed May 24, 2013
1 parent cd66c6f commit 9f40946
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 13 deletions.
30 changes: 25 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,36 @@ Here `column_a` will be selected from the table configured in the QueryContext w
*table_name* and will be grouped by *user*. This will result in two queries being run on the database.

## As Name
It is possible to use the same column in multiple columns by specifying the `as_name` argument of the column.
It is possible to use the same column in multiple columns by specifying the `alias` argument of the column.

```python
sum_a = SumColumn("column_a", as_name="sum_a")
count_a = CountColumn("column_a", as_name="count_a")
sum_a = SumColumn("column_a", alias="sum_a")
count_a = CountColumn("column_a", alias="count_a")
```

The resulting data will use the `as_name` keys to reference the values.
The resulting data will use the `alias` keys to reference the values.

TODO: custom queries, AliasColumn
## Conditional / Case columns
*Simple*
```python
num_wheels = SumWhen("vehicle", whens={"unicycle": 1, "bicycle": 2, "car": 4}, else_=0, alias="num_wheels")
```

*Complex*
```python
num_children = SumWhen(whens={"users.age < 13", 1}, else_=0, alias="children")
```

## Alias and Aggregate columns
Useful if you want to use a column more than once but don't want to re-calculate its value.
```python
sum_a = SumColumn("column_a")

aggregate = AggregateColumn(lambda x, y: x / y,
AliasColumn("column_a"),
SumColumn("column_b")
```
TODO: custom queries



2 changes: 1 addition & 1 deletion sqlagg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .base import *
from .columns import SumColumn, CountColumn, MaxColumn, MinColumn, MeanColumn, MedianColumn, UniqueColumn
from .columns import SumColumn, CountColumn, MaxColumn, MinColumn, MeanColumn, MedianColumn, UniqueColumn, SumWhen
17 changes: 14 additions & 3 deletions sqlagg/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# -*- coding: utf-8 -*-
import sqlalchemy
import logging


class SqlColumn(object):
def build_column(self, sql_table):
raise NotImplementedError()


class SimpleSqlColumn(SqlColumn):
"""
Simple representation of a column with a name and an aggregation function which can be None.
"""
Expand Down Expand Up @@ -43,7 +47,7 @@ def __init__(self, table_name, filters, group_by):
self.columns = []

def append_column(self, column):
self.columns.append(SqlColumn(column.key, column.aggregate_fn, column.alias))
self.columns.append(column.sql_column)

def _check(self):
if self.group_by:
Expand All @@ -53,7 +57,7 @@ def _check(self):
groups.remove(c.column_name)

for g in groups:
self.columns.append(SqlColumn(g, aggregate_fn=None, alias=g))
self.columns.append(SimpleSqlColumn(g, aggregate_fn=None, alias=g))

def execute(self, metadata, connection, filter_values):
query = self._build_query(metadata)
Expand All @@ -74,6 +78,9 @@ def _build_query(self, metadata):
for filter in self.filters:
query.append_whereclause(filter)

if not query.froms:
query = query.select_from(table)

return query

def __repr__(self):
Expand Down Expand Up @@ -175,6 +182,10 @@ def __init__(self, key, alias=None, table_name=None, filters=None, group_by=None
def column_key(self):
return self.table_name, str(self.filters), str(self.group_by)

@property
def sql_column(self):
return SimpleSqlColumn(self.key, self.aggregate_fn, self.alias)

def get_value(self, row):
row_key = self.alias or self.key
return row.get(row_key, None) if row else None
Expand Down
44 changes: 41 additions & 3 deletions sqlagg/columns.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy import func, distinct
from sqlalchemy import func, distinct, case, text
from queries import MedianQueryMeta
from .base import BaseColumn, CustomQueryColumn
from .base import BaseColumn, CustomQueryColumn, SqlColumn


class SimpleColumn(BaseColumn):
Expand Down Expand Up @@ -33,4 +33,42 @@ class UniqueColumn(BaseColumn):

class MedianColumn(CustomQueryColumn):
query_cls = MedianQueryMeta
name = "median"
name = "median"


class ConditionalAggregation(BaseColumn):
def __init__(self, key=None, whens={}, else_=None, *args, **kwargs):
super(ConditionalAggregation, self).__init__(key, *args, **kwargs)
self.whens = whens
self.else_ = else_

assert self.key or self.alias, "Column must have either a key or an alias"

@property
def sql_column(self):
return ConditionalColumn(self.key, self.whens, self.else_, self.aggregate_fn, self.alias)


class SumWhen(ConditionalAggregation):
aggregate_fn = func.sum


class ConditionalColumn(SqlColumn):
def __init__(self, column_name, whens, else_, aggregate_fn, alias):
self.aggregate_fn = aggregate_fn
self.column_name = column_name
self.whens = whens
self.else_ = else_
self.alias = alias or column_name

def build_column(self, sql_table):
if self.column_name:
expr = case(value=sql_table.c[self.column_name], whens=self.whens, else_=self.else_)
else:
whens = {}
for when, then in self.whens.items():
whens[text(when)] = then

expr = case(whens=whens, else_=self.else_)

return self.aggregate_fn(expr).label(self.alias)
17 changes: 16 additions & 1 deletion tests/test_columns.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from unittest2 import TestCase, skip
from unittest2 import TestCase
from . import BaseTest, engine
from sqlalchemy.orm import sessionmaker

from sqlagg import *

Session = sessionmaker()


Expand Down Expand Up @@ -77,6 +78,20 @@ def test_aggregate_column(self):
SumColumn("indicator_c"))
self._test_view(col, 9)

def test_conditional_column_simple(self):
# sum(case user when 'user1' then 1 when 'user2' then 3 else 0)
col = SumWhen('user', whens={'user1': 1, 'user2': 3}, else_=0)
self._test_view(col, 8)

def test_conditional_column_complex(self):
# sum(case when indicator_a < 1 OR indicator_a > 2 then 1 else 0)
col = SumWhen(whens={'user_table.indicator_a < 1 OR user_table.indicator_a > 2': 1}, alias='a')
self._test_view(col, 2)

# sum(case when indicator_a between 1 and 2 then 0 else 1)
col = SumWhen(whens={'user_table.indicator_a between 1 and 2': 0}, else_=1, alias='a')
self._test_view(col, 2)

def _test_view(self, view, expected):
data = self._get_view_data(view)
value = view.get_value(data)
Expand Down

0 comments on commit 9f40946

Please sign in to comment.