From 69b5046bc09582e233d620062e81c4a48c2ba350 Mon Sep 17 00:00:00 2001 From: Kasun Herath Date: Sat, 19 Nov 2016 22:15:11 +0530 Subject: [PATCH] aggregate fields --- beesql/aggregation.py | 44 +++++++++++++++++++++++++++++++++++++ beesql/query/base.py | 51 ++++++++++++++++++++++++++++++------------- 2 files changed, 80 insertions(+), 15 deletions(-) create mode 100644 beesql/aggregation.py diff --git a/beesql/aggregation.py b/beesql/aggregation.py new file mode 100644 index 0000000..c7c049d --- /dev/null +++ b/beesql/aggregation.py @@ -0,0 +1,44 @@ +class AggregationField(object): + def __init__(self, column_name, as_name=None): + self.column_name = column_name + self.as_name = as_name + + +class SumAggregationField(AggregationField): + QUERY_PART_NAME = 'sum_aggregation' + + +class AvgAggregationField(AggregationField): + QUERY_PART_NAME = 'avg_aggregation' + + +class CountAggregationField(AggregationField): + QUERY_PART_NAME = 'count_aggregation' + + +class MaxAggregationField(AggregationField): + QUERY_PART_NAME = 'max_aggregation' + + +class MinAggregationField(AggregationField): + QUERY_PART_NAME = 'min_aggregation' + + +def _sum(column_name, as_name=None): + return SumAggregationField(column_name, as_name) + + +def avg(column_name, as_name=None): + return AvgAggregationField(column_name, as_name) + + +def count(column_name, as_name=None): + return CountAggregationField(column_name, as_name) + + +def _max(column_name, as_name=None): + return MaxAggregationField(column_name, as_name) + + +def _min(column_name, as_name=None): + return MinAggregationField(column_name, as_name) diff --git a/beesql/query/base.py b/beesql/query/base.py index 977ca5f..8a68095 100644 --- a/beesql/query/base.py +++ b/beesql/query/base.py @@ -2,6 +2,7 @@ from .mixins import DataOperatorFuncs, AggregationFuncs from ..exceptions import BeeSQLError +from ..aggregation import AggregationField from .decorators import primary_keyword, secondary_keyword, logical_operator, complete_condition from .decorators import aggregation @@ -257,6 +258,14 @@ class AvgAggregation(AggregationFuncs, Aggregation): FUNCTION_NAME = 'AVG' +class MaxAggregation(AggregationFuncs, Aggregation): + FUNCTION_NAME = 'MAX' + + +class MinAggregation(AggregationFuncs, Aggregation): + FUNCTION_NAME = 'MIN' + + class Statement(object): def __init__(self, query, **kwargs): @@ -295,26 +304,24 @@ def get_sql(self): class Select(StatementWithCondition, Statement): - def __init__(self, query, all=False, *args): + def __init__(self, query, aggregations=None, *args): super().__init__(query) - self.aggregations = [] + self.aggregations = aggregations or [] self.all = all self.fields = list(set(args)) def select(self, *args): - fields = list(args) - if self.all: - return - - all = True if len(fields) == 0 else False - if all: - self.all = all - return - + q_maker = self.query.get_query_maker() + fields = filter(lambda x: isinstance(x, str), args) fields_set = set(self.fields) fields_set.update(fields) self.fields = list(fields_set) + agg_fields = filter(lambda x: isinstance(x, AggregationField), args) + aggregation_queries = [q_maker.make(agg.QUERY_PART_NAME)(agg.column_name, agg.as_name) for agg in agg_fields] + for agg_query in aggregation_queries: + self.add_aggregation(agg_query) + return self @secondary_keyword @@ -364,11 +371,21 @@ def count(self, column_name, as_name=None): AggregationClass = self.query.get_query_maker().make('count_aggregation') return AggregationClass(column_name, as_name) + @aggregation + def max(self, column_name, as_name=None): + AggregationClass = self.query.get_query_maker().make('max_aggregation') + return AggregationClass(column_name, as_name) + + @aggregation + def min(self, column_name, as_name=None): + AggregationClass = self.query.get_query_maker().make('min_aggregation') + return AggregationClass(column_name, as_name) + def add_aggregation(self, aggregation): self.aggregations.append(aggregation) def _get_sql(self): - if self.all: + if not self.fields and not self.aggregations: fields = '*' else: fields = self.fields[:] @@ -542,9 +559,11 @@ def on(self, table): @primary_keyword def select(self, *args): - fields = args[:] - all = True if len(fields) == 0 else False - select = self.get_query_maker().make('select')(self, all, *fields) + q_maker = self.get_query_maker() + fields = filter(lambda x: isinstance(x, str), args) + agg_fields = filter(lambda x: isinstance(x, AggregationField), args) + aggregation_queries = [q_maker.make(agg.QUERY_PART_NAME)(agg.column_name, agg.as_name) for agg in agg_fields] + select = self.get_query_maker().make('select')(self, aggregation_queries, *fields) return select @primary_keyword @@ -623,6 +642,8 @@ class QueryMaker(metaclass=QueryRegistry): 'count_aggregation': CountAggregation, 'sum_aggregation': SumAggregation, 'avg_aggregation': AvgAggregation, + 'max_aggregation': MaxAggregation, + 'min_aggregation': MinAggregation, } @classmethod