Skip to content

Commit

Permalink
aggregate fields
Browse files Browse the repository at this point in the history
  • Loading branch information
kasun committed Nov 19, 2016
1 parent 848bf6e commit 69b5046
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 15 deletions.
44 changes: 44 additions & 0 deletions beesql/aggregation.py
@@ -0,0 +1,44 @@
class AggregationField(object):
def __init__(self, column_name, as_name=None):
self.column_name = column_name
self.as_name = as_name


class SumAggregationField(AggregationField):
QUERY_PART_NAME = 'sum_aggregation'


class AvgAggregationField(AggregationField):
QUERY_PART_NAME = 'avg_aggregation'


class CountAggregationField(AggregationField):
QUERY_PART_NAME = 'count_aggregation'


class MaxAggregationField(AggregationField):
QUERY_PART_NAME = 'max_aggregation'


class MinAggregationField(AggregationField):
QUERY_PART_NAME = 'min_aggregation'


def _sum(column_name, as_name=None):
return SumAggregationField(column_name, as_name)


def avg(column_name, as_name=None):
return AvgAggregationField(column_name, as_name)


def count(column_name, as_name=None):
return CountAggregationField(column_name, as_name)


def _max(column_name, as_name=None):
return MaxAggregationField(column_name, as_name)


def _min(column_name, as_name=None):
return MinAggregationField(column_name, as_name)
51 changes: 36 additions & 15 deletions beesql/query/base.py
Expand Up @@ -2,6 +2,7 @@

from .mixins import DataOperatorFuncs, AggregationFuncs
from ..exceptions import BeeSQLError
from ..aggregation import AggregationField
from .decorators import primary_keyword, secondary_keyword, logical_operator, complete_condition
from .decorators import aggregation

Expand Down Expand Up @@ -257,6 +258,14 @@ class AvgAggregation(AggregationFuncs, Aggregation):
FUNCTION_NAME = 'AVG'


class MaxAggregation(AggregationFuncs, Aggregation):
FUNCTION_NAME = 'MAX'


class MinAggregation(AggregationFuncs, Aggregation):
FUNCTION_NAME = 'MIN'


class Statement(object):

def __init__(self, query, **kwargs):
Expand Down Expand Up @@ -295,26 +304,24 @@ def get_sql(self):

class Select(StatementWithCondition, Statement):

def __init__(self, query, all=False, *args):
def __init__(self, query, aggregations=None, *args):
super().__init__(query)
self.aggregations = []
self.aggregations = aggregations or []
self.all = all
self.fields = list(set(args))

def select(self, *args):
fields = list(args)
if self.all:
return

all = True if len(fields) == 0 else False
if all:
self.all = all
return

q_maker = self.query.get_query_maker()
fields = filter(lambda x: isinstance(x, str), args)
fields_set = set(self.fields)
fields_set.update(fields)
self.fields = list(fields_set)

agg_fields = filter(lambda x: isinstance(x, AggregationField), args)
aggregation_queries = [q_maker.make(agg.QUERY_PART_NAME)(agg.column_name, agg.as_name) for agg in agg_fields]
for agg_query in aggregation_queries:
self.add_aggregation(agg_query)

return self

@secondary_keyword
Expand Down Expand Up @@ -364,11 +371,21 @@ def count(self, column_name, as_name=None):
AggregationClass = self.query.get_query_maker().make('count_aggregation')
return AggregationClass(column_name, as_name)

@aggregation
def max(self, column_name, as_name=None):
AggregationClass = self.query.get_query_maker().make('max_aggregation')
return AggregationClass(column_name, as_name)

@aggregation
def min(self, column_name, as_name=None):
AggregationClass = self.query.get_query_maker().make('min_aggregation')
return AggregationClass(column_name, as_name)

def add_aggregation(self, aggregation):
self.aggregations.append(aggregation)

def _get_sql(self):
if self.all:
if not self.fields and not self.aggregations:
fields = '*'
else:
fields = self.fields[:]
Expand Down Expand Up @@ -542,9 +559,11 @@ def on(self, table):

@primary_keyword
def select(self, *args):
fields = args[:]
all = True if len(fields) == 0 else False
select = self.get_query_maker().make('select')(self, all, *fields)
q_maker = self.get_query_maker()
fields = filter(lambda x: isinstance(x, str), args)
agg_fields = filter(lambda x: isinstance(x, AggregationField), args)
aggregation_queries = [q_maker.make(agg.QUERY_PART_NAME)(agg.column_name, agg.as_name) for agg in agg_fields]
select = self.get_query_maker().make('select')(self, aggregation_queries, *fields)
return select

@primary_keyword
Expand Down Expand Up @@ -623,6 +642,8 @@ class QueryMaker(metaclass=QueryRegistry):
'count_aggregation': CountAggregation,
'sum_aggregation': SumAggregation,
'avg_aggregation': AvgAggregation,
'max_aggregation': MaxAggregation,
'min_aggregation': MinAggregation,
}

@classmethod
Expand Down

0 comments on commit 69b5046

Please sign in to comment.