Skip to content

Commit

Permalink
Reimplemented csv transformer to work with new data tables transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
twheys committed Aug 23, 2016
1 parent 5ed8ca5 commit 5e65c60
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 39 deletions.
130 changes: 92 additions & 38 deletions fireant/slicer/transformers/datatables.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,16 @@ def _render_data(self, dataframe, display_schema):

def _render_dimension_data(self, idx, dimensions):
i = 0
for key, schema in dimensions:
for key, dimension in dimensions:
dimension_value = _format_data_point(idx[i])

if 'label_field' in schema:
if 'label_field' in dimension:
i += 1
yield key, {'display': dimension_value,
'value': _format_data_point(idx[i])}

elif 'label_options' in schema:
yield key, {'display': schema['label_options'].get(dimension_value, dimension_value) or 'Total',
elif 'label_options' in dimension:
yield key, {'display': dimension['label_options'].get(dimension_value, dimension_value) or 'Total',
'value': dimension_value}

else:
Expand Down Expand Up @@ -173,9 +173,12 @@ def _render_column_level(self, metric_column, display_schema):
if not (isinstance(level_key, float) and np.isnan(level_key))
else 'Total')

elif 'label_field' in dimension:
else:
level_key = metric_column[i]
i += 1

if 'label_field' in dimension:
i += 1

level_label = metric_column[i]

# the metric key must remain last
Expand Down Expand Up @@ -218,50 +221,101 @@ def _recurse_dimensions(self, df, dimensions, metrics, reference=None):


class CSVRowIndexTransformer(DataTablesRowIndexTransformer):
def transform(self, data_frame, display_schema):
dim_ordinal = {name: ordinal
for ordinal, name in enumerate(data_frame.index.names)}

csv_df = self._format_columns(data_frame, dim_ordinal, display_schema)
def transform(self, dataframe, display_schema):
csv_df = self._format_columns(dataframe, display_schema['metrics'], display_schema['dimensions'])

if isinstance(data_frame.index, pd.RangeIndex):
if isinstance(dataframe.index, pd.RangeIndex):
# If there are no dimensions, just serialize to csv without the index
return csv_df.to_csv(index=False)

csv_df = self._format_index(csv_df, dim_ordinal, display_schema)
csv_df = self._format_index(csv_df, display_schema['dimensions'])

row_dimensions = display_schema['dimensions'][:None if self.table_type == 'row' else 1]
return csv_df.to_csv(index_label=[dimension['label']
for dimension in row_dimensions])
row_dimension_labels = self._row_dimension_labels(display_schema['dimensions'])
return csv_df.to_csv(index_label=row_dimension_labels)

def _format_index(self, csv_df, dim_ordinal, display_schema):
if isinstance(csv_df.index, pd.MultiIndex):
csv_df.index = pd.MultiIndex.from_tuples(
[[self._format_dimension_label(idx, dim_ordinal, dimension)
for dimension in display_schema['dimensions']]
for idx in list(csv_df.index)]
)
def _format_index(self, csv_df, dimensions):
levels = list(dimensions.items())[:None if isinstance(csv_df.index, pd.MultiIndex) else 1]

csv_df.index = [self.get_level_values(csv_df, key, dimension)
for key, dimension in levels]
csv_df.index.names = [key
for key, dimension in levels]
return csv_df

def get_level_values(self, csv_df, key, dimension):
if 'label_options' in dimension:
return [_format_data_point(dimension['label_options'].get(value, value))
for value in csv_df.index.get_level_values(key)]

if 'label_field' in dimension:
return [_format_data_point(data_point)
for data_point in csv_df.index.get_level_values(dimension['label_field'])]

return [_format_data_point(data_point)
for data_point in csv_df.index.get_level_values(key)]

@staticmethod
def _format_dimension_label(idx, dim_ordinal, dimension):
if 'label_field' in dimension:
label_field = dimension['label_field']
return idx[dim_ordinal[label_field]]

if isinstance(idx, tuple):
id_field = dimension['id_fields'][0]
dimension_label = idx[dim_ordinal[id_field]]

else:
csv_df.reindex(
self._format_dimension_label(idx, dim_ordinal, display_schema['dimensions'][0])
for idx in list(csv_df.index)
)
dimension_label = idx

return csv_df
if 'label_options' in dimension:
dimension_label = dimension['label_options'].get(dimension_label, dimension_label)

def _format_columns(self, data_frame, dim_ordinal, display_schema):
if 1 < len(display_schema['dimensions']) and self.table_type == 'column':
csv_df = self._prepare_data_frame(data_frame, display_schema['dimensions'])
return dimension_label

csv_df.columns = pd.Index([self._format_series_labels(column, dim_ordinal, display_schema)
for column in csv_df.columns])
return csv_df
def _format_columns(self, dataframe, metrics, dimensions):
return dataframe.rename(columns=lambda metric: metrics.get(metric, metric))

return data_frame.rename(
columns=lambda metric: display_schema['metrics'].get(metric, metric)
)
def _row_dimension_labels(self, dimensions):
return [dimension['label']
for dimension in dimensions.values()]


class CSVColumnIndexTransformer(DataTablesColumnIndexTransformer, CSVRowIndexTransformer):
pass
def _format_columns(self, dataframe, metrics, dimensions):
if 1 < len(dimensions):
csv_df = self._prepare_dataframe(dataframe, dimensions)
csv_df.columns = self._format_column_labels(csv_df, metrics, dimensions)
return csv_df

return super(CSVColumnIndexTransformer, self)._format_columns(dataframe, metrics, dimensions)

def _row_dimension_labels(self, dimensions):
return [dimension['label']
for dimension in list(dimensions.values())[:1]]

def _format_column_labels(self, csv_df, metrics, dimensions):
column_labels = []
for idx in list(csv_df.columns):
metric_value, dimension_values = idx[0], idx[1:]

dimension_labels = []
for dimension_level, dimension_value in zip(csv_df.columns.names[1:], dimension_values):
if 'label_options' in dimensions[dimension_level]:
dimension_label = dimensions[dimension_level]['label_options'].get(dimension_value, dimension_value)
else:
dimension_label = dimension_value

dimension_label = _format_data_point(dimension_label)

if dimension_label is not None:
dimension_labels.append(dimension_label)

if dimension_labels:
column_labels.append('{metric} ({dimensions})'.format(
metric=metrics[metric_value],
dimensions=', '.join(dimension_labels),
))
else:
column_labels.append(metrics[metric_value])

return column_labels
2 changes: 1 addition & 1 deletion fireant/slicer/transformers/highcharts.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _make_categories(data_frame, dim_ordinal, display_schema):

if 'label_field' in category_dimension:
label_field = category_dimension['label_field']
return data_frame.index.get_level_values(label_field).unique().tolist()
return data_frame.index.get_level_values(label_field, ).unique().tolist()


class HighchartsBarTransformer(HighchartsColumnTransformer):
Expand Down
33 changes: 33 additions & 0 deletions fireant/tests/slicer/transformers/test_datatables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# coding: utf-8

from datetime import date, datetime
from unittest import TestCase

import numpy as np
import pandas as pd

from fireant.slicer.transformers import DataTablesRowIndexTransformer, DataTablesColumnIndexTransformer
from fireant.slicer.transformers import datatables
from fireant.tests.slicer.transformers.base import BaseTransformerTests


Expand Down Expand Up @@ -639,3 +646,29 @@ def test_rollup_cont_cat_cat_dims_multi_metric_df(self):
{'a': {'y': {'one': 28, 'two': 56}, 'z': {'one': 29, 'two': 58}},
'b': {'y': {'one': 30, 'two': 60}, 'z': {'one': 31, 'two': 62}},
'cont': {'display': 7}}]}, result)


class DatatablesUtilityTests(TestCase):
def test_nan_data_point(self):
# Needs to be cast to python int
result = datatables._format_data_point(np.nan)
self.assertIsNone(result)

def test_str_data_point(self):
result = datatables._format_data_point(u'abc')
self.assertEqual('abc', result)

def test_int64_data_point(self):
# Needs to be cast to python int
result = datatables._format_data_point(np.int64(1))
self.assertEqual(int(1), result)

def test_date_data_point(self):
# Needs to be converted to milliseconds
result = datatables._format_data_point(pd.Timestamp(date(2000, 1, 1)))
self.assertEqual('2000-01-01', result)

def test_datetime_data_point(self):
# Needs to be converted to milliseconds
result = datatables._format_data_point(pd.Timestamp(datetime(2000, 1, 1, 1)))
self.assertEqual('2000-01-01T01:00:00', result)

0 comments on commit 5e65c60

Please sign in to comment.