From 9f40946c443a1e9f53906c6b0c6b3ffbb1891107 Mon Sep 17 00:00:00 2001 From: Simon Kelly Date: Thu, 23 May 2013 17:52:45 +0200 Subject: [PATCH] add conditional columns --- README.md | 30 ++++++++++++++++++++++++----- sqlagg/__init__.py | 2 +- sqlagg/base.py | 17 ++++++++++++++--- sqlagg/columns.py | 44 ++++++++++++++++++++++++++++++++++++++++--- tests/test_columns.py | 17 ++++++++++++++++- 5 files changed, 97 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index dd567e1..6f29e63 100644 --- a/README.md +++ b/README.md @@ -112,16 +112,36 @@ Here `column_a` will be selected from the table configured in the QueryContext w *table_name* and will be grouped by *user*. This will result in two queries being run on the database. ## As Name -It is possible to use the same column in multiple columns by specifying the `as_name` argument of the column. +It is possible to use the same column in multiple columns by specifying the `alias` argument of the column. ```python -sum_a = SumColumn("column_a", as_name="sum_a") -count_a = CountColumn("column_a", as_name="count_a") +sum_a = SumColumn("column_a", alias="sum_a") +count_a = CountColumn("column_a", alias="count_a") ``` -The resulting data will use the `as_name` keys to reference the values. +The resulting data will use the `alias` keys to reference the values. -TODO: custom queries, AliasColumn +## Conditional / Case columns +*Simple* +```python +num_wheels = SumWhen("vehicle", whens={"unicycle": 1, "bicycle": 2, "car": 4}, else_=0, alias="num_wheels") +``` + +*Complex* +```python +num_children = SumWhen(whens={"users.age < 13", 1}, else_=0, alias="children") +``` + +## Alias and Aggregate columns +Useful if you want to use a column more than once but don't want to re-calculate its value. +```python +sum_a = SumColumn("column_a") + +aggregate = AggregateColumn(lambda x, y: x / y, + AliasColumn("column_a"), + SumColumn("column_b") +``` +TODO: custom queries diff --git a/sqlagg/__init__.py b/sqlagg/__init__.py index 405672d..128371d 100644 --- a/sqlagg/__init__.py +++ b/sqlagg/__init__.py @@ -1,2 +1,2 @@ from .base import * -from .columns import SumColumn, CountColumn, MaxColumn, MinColumn, MeanColumn, MedianColumn, UniqueColumn \ No newline at end of file +from .columns import SumColumn, CountColumn, MaxColumn, MinColumn, MeanColumn, MedianColumn, UniqueColumn, SumWhen \ No newline at end of file diff --git a/sqlagg/base.py b/sqlagg/base.py index 1b65d21..333ba3a 100644 --- a/sqlagg/base.py +++ b/sqlagg/base.py @@ -1,9 +1,13 @@ # -*- coding: utf-8 -*- import sqlalchemy -import logging class SqlColumn(object): + def build_column(self, sql_table): + raise NotImplementedError() + + +class SimpleSqlColumn(SqlColumn): """ Simple representation of a column with a name and an aggregation function which can be None. """ @@ -43,7 +47,7 @@ def __init__(self, table_name, filters, group_by): self.columns = [] def append_column(self, column): - self.columns.append(SqlColumn(column.key, column.aggregate_fn, column.alias)) + self.columns.append(column.sql_column) def _check(self): if self.group_by: @@ -53,7 +57,7 @@ def _check(self): groups.remove(c.column_name) for g in groups: - self.columns.append(SqlColumn(g, aggregate_fn=None, alias=g)) + self.columns.append(SimpleSqlColumn(g, aggregate_fn=None, alias=g)) def execute(self, metadata, connection, filter_values): query = self._build_query(metadata) @@ -74,6 +78,9 @@ def _build_query(self, metadata): for filter in self.filters: query.append_whereclause(filter) + if not query.froms: + query = query.select_from(table) + return query def __repr__(self): @@ -175,6 +182,10 @@ def __init__(self, key, alias=None, table_name=None, filters=None, group_by=None def column_key(self): return self.table_name, str(self.filters), str(self.group_by) + @property + def sql_column(self): + return SimpleSqlColumn(self.key, self.aggregate_fn, self.alias) + def get_value(self, row): row_key = self.alias or self.key return row.get(row_key, None) if row else None diff --git a/sqlagg/columns.py b/sqlagg/columns.py index c801a09..bd112f6 100644 --- a/sqlagg/columns.py +++ b/sqlagg/columns.py @@ -1,6 +1,6 @@ -from sqlalchemy import func, distinct +from sqlalchemy import func, distinct, case, text from queries import MedianQueryMeta -from .base import BaseColumn, CustomQueryColumn +from .base import BaseColumn, CustomQueryColumn, SqlColumn class SimpleColumn(BaseColumn): @@ -33,4 +33,42 @@ class UniqueColumn(BaseColumn): class MedianColumn(CustomQueryColumn): query_cls = MedianQueryMeta - name = "median" \ No newline at end of file + name = "median" + + +class ConditionalAggregation(BaseColumn): + def __init__(self, key=None, whens={}, else_=None, *args, **kwargs): + super(ConditionalAggregation, self).__init__(key, *args, **kwargs) + self.whens = whens + self.else_ = else_ + + assert self.key or self.alias, "Column must have either a key or an alias" + + @property + def sql_column(self): + return ConditionalColumn(self.key, self.whens, self.else_, self.aggregate_fn, self.alias) + + +class SumWhen(ConditionalAggregation): + aggregate_fn = func.sum + + +class ConditionalColumn(SqlColumn): + def __init__(self, column_name, whens, else_, aggregate_fn, alias): + self.aggregate_fn = aggregate_fn + self.column_name = column_name + self.whens = whens + self.else_ = else_ + self.alias = alias or column_name + + def build_column(self, sql_table): + if self.column_name: + expr = case(value=sql_table.c[self.column_name], whens=self.whens, else_=self.else_) + else: + whens = {} + for when, then in self.whens.items(): + whens[text(when)] = then + + expr = case(whens=whens, else_=self.else_) + + return self.aggregate_fn(expr).label(self.alias) diff --git a/tests/test_columns.py b/tests/test_columns.py index 1414aad..38ab78d 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -1,8 +1,9 @@ -from unittest2 import TestCase, skip +from unittest2 import TestCase from . import BaseTest, engine from sqlalchemy.orm import sessionmaker from sqlagg import * + Session = sessionmaker() @@ -77,6 +78,20 @@ def test_aggregate_column(self): SumColumn("indicator_c")) self._test_view(col, 9) + def test_conditional_column_simple(self): + # sum(case user when 'user1' then 1 when 'user2' then 3 else 0) + col = SumWhen('user', whens={'user1': 1, 'user2': 3}, else_=0) + self._test_view(col, 8) + + def test_conditional_column_complex(self): + # sum(case when indicator_a < 1 OR indicator_a > 2 then 1 else 0) + col = SumWhen(whens={'user_table.indicator_a < 1 OR user_table.indicator_a > 2': 1}, alias='a') + self._test_view(col, 2) + + # sum(case when indicator_a between 1 and 2 then 0 else 1) + col = SumWhen(whens={'user_table.indicator_a between 1 and 2': 0}, else_=1, alias='a') + self._test_view(col, 2) + def _test_view(self, view, expected): data = self._get_view_data(view) value = view.get_value(data)