diff --git a/Readme.rst b/Readme.rst index 8a9faba..3337354 100644 --- a/Readme.rst +++ b/Readme.rst @@ -79,7 +79,7 @@ in each region for every shipped date, but only for Golf shirts: >>> pivot_table = pivot(ShirtSales.objects.filter(style='Golf'), 'region', 'shipped', 'units') -The pivot function takes an option parameter for how to aggregate the data. For example, +The pivot function takes an optional parameter for how to aggregate the data. For example, instead of the total units sold in each region for every ship date, we might be interested in the average number of units per order. Then we can pass the Avg aggregation function @@ -117,6 +117,25 @@ you want them all lower cased, you can do the following >>> return s.lower() >>> pivot_table = pivot(ShirtSales, 'region', 'shipped', 'units', display_transform=lowercase) +If there are no records in the original data table for a particular cell in the pivot result, +SQL will return NULL and this gets translated to None in python. If you want to get zero, or +some other default, you can pass that as a parameter to pivot: + +>>> pivot_table = pivot(ShirtSales, 'region', 'shipped', 'units', default=0) + +The above call ensures that when there are no units sold in a particular region on a particular +date, we get zero as the result instead of None. However, the results will only contain +shipped dates if at least one region had sales on that date. If it's necessary to get results +for all dates in a range including dates where there are no ShirtSales records, we can pass +a target row_range: + +>>> from datetime import date, timedelta +>>> row_range = [date(2005, 1, 1) + timedelta(days) for days in range(59)] +>>> pivot_table = pivot(ShirtSales, 'region', 'shipped', 'units', default=0, row_range=row_range) + +Will output a result with every shipped date from Jan 1st to February 28th whether there are +sales on those days or not. + *The histogram function* This library also supports creating histograms from a single column of data with the diff --git a/django_pivot/pivot.py b/django_pivot/pivot.py index 4605bd8..6a4b7da 100644 --- a/django_pivot/pivot.py +++ b/django_pivot/pivot.py @@ -1,11 +1,12 @@ import six from django.db.models import Case, When, Q, F, Sum, CharField, Value +from django.db.models.functions import Coalesce from django.shortcuts import _get_queryset -from django_pivot.utils import get_column_values, get_field_choices +from django_pivot.utils import get_column_values, get_field_choices, default_fill -def pivot(queryset, rows, column, data, aggregation=Sum, choices='auto', display_transform=lambda s: s): +def pivot(queryset, rows, column, data, aggregation=Sum, choices='auto', display_transform=lambda s: s, default=None, row_range=()): """ Takes a queryset and pivots it. The result is a table with one record per unique value in the `row` column, a column for each unique value in the `column` column @@ -17,6 +18,8 @@ def pivot(queryset, rows, column, data, aggregation=Sum, choices='auto', display :param data: column name or Combinable :param aggregation: aggregation function to apply to data column :param display_transform: function that takes a string and returns a string + :param default: default value to pass to the aggregate function when no record is found + :param row_range: iterable with the expected range of rows in the result :return: ValuesQueryset """ values = [rows] if isinstance(rows, six.string_types) else list(rows) @@ -25,7 +28,7 @@ def pivot(queryset, rows, column, data, aggregation=Sum, choices='auto', display column_values = get_column_values(queryset, column, choices) - annotations = _get_annotations(column, column_values, data, aggregation, display_transform) + annotations = _get_annotations(column, column_values, data, aggregation, display_transform, default=default) for row in values: row_choices = get_field_choices(queryset, row) if row_choices: @@ -34,12 +37,18 @@ def pivot(queryset, rows, column, data, aggregation=Sum, choices='auto', display queryset = queryset.annotate(**{'get_' + row + '_display': row_display}) values.append('get_' + row + '_display') - return queryset.values(*values).annotate(**annotations) + values_list = queryset.values(*values).annotate(**annotations) + if row_range: + attributes = [value[0] for value in column_values] + values_list = default_fill(values_list, values[0], row_range, fill_value=default, fill_attributes=attributes) -def _get_annotations(column, column_values, data, aggregation, display_transform=lambda s: s): + return values_list + + +def _get_annotations(column, column_values, data, aggregation, display_transform=lambda s: s, default=None): value = data if hasattr(data, 'resolve_expression') else F(data) return { - display_transform(display_value): aggregation(Case(When(Q(**{column: column_value}), then=value))) + display_transform(display_value): Coalesce(aggregation(Case(When(Q(**{column: column_value}), then=value))), default) for column_value, display_value in column_values } diff --git a/django_pivot/tests/pivot/test.py b/django_pivot/tests/pivot/test.py index 0f77682..321663f 100644 --- a/django_pivot/tests/pivot/test.py +++ b/django_pivot/tests/pivot/test.py @@ -1,5 +1,6 @@ from __future__ import absolute_import +from datetime import timedelta, date from itertools import chain from django.conf import settings @@ -87,6 +88,14 @@ def setUpClass(cls): shirt_sale.units = units[indx % len(units)] shirt_sale.price = prices[indx % len(prices)] + shirt_sales.append(ShirtSales(store=Store.objects.first(), + gender=genders[0], + style=styles[0], + shipped='2005-07-05', + units=13, + price=73 + )) + ShirtSales.objects.bulk_create(shirt_sales) def test_pivot(self): @@ -114,14 +123,14 @@ def test_pivot_on_choice_field_row(self): def test_pivot_on_date(self): shirt_sales = ShirtSales.objects.all() - pt = pivot(ShirtSales, 'style', 'shipped', 'units') + pt = pivot(ShirtSales, 'style', 'shipped', 'units', default=0) for row in pt: style = row['style'] for dt in dates: self.assertEqual(row[dt], sum(ss.units for ss in shirt_sales if ss.style == style and force_text(ss.shipped) == dt)) - pt = pivot(ShirtSales.objects, 'shipped', 'style', 'units') + pt = pivot(ShirtSales.objects, 'shipped', 'style', 'units', default=0) for row in pt: shipped = row['shipped'] @@ -131,14 +140,14 @@ def test_pivot_on_date(self): def test_pivot_on_foreignkey(self): shirt_sales = ShirtSales.objects.all() - pt = pivot(ShirtSales, 'shipped', 'store__region__name', 'units') + pt = pivot(ShirtSales, 'shipped', 'store__region__name', 'units', default=0) for row in pt: shipped = row['shipped'] for name in ['North', 'South', 'East', 'West']: self.assertEqual(row[name], sum(ss.units for ss in shirt_sales if force_text(ss.shipped) == force_text(shipped) and ss.store.region.name == name)) - pt = pivot(ShirtSales, 'shipped', 'store__name', 'units') + pt = pivot(ShirtSales, 'shipped', 'store__name', 'units', default=0) for row in pt: shipped = row['shipped'] @@ -160,14 +169,15 @@ def test_monthly_report(self): return shirt_sales = ShirtSales.objects.annotate(**annotations).order_by('date_sort') - monthly_report = pivot(shirt_sales, 'Month', 'store__name', 'units') + monthly_report = pivot(shirt_sales, 'Month', 'store__name', 'units', default=0) # Get the months and assert that the order by that we sent in is respected months = [record['Month'] for record in monthly_report] - month_strings = ['12-2004', '01-2005', '02-2005', '03-2005', '04-2005', '05-2005'] + month_strings = ['12-2004', '01-2005', '02-2005', '03-2005', '04-2005', '05-2005', '07-2005'] self.assertEqual(months, month_strings) # Check that the aggregations are correct too + for record in monthly_report: month, year = record['Month'].split('-') for name in store_names: @@ -176,11 +186,22 @@ def test_monthly_report(self): int(ss.shipped.month) == int(month) and ss.store.name == name))) + def test_pivot_with_default_fill(self): + shirt_sales = ShirtSales.objects.filter(shipped__gt='2005-01-25', shipped__lt='2005-02-03') + + row_range = [date(2005, 1, 25) + timedelta(days=n) for n in range(14)] + pt = pivot(shirt_sales, 'shipped', 'style', 'units', default=0, row_range=row_range) + + for row in pt: + shipped = row['shipped'] + for style in styles: + self.assertEqual(row[style], sum(ss.units for ss in shirt_sales if force_text(ss.shipped) == force_text(shipped) and ss.style == style)) + def test_pivot_aggregate(self): shirt_sales = ShirtSales.objects.all() data = ExpressionWrapper(F('units') * F('price'), output_field=DecimalField()) - pt = pivot(ShirtSales, 'store__region__name', 'shipped', data, Avg) + pt = pivot(ShirtSales, 'store__region__name', 'shipped', data, Avg, default=0) for row in pt: region_name = row['store__region__name'] diff --git a/django_pivot/utils.py b/django_pivot/utils.py index c03b704..87f2119 100644 --- a/django_pivot/utils.py +++ b/django_pivot/utils.py @@ -32,3 +32,20 @@ def _get_field(model, field_names): model = field.related_model else: return field + + +def default_fill(data, attribute, target_range, fill_value=None, fill_attributes=()): + indx = 0 + new_data = list() + fill_dict = {attribute: fill_value for attribute in fill_attributes} + for element in target_range: + + if indx < len(data) and data[indx][attribute] == element: + new_data.append(data[indx]) + indx += 1 + else: + fill_record = {attribute: element} + fill_record.update(fill_dict) + new_data.append(fill_record) + + return new_data