Skip to content

Commit

Permalink
Add default values and default_fill over a row_range.
Browse files Browse the repository at this point in the history
  • Loading branch information
martsberger committed Mar 15, 2019
1 parent 9230f18 commit fa50940
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 14 deletions.
21 changes: 20 additions & 1 deletion Readme.rst
Expand Up @@ -79,7 +79,7 @@ in each region for every shipped date, but only for Golf shirts:

>>> pivot_table = pivot(ShirtSales.objects.filter(style='Golf'), 'region', 'shipped', 'units')

The pivot function takes an option parameter for how to aggregate the data. For example,
The pivot function takes an optional parameter for how to aggregate the data. For example,
instead of the total units sold in each region for every ship date, we might be interested in
the average number of units per order. Then we can pass the Avg aggregation function

Expand Down Expand Up @@ -117,6 +117,25 @@ you want them all lower cased, you can do the following
>>> return s.lower()
>>> pivot_table = pivot(ShirtSales, 'region', 'shipped', 'units', display_transform=lowercase)

If there are no records in the original data table for a particular cell in the pivot result,
SQL will return NULL and this gets translated to None in python. If you want to get zero, or
some other default, you can pass that as a parameter to pivot:

>>> pivot_table = pivot(ShirtSales, 'region', 'shipped', 'units', default=0)

The above call ensures that when there are no units sold in a particular region on a particular
date, we get zero as the result instead of None. However, the results will only contain
shipped dates if at least one region had sales on that date. If it's necessary to get results
for all dates in a range including dates where there are no ShirtSales records, we can pass
a target row_range:

>>> from datetime import date, timedelta
>>> row_range = [date(2005, 1, 1) + timedelta(days) for days in range(59)]
>>> pivot_table = pivot(ShirtSales, 'region', 'shipped', 'units', default=0, row_range=row_range)

Will output a result with every shipped date from Jan 1st to February 28th whether there are
sales on those days or not.

*The histogram function*

This library also supports creating histograms from a single column of data with the
Expand Down
21 changes: 15 additions & 6 deletions django_pivot/pivot.py
@@ -1,11 +1,12 @@
import six
from django.db.models import Case, When, Q, F, Sum, CharField, Value
from django.db.models.functions import Coalesce
from django.shortcuts import _get_queryset

from django_pivot.utils import get_column_values, get_field_choices
from django_pivot.utils import get_column_values, get_field_choices, default_fill


def pivot(queryset, rows, column, data, aggregation=Sum, choices='auto', display_transform=lambda s: s):
def pivot(queryset, rows, column, data, aggregation=Sum, choices='auto', display_transform=lambda s: s, default=None, row_range=()):
"""
Takes a queryset and pivots it. The result is a table with one record
per unique value in the `row` column, a column for each unique value in the `column` column
Expand All @@ -17,6 +18,8 @@ def pivot(queryset, rows, column, data, aggregation=Sum, choices='auto', display
:param data: column name or Combinable
:param aggregation: aggregation function to apply to data column
:param display_transform: function that takes a string and returns a string
:param default: default value to pass to the aggregate function when no record is found
:param row_range: iterable with the expected range of rows in the result
:return: ValuesQueryset
"""
values = [rows] if isinstance(rows, six.string_types) else list(rows)
Expand All @@ -25,7 +28,7 @@ def pivot(queryset, rows, column, data, aggregation=Sum, choices='auto', display

column_values = get_column_values(queryset, column, choices)

annotations = _get_annotations(column, column_values, data, aggregation, display_transform)
annotations = _get_annotations(column, column_values, data, aggregation, display_transform, default=default)
for row in values:
row_choices = get_field_choices(queryset, row)
if row_choices:
Expand All @@ -34,12 +37,18 @@ def pivot(queryset, rows, column, data, aggregation=Sum, choices='auto', display
queryset = queryset.annotate(**{'get_' + row + '_display': row_display})
values.append('get_' + row + '_display')

return queryset.values(*values).annotate(**annotations)
values_list = queryset.values(*values).annotate(**annotations)

if row_range:
attributes = [value[0] for value in column_values]
values_list = default_fill(values_list, values[0], row_range, fill_value=default, fill_attributes=attributes)

def _get_annotations(column, column_values, data, aggregation, display_transform=lambda s: s):
return values_list


def _get_annotations(column, column_values, data, aggregation, display_transform=lambda s: s, default=None):
value = data if hasattr(data, 'resolve_expression') else F(data)
return {
display_transform(display_value): aggregation(Case(When(Q(**{column: column_value}), then=value)))
display_transform(display_value): Coalesce(aggregation(Case(When(Q(**{column: column_value}), then=value))), default)
for column_value, display_value in column_values
}
35 changes: 28 additions & 7 deletions django_pivot/tests/pivot/test.py
@@ -1,5 +1,6 @@
from __future__ import absolute_import

from datetime import timedelta, date
from itertools import chain

from django.conf import settings
Expand Down Expand Up @@ -87,6 +88,14 @@ def setUpClass(cls):
shirt_sale.units = units[indx % len(units)]
shirt_sale.price = prices[indx % len(prices)]

shirt_sales.append(ShirtSales(store=Store.objects.first(),
gender=genders[0],
style=styles[0],
shipped='2005-07-05',
units=13,
price=73
))

ShirtSales.objects.bulk_create(shirt_sales)

def test_pivot(self):
Expand Down Expand Up @@ -114,14 +123,14 @@ def test_pivot_on_choice_field_row(self):
def test_pivot_on_date(self):
shirt_sales = ShirtSales.objects.all()

pt = pivot(ShirtSales, 'style', 'shipped', 'units')
pt = pivot(ShirtSales, 'style', 'shipped', 'units', default=0)

for row in pt:
style = row['style']
for dt in dates:
self.assertEqual(row[dt], sum(ss.units for ss in shirt_sales if ss.style == style and force_text(ss.shipped) == dt))

pt = pivot(ShirtSales.objects, 'shipped', 'style', 'units')
pt = pivot(ShirtSales.objects, 'shipped', 'style', 'units', default=0)

for row in pt:
shipped = row['shipped']
Expand All @@ -131,14 +140,14 @@ def test_pivot_on_date(self):
def test_pivot_on_foreignkey(self):
shirt_sales = ShirtSales.objects.all()

pt = pivot(ShirtSales, 'shipped', 'store__region__name', 'units')
pt = pivot(ShirtSales, 'shipped', 'store__region__name', 'units', default=0)

for row in pt:
shipped = row['shipped']
for name in ['North', 'South', 'East', 'West']:
self.assertEqual(row[name], sum(ss.units for ss in shirt_sales if force_text(ss.shipped) == force_text(shipped) and ss.store.region.name == name))

pt = pivot(ShirtSales, 'shipped', 'store__name', 'units')
pt = pivot(ShirtSales, 'shipped', 'store__name', 'units', default=0)

for row in pt:
shipped = row['shipped']
Expand All @@ -160,14 +169,15 @@ def test_monthly_report(self):
return

shirt_sales = ShirtSales.objects.annotate(**annotations).order_by('date_sort')
monthly_report = pivot(shirt_sales, 'Month', 'store__name', 'units')
monthly_report = pivot(shirt_sales, 'Month', 'store__name', 'units', default=0)

# Get the months and assert that the order by that we sent in is respected
months = [record['Month'] for record in monthly_report]
month_strings = ['12-2004', '01-2005', '02-2005', '03-2005', '04-2005', '05-2005']
month_strings = ['12-2004', '01-2005', '02-2005', '03-2005', '04-2005', '05-2005', '07-2005']
self.assertEqual(months, month_strings)

# Check that the aggregations are correct too

for record in monthly_report:
month, year = record['Month'].split('-')
for name in store_names:
Expand All @@ -176,11 +186,22 @@ def test_monthly_report(self):
int(ss.shipped.month) == int(month) and
ss.store.name == name)))

def test_pivot_with_default_fill(self):
shirt_sales = ShirtSales.objects.filter(shipped__gt='2005-01-25', shipped__lt='2005-02-03')

row_range = [date(2005, 1, 25) + timedelta(days=n) for n in range(14)]
pt = pivot(shirt_sales, 'shipped', 'style', 'units', default=0, row_range=row_range)

for row in pt:
shipped = row['shipped']
for style in styles:
self.assertEqual(row[style], sum(ss.units for ss in shirt_sales if force_text(ss.shipped) == force_text(shipped) and ss.style == style))

def test_pivot_aggregate(self):
shirt_sales = ShirtSales.objects.all()

data = ExpressionWrapper(F('units') * F('price'), output_field=DecimalField())
pt = pivot(ShirtSales, 'store__region__name', 'shipped', data, Avg)
pt = pivot(ShirtSales, 'store__region__name', 'shipped', data, Avg, default=0)

for row in pt:
region_name = row['store__region__name']
Expand Down
17 changes: 17 additions & 0 deletions django_pivot/utils.py
Expand Up @@ -32,3 +32,20 @@ def _get_field(model, field_names):
model = field.related_model
else:
return field


def default_fill(data, attribute, target_range, fill_value=None, fill_attributes=()):
indx = 0
new_data = list()
fill_dict = {attribute: fill_value for attribute in fill_attributes}
for element in target_range:

if indx < len(data) and data[indx][attribute] == element:
new_data.append(data[indx])
indx += 1
else:
fill_record = {attribute: element}
fill_record.update(fill_dict)
new_data.append(fill_record)

return new_data

0 comments on commit fa50940

Please sign in to comment.