-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #189 from kayak/matplotlib_widget
Matplotlib widget
- Loading branch information
Showing
9 changed files
with
222 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,4 +15,4 @@ script: | |
- "coverage run --source=fireant setup.py test" | ||
|
||
after_success: | ||
coveralls | ||
coveralls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from typing import ( | ||
Iterable, | ||
Union, | ||
) | ||
|
||
from fireant import ( | ||
Metric, | ||
Operation, | ||
utils, | ||
) | ||
from ..exceptions import MetricRequiredException | ||
|
||
|
||
class Series: | ||
type = None | ||
needs_marker = False | ||
stacking = None | ||
|
||
def __init__(self, metric: Union[Metric, Operation], stacking=None): | ||
self.metric = metric | ||
self.stacking = self.stacking or stacking | ||
|
||
def __repr__(self): | ||
return "{}({})".format(self.__class__.__name__, | ||
repr(self.metric)) | ||
|
||
|
||
class ContinuousAxisSeries(Series): | ||
pass | ||
|
||
|
||
class Axis: | ||
def __init__(self, series: Iterable[Series], label=None, y_axis_visible=True): | ||
self._series = series or [] | ||
self.label = label | ||
self.y_axis_visible = y_axis_visible | ||
|
||
def __iter__(self): | ||
return iter(self._series) | ||
|
||
def __len__(self): | ||
return len(self._series) | ||
|
||
def __repr__(self): | ||
return "axis({})".format(", ".join(map(repr, self))) | ||
|
||
|
||
class ChartWidget: | ||
class LineSeries(ContinuousAxisSeries): | ||
type = 'line' | ||
|
||
class AreaSeries(ContinuousAxisSeries): | ||
type = 'area' | ||
|
||
class AreaStackedSeries(AreaSeries): | ||
stacking = "normal" | ||
|
||
class AreaPercentageSeries(AreaSeries): | ||
stacking = "percent" | ||
|
||
class PieSeries(Series): | ||
type = 'pie' | ||
|
||
class BarSeries(Series): | ||
type = 'bar' | ||
|
||
class StackedBarSeries(BarSeries): | ||
stacking = "normal" | ||
|
||
class ColumnSeries(Series): | ||
type = 'column' | ||
|
||
class StackedColumnSeries(ColumnSeries): | ||
stacking = "normal" | ||
|
||
@utils.immutable | ||
def axis(self, *series: Series, **kwargs): | ||
""" | ||
(Immutable) Adds an axis to the Chart. | ||
:param axis: | ||
:return: | ||
""" | ||
|
||
self.items.append(Axis(series, **kwargs)) | ||
|
||
@property | ||
def metrics(self): | ||
""" | ||
:return: | ||
A set of metrics used in this chart. This collects all metrics across all axes. | ||
""" | ||
if 0 == len(self.items): | ||
raise MetricRequiredException(str(self)) | ||
|
||
seen = set() | ||
return [metric | ||
for axis in self.items | ||
for series in axis | ||
for metric in getattr(series.metric, 'metrics', [series.metric]) | ||
if not (metric.key in seen or seen.add(metric.key))] | ||
|
||
@property | ||
def operations(self): | ||
return utils.ordered_distinct_list_by_attr([operation | ||
for axis in self.items | ||
for series in axis | ||
if isinstance(series.metric, Operation) | ||
for operation in [series.metric] + series.metric.operations]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,71 @@ | ||
import itertools | ||
|
||
from fireant import utils | ||
from .base import TransformableWidget | ||
from .chart_base import ChartWidget | ||
from ..references import ( | ||
reference_key, | ||
reference_label, | ||
) | ||
|
||
MAP_SERIES_TO_PLOT_FUNC = { | ||
ChartWidget.LineSeries: 'line', | ||
ChartWidget.AreaSeries: 'area', | ||
ChartWidget.AreaStackedSeries: 'area', | ||
ChartWidget.AreaPercentageSeries: 'area', | ||
ChartWidget.PieSeries: 'pie', | ||
ChartWidget.BarSeries: 'bar', | ||
ChartWidget.StackedBarSeries: 'bar', | ||
ChartWidget.ColumnSeries: 'bar', | ||
ChartWidget.StackedColumnSeries: 'bar', | ||
} | ||
|
||
|
||
class Matplotlib(ChartWidget, TransformableWidget): | ||
def __init__(self, title=None): | ||
super(Matplotlib, self).__init__() | ||
self.title = title | ||
|
||
class Matplotlib(TransformableWidget): | ||
def transform(self, data_frame, slicer, dimensions, references): | ||
raise NotImplementedError() | ||
import matplotlib.pyplot as plt | ||
data_frame = data_frame.copy() | ||
|
||
n_axes = len(self.items) | ||
figsize = (14, 5 * n_axes) | ||
fig, plt_axes = plt.subplots(n_axes, | ||
sharex='row', | ||
figsize=figsize) | ||
fig.suptitle(self.title) | ||
|
||
if not hasattr(plt_axes, '__iter__'): | ||
plt_axes = (plt_axes,) | ||
|
||
colors = itertools.cycle('bgrcmyk') | ||
for axis, plt_axis in zip(self.items, plt_axes): | ||
for series in axis: | ||
series_color = next(colors) | ||
|
||
linestyles = itertools.cycle(['-', '--', '-.', ':']) | ||
for reference in [None] + references: | ||
metric = series.metric | ||
f_metric_key = utils.format_metric_key(reference_key(metric, reference)) | ||
f_metric_label = reference_label(metric, reference) | ||
|
||
plot = self.get_plot_func_for_series_type(data_frame[f_metric_key], f_metric_label, series) | ||
plot(ax=plt_axis, | ||
title=axis.label, | ||
color=series_color, | ||
stacked=series.stacking is not None, | ||
linestyle=next(linestyles)) \ | ||
.legend(loc='center left', | ||
bbox_to_anchor=(1, 0.5)) | ||
|
||
return plt_axes | ||
|
||
@staticmethod | ||
def get_plot_func_for_series_type(pd_series, label, chart_series): | ||
pd_series.name = label | ||
plot = pd_series.plot | ||
plot_func_name = MAP_SERIES_TO_PLOT_FUNC[type(chart_series)] | ||
plot_func = getattr(plot, plot_func_name) | ||
return plot_func |
Oops, something went wrong.