Skip to content

Commit

Permalink
generic group bys
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar committed May 29, 2014
1 parent e512d4b commit 8a8e74d
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 17 deletions.
12 changes: 11 additions & 1 deletion examples/quickcharts2/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ def __repr__(self):
return self.name


class PoliticalType(Model):
id = Column(Integer, primary_key=True)
name = Column(String(50), unique = True, nullable=False)

def __repr__(self):
return self.name


class CountryStats(Model):
id = Column(Integer, primary_key=True)
stat_date = Column(Date, nullable=True)
Expand All @@ -19,6 +27,8 @@ class CountryStats(Model):
college = Column(Float)
country_id = Column(Integer, ForeignKey('country.id'), nullable=False)
country = relationship("Country")
political_type_id = Column(Integer, ForeignKey('political_type.id'), nullable=False)
political_type = relationship("PoliticalType")

def __repr__(self):
return str(self.stat_date)
return "{0}:{1}:{2}:{3}".format(self.country, self.political_type, self.population, self.college)
22 changes: 19 additions & 3 deletions examples/quickcharts2/app/views.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from flask.ext.appbuilder.models.datamodel import SQLAModel
from flask.ext.appbuilder.views import ModelView
from flask_appbuilder.charts.views import DirectChartView
from models import CountryStats, Country
from flask_appbuilder.charts.views import DirectChartView, GroupByChartView
from models import CountryStats, Country, PoliticalType
from app import appbuilder, db

from flask_appbuilder.models.group import aggregate_count, aggregate_sum

class CountryStatsModelView(ModelView):
datamodel = SQLAModel(CountryStats)
Expand All @@ -13,6 +13,10 @@ class CountryModelView(ModelView):
datamodel = SQLAModel(Country)


class PoliticalTypeModelView(ModelView):
datamodel = SQLAModel(PoliticalType)


class CountryStatsDirectChart(DirectChartView):
datamodel = SQLAModel(CountryStats)
chart_title = 'Statistics'
Expand All @@ -21,9 +25,21 @@ class CountryStatsDirectChart(DirectChartView):
base_order = ('stat_date', 'asc')


class CountryGroupByChartView(GroupByChartView):
datamodel = SQLAModel(CountryStats)
chart_title = 'Statistics'
chart_type = 'ColumnChart'
group_by_columns = ['country.name', 'political_type.name']
# [{'column':'<COL NAME>','group_class':<CLASS>]
aggregate_by_column = [(aggregate_count, ''), (aggregate_sum, 'population')]
# [{'aggr_func':<FUNC>,'column':'<COL NAME>'}]


db.create_all()
appbuilder.add_view(CountryModelView, "List Countries", icon="fa-folder-open-o", category="Statistics")
appbuilder.add_view(PoliticalTypeModelView, "List Political Types", icon="fa-folder-open-o", category="Statistics")
appbuilder.add_view(CountryStatsModelView, "List Country Stats", icon="fa-folder-open-o", category="Statistics")
appbuilder.add_separator("Statistics")
appbuilder.add_view(CountryStatsDirectChart, "Show Country Chart", icon="fa-dashboard", category="Statistics")
appbuilder.add_view(CountryGroupByChartView, "Group Country Chart", icon="fa-dashboard", category="Statistics")

17 changes: 17 additions & 0 deletions flask_appbuilder/charts/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..models.filters import Filters, FilterRelationOneToManyEqual
from ..baseviews import BaseModelView, expose
from ..urltools import *
from ..models.group import GroupBys

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -156,6 +157,22 @@ class GroupByChartView(BaseChartView):
# [{'aggr_func':<FUNC>,'column':'<COL NAME>'}]
chart_widget = DirectChartWidget

def __init__(self, **kwargs):
super(BaseChartView, self).__init__(**kwargs)


@expose('/chart/')
@has_access
def chart(self):
form = self.search_form.refresh()
get_filter_args(self._filters)

group = GroupBys(self.group_by_columns, self.aggregate_by_column)
joined_filters = self._filters.get_joined_filters(self._base_filters)

count, lst = self.datamodel.query(filters=joined_filters)
log.debug("Group Data1 {0}".format(lst))
log.debug("GROUP: {0}".format(group.apply(lst)))


class ChartView(BaseSimpleGroupByChartView):
Expand Down
38 changes: 25 additions & 13 deletions flask_appbuilder/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import calendar
import logging
from itertools import groupby
from operator import itemgetter, attrgetter

log = logging.getLogger(__name__)

Expand All @@ -11,10 +12,7 @@ def aggregate_count(items, col):


def aggregate_sum(items, col):
value = 0
for item in items:
value = value + getattr(item, col)
return value
return sum(getattr(item, col) for item in items)


def aggregate_avg(items, col):
Expand Down Expand Up @@ -65,12 +63,11 @@ def __repr__(self):


class GroupByCol(BaseGroupBy):

def _apply(self, data):
data = sorted(data, key=self.get_group_col)
json_data = dict()
json_data['cols'] = [{'id': self.column_name,
'label': self.column_name,
'label': self.column_name,
'type': 'string'},
{'id': self.aggregate_func.__name__ + '_' + self.column_name,
'label': self.aggregate_func.__name__ + '_' + self.column_name,
Expand Down Expand Up @@ -117,18 +114,33 @@ def apply(self, data):
def get_group_col(self, item):
value = getattr(item, self.column_name)
if value:
return value.year, value.month
return value.year, value.month

def get_format_group_col(self, item):
return calendar.month_name[item[1]] + ' ' + str(item[0])


class GroupBys(object):
group_bys = None
"""
[['COLNAME',GROUP_CLASS, AGR_FUNC,'AGR_COLNAME'],]
"""
group_bys_cols = None
# ['<COLNAME>',<FUNC>, ....]
aggr_by_cols = None
# [(<AGGR FUNC>),'<COLNAME>',...]

def __init__(self, group_by_cols, aggr_by_cols):
self.group_bys_cols = group_by_cols
self.aggr_by_cols = aggr_by_cols

def get_group_col(self, item):
return getattr(item, self.column_name)

def __init__(self, group_bys):
self.group_bys = group_bys

def apply(self, data):
data = sorted(data, key=attrgetter(*self.group_bys_cols))
result = []
for (grouped, items) in groupby(data, key=attrgetter(*self.group_bys_cols)):
items = list(items)
result_item = [grouped]
for aggr_by_col in self.aggr_by_cols:
result_item.append(aggr_by_col[0](items, aggr_by_col[1]))
result.append(result_item)
return result

0 comments on commit 8a8e74d

Please sign in to comment.