Permalink
Browse files

ENH: per column styles to plot #1559

  • Loading branch information...
1 parent 0e85b60 commit 42f2b7139d395bc4130c37df04a353f1ed1740f1 Chang She committed Jul 3, 2012
Showing with 61 additions and 31 deletions.
  1. +21 −0 pandas/tests/test_graphics.py
  2. +40 −31 pandas/tools/plotting.py
@@ -333,6 +333,27 @@ def test_legend_name(self):
def _check_plot_fails(self, f, *args, **kwargs):
self.assertRaises(Exception, f, *args, **kwargs)
+ @slow
+ def test_style_by_column(self):
+ import matplotlib.pyplot as plt
+ fig = plt.gcf()
+ fig.clf()
+ fig.add_subplot(111)
+
+ df = DataFrame(np.random.randn(100, 3))
+ markers = {0: '^', 1: '+', 2: 'o'}
+ ax = df.plot(style=markers)
+ for i, l in enumerate(ax.get_lines()):
+ self.assertEqual(l.get_marker(), markers[i])
+
+ fig.clf()
+ fig.add_subplot(111)
+ df = DataFrame(np.random.randn(100, 3))
+ markers = ['^', '+', 'o']
+ ax = df.plot(style=markers)
+ for i, l in enumerate(ax.get_lines()):
+ self.assertEqual(l.get_marker(), markers[i])
+
class TestDataFrameGroupByPlots(unittest.TestCase):
@classmethod
@@ -342,7 +342,7 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=True,
def _iter_data(self):
from pandas.core.frame import DataFrame
if isinstance(self.data, (Series, np.ndarray)):
- yield com._stringify(self.label), np.asarray(self.data)
+ yield self.label, np.asarray(self.data)
elif isinstance(self.data, DataFrame):
df = self.data
@@ -356,7 +356,7 @@ def _iter_data(self):
# is this right?
values = df[col].values if not empty else np.zeros(len(df))
- col = com._stringify(col)
+
yield col, values
@property
@@ -548,13 +548,28 @@ def _get_index_name(self):
return name
- def _get_ax_and_style(self, i):
+ def _get_ax(self, i):
if self.subplots:
ax = self.axes[i]
- style = 'k'
else:
- style = '' # empty string ignored
ax = self.ax
+ return ax
+
+ def _get_ax_and_style(self, i, col_name):
+ ax = self._get_ax(i)
+
+ if self.subplots:
+ style = 'k'
+ else:
+ style = ''
+
+ if self.style is not None:
+ if isinstance(self.style, list):
+ style = self.style[i]
+ elif isinstance(self.style, dict):
+ style = self.style[col_name]
+ else:
+ style = self.style
return ax, style
@@ -566,11 +581,10 @@ def _make_plot(self):
from scipy.stats import gaussian_kde
plotf = self._get_plot_function()
for i, (label, y) in enumerate(self._iter_data()):
+ ax, style = self._get_ax_and_style(i, label)
- ax, style = self._get_ax_and_style(i)
+ label = com._stringify(label)
- if self.style:
- style = self.style
gkde = gaussian_kde(y)
sample_range = max(y) - min(y)
ind = np.linspace(min(y) - 0.5 * sample_range,
@@ -608,7 +622,7 @@ def _is_dynamic_freq(self, freq):
def _use_dynamic_x(self):
freq = self._index_freq()
- ax, _ = self._get_ax_and_style(0)
+ ax = self._get_ax(0)
ax_freq = getattr(ax, 'freq', None)
if freq is None: # convert irregular if axes has freq info
freq = ax_freq
@@ -629,11 +643,9 @@ def _make_plot(self):
plotf = self._get_plot_function()
for i, (label, y) in enumerate(self._iter_data()):
+ ax, style = self._get_ax_and_style(i, label)
- ax, style = self._get_ax_and_style(i)
-
- if self.style:
- style = self.style
+ label = com._stringify(label)
mask = com.isnull(y)
if mask.any():
@@ -660,7 +672,7 @@ def _maybe_convert_index(self, data):
freq = get_period_alias(freq)
if freq is None:
- ax, _ = self._get_ax_and_style(0)
+ ax = self._get_ax(0)
freq = getattr(ax, 'freq', None)
if freq is None:
@@ -677,17 +689,16 @@ def _make_ts_plot(self, data, **kwargs):
plotf = self._get_plot_function()
if isinstance(data, Series):
- ax, _ = self._get_ax_and_style(0) #self.axes[0]
-
+ ax = self._get_ax(0) #self.axes[0]
+ style = self.style or ''
label = com._stringify(self.label)
tsplot(data, plotf, ax=ax, label=label, style=self.style,
**kwargs)
ax.grid(self.grid)
else:
for i, col in enumerate(data.columns):
- ax, _ = self._get_ax_and_style(i)
- label = com._stringify(col)
- tsplot(data[col], plotf, ax=ax, label=label, **kwargs)
+ ax, style = self._get_ax_and_style(i, col)
+ tsplot(data[col], plotf, ax=ax, label=col, style=style, **kwargs)
ax.grid(self.grid)
# self.fig.subplots_adjust(wspace=0, hspace=0)
@@ -753,7 +764,7 @@ def _make_plot(self):
rects = []
labels = []
- ax, _ = self._get_ax_and_style(0) #self.axes[0]
+ ax = self._get_ax(0) #self.axes[0]
bar_f = self.bar_f
@@ -762,12 +773,12 @@ def _make_plot(self):
K = self.nseries
for i, (label, y) in enumerate(self._iter_data()):
-
+ label = com._stringify(label)
kwds = self.kwds.copy()
kwds['color'] = colors[i % len(colors)]
if self.subplots:
- ax, _ = self._get_ax_and_style(i) #self.axes[i]
+ ax = self._get_ax(i) #self.axes[i]
rect = bar_f(ax, self.ax_pos, y, 0.5, start=pos_prior,
linewidth=1, **kwds)
ax.set_title(label)
@@ -830,13 +841,10 @@ class HistPlot(MPLPlot):
def plot_frame(frame=None, subplots=False, sharex=True, sharey=False,
- use_index=True,
- figsize=None, grid=False, legend=True, rot=None,
- ax=None, title=None,
- xlim=None, ylim=None, logy=False,
- xticks=None, yticks=None,
- kind='line',
- sort_columns=False, fontsize=None, secondary_y=False, **kwds):
+ use_index=True, figsize=None, grid=False, legend=True, rot=None,
+ ax=None, style=None, title=None, xlim=None, ylim=None, logy=False,
+ xticks=None, yticks=None, kind='line', sort_columns=False,
+ fontsize=None, secondary_y=False, **kwds):
"""
Make line or bar plot of DataFrame's series with the index on the x-axis
using matplotlib / pylab.
@@ -863,6 +871,8 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False,
Place legend on axis subplots
ax : matplotlib axis object, default None
+ style : list or dict
+ matplotlib line style per column
kind : {'line', 'bar', 'barh'}
bar : vertical bar plot
barh : horizontal bar plot
@@ -897,7 +907,7 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False,
raise ValueError('Invalid chart type given %s' % kind)
plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
- legend=legend, ax=ax, fontsize=fontsize,
+ legend=legend, ax=ax, style=style, fontsize=fontsize,
use_index=use_index, sharex=sharex, sharey=sharey,
xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
title=title, grid=grid, figsize=figsize, logy=logy,
@@ -930,7 +940,6 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None,
If not passed, uses gca()
style : string, default matplotlib default
matplotlib line style to use
-
ax : matplotlib axis object
If not passed, uses gca()
kind : {'line', 'bar', 'barh'}

0 comments on commit 42f2b71

Please sign in to comment.