Skip to content

Commit

Permalink
Merge pull request #189 from kayak/matplotlib_widget
Browse files Browse the repository at this point in the history
Matplotlib widget
  • Loading branch information
twheys authored Sep 14, 2018
2 parents a244eef + 5ff8aa8 commit 53301bd
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 123 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ script:
- "coverage run --source=fireant setup.py test"

after_success:
coveralls
coveralls
2 changes: 1 addition & 1 deletion fireant/slicer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def __eq__(self, other):
and str(self.definition) == str(other.definition)

def __hash__(self):
return hash('{}({})'.format(self.__class__.__name__, self.key, self.definition))
return hash('{}({})'.format(self.__class__.__name__, self.key))
5 changes: 0 additions & 5 deletions fireant/slicer/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,6 @@ def __init__(self, key, label=None, definition=None, display_definition=None):
def has_display_field(self):
return hasattr(self, 'display')

def __hash__(self):
if self.has_display_field:
return hash('{}({},{})'.format(self.__class__.__name__, self.definition, self.display_definition))
return super(UniqueDimension, self).__hash__()

def like(self, pattern, *patterns):
if not self.has_display_field:
raise QueryException('No value set for display_definition.')
Expand Down
4 changes: 2 additions & 2 deletions fireant/slicer/queries/makers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from functools import partial
from pypika import JoinType

from fireant.utils import (
flatten,
format_key,
format_dimension_key,
format_key,
format_metric_key,
)
from pypika import JoinType
from .finders import (
find_joins_for_tables,
find_required_tables_to_join,
Expand Down
109 changes: 109 additions & 0 deletions fireant/slicer/widgets/chart_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import (
Iterable,
Union,
)

from fireant import (
Metric,
Operation,
utils,
)
from ..exceptions import MetricRequiredException


class Series:
type = None
needs_marker = False
stacking = None

def __init__(self, metric: Union[Metric, Operation], stacking=None):
self.metric = metric
self.stacking = self.stacking or stacking

def __repr__(self):
return "{}({})".format(self.__class__.__name__,
repr(self.metric))


class ContinuousAxisSeries(Series):
pass


class Axis:
def __init__(self, series: Iterable[Series], label=None, y_axis_visible=True):
self._series = series or []
self.label = label
self.y_axis_visible = y_axis_visible

def __iter__(self):
return iter(self._series)

def __len__(self):
return len(self._series)

def __repr__(self):
return "axis({})".format(", ".join(map(repr, self)))


class ChartWidget:
class LineSeries(ContinuousAxisSeries):
type = 'line'

class AreaSeries(ContinuousAxisSeries):
type = 'area'

class AreaStackedSeries(AreaSeries):
stacking = "normal"

class AreaPercentageSeries(AreaSeries):
stacking = "percent"

class PieSeries(Series):
type = 'pie'

class BarSeries(Series):
type = 'bar'

class StackedBarSeries(BarSeries):
stacking = "normal"

class ColumnSeries(Series):
type = 'column'

class StackedColumnSeries(ColumnSeries):
stacking = "normal"

@utils.immutable
def axis(self, *series: Series, **kwargs):
"""
(Immutable) Adds an axis to the Chart.
:param axis:
:return:
"""

self.items.append(Axis(series, **kwargs))

@property
def metrics(self):
"""
:return:
A set of metrics used in this chart. This collects all metrics across all axes.
"""
if 0 == len(self.items):
raise MetricRequiredException(str(self))

seen = set()
return [metric
for axis in self.items
for series in axis
for metric in getattr(series.metric, 'metrics', [series.metric])
if not (metric.key in seen or seen.add(metric.key))]

@property
def operations(self):
return utils.ordered_distinct_list_by_attr([operation
for axis in self.items
for series in axis
if isinstance(series.metric, Operation)
for operation in [series.metric] + series.metric.operations])
119 changes: 9 additions & 110 deletions fireant/slicer/widgets/highcharts.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,22 @@
import itertools
from datetime import (
datetime,
)
from typing import (
Iterable,
Union,
)
from datetime import datetime

import pandas as pd

from fireant import (
DatetimeDimension,
Metric,
Operation,
formats,
utils,
)
from .base import (
TransformableWidget,
from .base import TransformableWidget
from .chart_base import (
ChartWidget,
ContinuousAxisSeries,
)
from .helpers import (
dimensional_metric_label,
extract_display_values,
)
from ..exceptions import MetricRequiredException
from ..references import (
reference_key,
reference_label,
Expand Down Expand Up @@ -64,70 +57,10 @@
"triangle-down",
)

SERIES_NEEDING_MARKER = (ChartWidget.LineSeries, ChartWidget.AreaSeries)

class Series:
type = None
needs_marker = False
stacking = None

def __init__(self, metric: Union[Metric, Operation], name=None, stacking=None):
self.metric = metric
self.name = name
self.stacking = self.stacking or stacking

def __repr__(self):
return "{}({})".format(self.__class__.__name__,
repr(self.metric))


class ContinuousAxisSeries(Series):
pass


class HighCharts(TransformableWidget):
class Axis:
def __init__(self, series: Iterable[Series], y_axis_visible=True):
self._series = series or []
self.y_axis_visible = y_axis_visible

def __iter__(self):
return iter(self._series)

def __len__(self):
return len(self._series)

def __repr__(self):
return "axis({})".format(", ".join(map(repr, self)))

class LineSeries(ContinuousAxisSeries):
type = 'line'
needs_marker = True

class AreaSeries(ContinuousAxisSeries):
type = 'area'
needs_marker = True

class AreaStackedSeries(AreaSeries):
stacking = "normal"

class AreaPercentageSeries(AreaSeries):
stacking = "percent"

class PieSeries(Series):
type = 'pie'

class BarSeries(Series):
type = 'bar'

class StackedBarSeries(BarSeries):
stacking = "normal"

class ColumnSeries(Series):
type = 'column'

class StackedColumnSeries(ColumnSeries):
stacking = "normal"

class HighCharts(ChartWidget, TransformableWidget):
def __init__(self, title=None, colors=None, x_axis_visible=True, tooltip_visible=True):
super(HighCharts, self).__init__()
self.title = title
Expand All @@ -138,41 +71,6 @@ def __init__(self, title=None, colors=None, x_axis_visible=True, tooltip_visible
def __repr__(self):
return ".".join(["HighCharts()"] + [repr(axis) for axis in self.items])

@utils.immutable
def axis(self, *series: Series, **kwargs):
"""
(Immutable) Adds an axis to the Chart.
:param axis:
:return:
"""

self.items.append(self.Axis(series, **kwargs))

@property
def metrics(self):
"""
:return:
A set of metrics used in this chart. This collects all metrics across all axes.
"""
if 0 == len(self.items):
raise MetricRequiredException(str(self))

seen = set()
return [metric
for axis in self.items
for series in axis
for metric in getattr(series.metric, 'metrics', [series.metric])
if not (metric.key in seen or seen.add(metric.key))]

@property
def operations(self):
return utils.ordered_distinct_list_by_attr([operation
for axis in self.items
for series in axis
if isinstance(series.metric, Operation)
for operation in [series.metric] + series.metric.operations])

def transform(self, data_frame, slicer, dimensions, references):
"""
- Main entry point -
Expand Down Expand Up @@ -327,6 +225,7 @@ def _render_series(self, axis, axis_idx, axis_color, colors, data_frame_groups,
hc_series = []
for series in axis:
symbols = itertools.cycle(MARKER_SYMBOLS)
series_color = next(colors) if has_multi_axis else None

for (dimension_values, group_df), symbol in zip(data_frame_groups, symbols):
dimension_values = utils.wrap_list(dimension_values)
Expand Down Expand Up @@ -363,7 +262,7 @@ def _render_series(self, axis, axis_idx, axis_color, colors, data_frame_groups,
else str(axis_idx)),

"marker": ({"symbol": symbol, "fillColor": axis_color or series_color}
if series.needs_marker
if isinstance(series, SERIES_NEEDING_MARKER)
else {}),

"stacking": series.stacking,
Expand Down
69 changes: 67 additions & 2 deletions fireant/slicer/widgets/matplotlib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,71 @@
import itertools

from fireant import utils
from .base import TransformableWidget
from .chart_base import ChartWidget
from ..references import (
reference_key,
reference_label,
)

MAP_SERIES_TO_PLOT_FUNC = {
ChartWidget.LineSeries: 'line',
ChartWidget.AreaSeries: 'area',
ChartWidget.AreaStackedSeries: 'area',
ChartWidget.AreaPercentageSeries: 'area',
ChartWidget.PieSeries: 'pie',
ChartWidget.BarSeries: 'bar',
ChartWidget.StackedBarSeries: 'bar',
ChartWidget.ColumnSeries: 'bar',
ChartWidget.StackedColumnSeries: 'bar',
}


class Matplotlib(ChartWidget, TransformableWidget):
def __init__(self, title=None):
super(Matplotlib, self).__init__()
self.title = title

class Matplotlib(TransformableWidget):
def transform(self, data_frame, slicer, dimensions, references):
raise NotImplementedError()
import matplotlib.pyplot as plt
data_frame = data_frame.copy()

n_axes = len(self.items)
figsize = (14, 5 * n_axes)
fig, plt_axes = plt.subplots(n_axes,
sharex='row',
figsize=figsize)
fig.suptitle(self.title)

if not hasattr(plt_axes, '__iter__'):
plt_axes = (plt_axes,)

colors = itertools.cycle('bgrcmyk')
for axis, plt_axis in zip(self.items, plt_axes):
for series in axis:
series_color = next(colors)

linestyles = itertools.cycle(['-', '--', '-.', ':'])
for reference in [None] + references:
metric = series.metric
f_metric_key = utils.format_metric_key(reference_key(metric, reference))
f_metric_label = reference_label(metric, reference)

plot = self.get_plot_func_for_series_type(data_frame[f_metric_key], f_metric_label, series)
plot(ax=plt_axis,
title=axis.label,
color=series_color,
stacked=series.stacking is not None,
linestyle=next(linestyles)) \
.legend(loc='center left',
bbox_to_anchor=(1, 0.5))

return plt_axes

@staticmethod
def get_plot_func_for_series_type(pd_series, label, chart_series):
pd_series.name = label
plot = pd_series.plot
plot_func_name = MAP_SERIES_TO_PLOT_FUNC[type(chart_series)]
plot_func = getattr(plot, plot_func_name)
return plot_func
Loading

0 comments on commit 53301bd

Please sign in to comment.