Skip to content

Commit

Permalink
Merge pull request #1225 from ioam/mpl_rc_context
Browse files Browse the repository at this point in the history
Apply matplotlib rc parameters correctly throughout
  • Loading branch information
jlstevens authored Mar 24, 2017
2 parents ea02488 + c554db1 commit 8974ed1
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 9 deletions.
2 changes: 2 additions & 0 deletions holoviews/plotting/mpl/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ...core.util import match_spec
from ...core.options import abbreviated_exception
from .element import ElementPlot
from .plot import mpl_rc_context


class AnnotationPlot(ElementPlot):
Expand All @@ -16,6 +17,7 @@ def __init__(self, annotation, **params):
super(AnnotationPlot, self).__init__(annotation, **params)
self.handles['annotations'] = []

@mpl_rc_context
def initialize_plot(self, ranges=None):
annotation = self.hmap.last
key = self.keys[-1]
Expand Down
4 changes: 3 additions & 1 deletion holoviews/plotting/mpl/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
get_min_distance)
from .element import ElementPlot, ColorbarPlot, LegendPlot
from .path import PathPlot
from .plot import AdjoinedPlot
from .plot import AdjoinedPlot, mpl_rc_context


class ChartPlot(ElementPlot):
Expand Down Expand Up @@ -277,6 +277,7 @@ def __init__(self, histograms, **params):
self.cyclic_range = val_dim.range if val_dim.cyclic else None


@mpl_rc_context
def initialize_plot(self, ranges=None):
hist = self.hmap.last
key = self.keys[-1]
Expand Down Expand Up @@ -788,6 +789,7 @@ def get_extents(self, element, ranges):
return 0, np.nanmin([vrange[0], 0]), ngroups, vrange[1]


@mpl_rc_context
def initialize_plot(self, ranges=None):
element = self.hmap.last
vdim = element.vdims[0]
Expand Down
4 changes: 3 additions & 1 deletion holoviews/plotting/mpl/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...core.options import abbreviated_exception
from ..plot import GenericElementPlot, GenericOverlayPlot
from ..util import dynamic_update
from .plot import MPLPlot
from .plot import MPLPlot, mpl_rc_context
from .util import wrap_formatter
from distutils.version import LooseVersion

Expand Down Expand Up @@ -465,6 +465,7 @@ def update_frame(self, key, ranges=None, element=None):
self._finalize_axis(key, ranges=ranges, **(axis_kwargs if axis_kwargs else {}))


@mpl_rc_context
def initialize_plot(self, ranges=None):
element = self.hmap.last
ax = self.handles['axis']
Expand Down Expand Up @@ -814,6 +815,7 @@ def _adjust_legend(self, overlay, axis):
self.handles['legend_data'] = data


@mpl_rc_context
def initialize_plot(self, ranges=None):
axis = self.handles['axis']
key = self.keys[-1]
Expand Down
16 changes: 16 additions & 0 deletions holoviews/plotting/mpl/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
from .util import compute_ratios, fix_aspect


def mpl_rc_context(f):
"""
Applies matplotlib rc params while when method is called.
"""
def wrapper(self, *args, **kwargs):
with mpl.rc_context(rc=self.fig_rcparams):
return f(self, *args, **kwargs)
return wrapper


class MPLPlot(DimensionedPlot):
"""
An MPLPlot object draws a matplotlib figure object when called or
Expand Down Expand Up @@ -208,6 +218,7 @@ def anim(self, start=0, stop=None, fps=30):
if self._close_figures: plt.close(figure)
return anim


def update(self, key):
if len(self) == 1 and key == 0 and not self.drawn:
return self.initialize_plot()
Expand All @@ -221,6 +232,7 @@ class CompositePlot(GenericCompositePlot, MPLPlot):
subplots to form a Layout.
"""

@mpl_rc_context
def update_frame(self, key, ranges=None):
ranges = self.compute_ranges(self.layout, key, ranges)
for subplot in self.subplots.values():
Expand Down Expand Up @@ -413,6 +425,7 @@ def _create_subplots(self, layout, axis, ranges, create_axes):
return subplots, subaxes, collapsed_layout


@mpl_rc_context
def initialize_plot(self, ranges=None):
# Get the extent of the layout elements (not the whole layout)
key = self.keys[-1]
Expand Down Expand Up @@ -586,6 +599,7 @@ def __init__(self, layout, layout_type, subaxes, subplots, **params):
super(AdjointLayoutPlot, self).__init__(subplots=subplots, **params)


@mpl_rc_context
def initialize_plot(self, ranges=None):
"""
Plot all the views contained in the AdjointLayout Object using axes
Expand Down Expand Up @@ -657,6 +671,7 @@ def adjust_positions(self, redraw=True):
ax.set_aspect('equal')


@mpl_rc_context
def update_frame(self, key, ranges=None):
for pos in self.view_positions:
subplot = self.subplots.get(pos)
Expand Down Expand Up @@ -1015,6 +1030,7 @@ def _create_subplots(self, layout, positions, layout_dimensions, ranges, axes={}
return subplots, adjoint_clone, projections


@mpl_rc_context
def initialize_plot(self):
key = self.keys[-1]
ranges = self.compute_ranges(self.layout, key, None)
Expand Down
3 changes: 2 additions & 1 deletion holoviews/plotting/mpl/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ...core.util import match_spec, max_range, unique_iterator, unique_array, is_nan
from ...element.raster import Image, Raster, RGB
from .element import ColorbarPlot, OverlayPlot
from .plot import MPLPlot, GridPlot
from .plot import MPLPlot, GridPlot, mpl_rc_context


class RasterPlot(ColorbarPlot):
Expand Down Expand Up @@ -324,6 +324,7 @@ def _get_frame(self, key):
return GridPlot._get_frame(self, key)


@mpl_rc_context
def initialize_plot(self, ranges=None):
_, _, b_w, b_h, widths, heights = self.border_extents

Expand Down
6 changes: 4 additions & 2 deletions holoviews/plotting/mpl/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,13 @@ def __call__(self, obj, fmt='auto'):
if isinstance(plot, tuple(self.widgets.values())):
data = plot()
elif fmt in ['png', 'svg', 'pdf', 'html', 'json']:
data = self._figure_data(plot, fmt, **({'dpi':self.dpi} if self.dpi else {}))
with mpl.rc_context(rc=plot.fig_rcparams):
data = self._figure_data(plot, fmt, **({'dpi':self.dpi} if self.dpi else {}))
else:
if sys.version_info[0] == 3 and mpl.__version__[:-2] in ['1.2', '1.3']:
raise Exception("<b>Python 3 matplotlib animation support broken &lt;= 1.3</b>")
anim = plot.anim(fps=self.fps)
with mpl.rc_context(rc=plot.fig_rcparams):
anim = plot.anim(fps=self.fps)
data = self._anim_data(anim, fmt)

data = self._apply_post_render_hooks(data, obj, fmt)
Expand Down
4 changes: 3 additions & 1 deletion holoviews/plotting/mpl/seaborn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ...core.options import Store
from .element import ElementPlot
from .pandas import DFrameViewPlot
from .plot import MPLPlot, AdjoinedPlot
from .plot import MPLPlot, AdjoinedPlot, mpl_rc_context


class SeabornPlot(ElementPlot):
Expand Down Expand Up @@ -230,6 +230,7 @@ def __init__(self, view, **params):
super(SNSFramePlot, self).__init__(view, **params)


@mpl_rc_context
def initialize_plot(self, ranges=None):
dfview = self.hmap.last
axis = self.handles['axis']
Expand Down Expand Up @@ -258,6 +259,7 @@ def _validate(self, dfview):
raise Exception("Multiple %s plots cannot be composed."
% self.plot_type)

@mpl_rc_context
def update_frame(self, key, ranges=None):
element = self.hmap.get(key, None)
axis = self.handles['axis']
Expand Down
7 changes: 4 additions & 3 deletions holoviews/plotting/mpl/tabular.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from collections import defaultdict
from matplotlib.font_manager import FontProperties
from matplotlib.table import Table as mpl_Table
from holoviews.core.util import unicode

import param

from ...core.util import bytes_to_unicode, unicode
from .element import ElementPlot
from ...core.util import bytes_to_unicode
from .plot import mpl_rc_context



class TablePlot(ElementPlot):
Expand Down Expand Up @@ -93,6 +93,7 @@ def _cell_value(self, element, row, col):
return cell_text


@mpl_rc_context
def initialize_plot(self, ranges=None):
element = self.hmap.last
axis = self.handles['axis']
Expand Down
15 changes: 15 additions & 0 deletions tests/testplotinstantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,21 @@ def test_layout_instantiate_subplots_transposed(self):
if 'main' in adjoint.subplots:
self.assertEqual(adjoint.subplots['main'].layout_num, num)

def test_points_rcparams_do_not_persist(self):
opts = dict(fig_rcparams={'text.usetex': True})
points = Points(([0, 1], [0, 3]))(plot=opts)
plot = mpl_renderer.get_plot(points)
self.assertFalse(pyplot.rcParams['text.usetex'])

def test_points_rcparams_used(self):
opts = dict(fig_rcparams={'grid.color': 'red'})
points = Points(([0, 1], [0, 3]))(plot=opts)
plot = mpl_renderer.get_plot(points)
ax = plot.state.axes[0]
lines = ax.get_xgridlines()
self.assertEqual(lines[0].get_color(), 'red')



class TestBokehPlotInstantiation(ComparisonTestCase):

Expand Down

0 comments on commit 8974ed1

Please sign in to comment.