Skip to content

Uses tight_layout.get_subplotspec_list to check if all axes are compatible w/ tight_layout #1170

Merged
merged 2 commits into from Sep 1, 2012
View
15 lib/matplotlib/figure.py
@@ -1422,17 +1422,20 @@ def tight_layout(self, renderer=None, pad=1.08, h_pad=None, w_pad=None, rect=Non
labels) will fit into. Default is (0, 0, 1, 1).
"""
- from tight_layout import get_renderer, get_tight_layout_figure
+ from tight_layout import (get_renderer, get_tight_layout_figure,
+ get_subplotspec_list)
- subplot_axes = [ax for ax in self.axes if isinstance(ax, SubplotBase)]
- if len(subplot_axes) < len(self.axes):
- warnings.warn("tight_layout can only process Axes that descend "
- "from SubplotBase; results might be incorrect.")
+ subplotspec_list = get_subplotspec_list(self.axes)
+ if None in subplotspec_list:
+ warnings.warn("This figure includes Axes that are not "
+ "compatible with tight_layout, so its "
+ "results might be incorrect.")
if renderer is None:
renderer = get_renderer(self)
- kwargs = get_tight_layout_figure(self, subplot_axes, renderer,
+ kwargs = get_tight_layout_figure(self, self.axes, subplotspec_list,
+ renderer,
pad=pad, h_pad=h_pad, w_pad=w_pad,
rect=rect)
View
56 lib/matplotlib/tight_layout.py
@@ -209,7 +209,33 @@ def get_renderer(fig):
return renderer
-def get_tight_layout_figure(fig, axes_list, renderer,
+def get_subplotspec_list(axes_list):
+ """
+ Return a list of subplotspec from the given list of axes. For an
+ instance of axes that does not support subplotspec, None is
+ inserted in the list.
+
+ """
+ subplotspec_list = []
+ for ax in axes_list:
+ axes_or_locator = ax.get_axes_locator()
+ if axes_or_locator is None:
+ axes_or_locator = ax
+
+ if hasattr(axes_or_locator, "get_subplotspec"):
+ subplotspec = axes_or_locator.get_subplotspec()
+ subplotspec = subplotspec.get_topmost_subplotspec()
+ if subplotspec.get_gridspec().locally_modified_subplot_params():
+ subplotspec = None
+ else:
+ subplotspec = None
+
+ subplotspec_list.append(subplotspec)
+
+ return subplotspec_list
+
+
+def get_tight_layout_figure(fig, axes_list, subplotspec_list, renderer,
pad=1.08, h_pad=None, w_pad=None, rect=None):
"""
Return subplot parameters for tight-layouted-figure with specified
@@ -221,6 +247,9 @@ def get_tight_layout_figure(fig, axes_list, renderer,
*axes_list* : a list of axes
+ *subplotspec_list* : a list of subplotspec associated with each
+ axes in axes_list
+
*renderer* : renderer instance
*pad* : float
@@ -238,27 +267,20 @@ def get_tight_layout_figure(fig, axes_list, renderer,
"""
- subplotspec_list = []
subplot_list = []
nrows_list = []
ncols_list = []
ax_bbox_list = []
- subplot_dict = {} # for axes_grid1, multiple axes can share
- # same subplot_interface. Thus we need to
- # join them together.
+ subplot_dict = {} # multiple axes can share
+ # same subplot_interface (e.g, axes_grid1). Thus
+ # we need to join them together.
- for ax in axes_list:
- locator = ax.get_axes_locator()
- if hasattr(locator, "get_subplotspec"):
- subplotspec = locator.get_subplotspec().get_topmost_subplotspec()
- elif hasattr(ax, "get_subplotspec"):
- subplotspec = ax.get_subplotspec().get_topmost_subplotspec()
- else:
- continue
+ subplotspec_list2 = []
- if (subplotspec is None) or \
- subplotspec.get_gridspec().locally_modified_subplot_params():
+ for ax, subplotspec in zip(axes_list,
+ subplotspec_list):
+ if subplotspec is None:
continue
subplots = subplot_dict.setdefault(subplotspec, [])
@@ -267,7 +289,7 @@ def get_tight_layout_figure(fig, axes_list, renderer,
myrows, mycols, _, _ = subplotspec.get_geometry()
nrows_list.append(myrows)
ncols_list.append(mycols)
- subplotspec_list.append(subplotspec)
+ subplotspec_list2.append(subplotspec)
subplot_list.append(subplots)
ax_bbox_list.append(subplotspec.get_position(fig))
@@ -277,7 +299,7 @@ def get_tight_layout_figure(fig, axes_list, renderer,
max_ncols = max(ncols_list)
num1num2_list = []
- for subplotspec in subplotspec_list:
+ for subplotspec in subplotspec_list2:
rows, cols, num1, num2 = subplotspec.get_geometry()
div_row, mod_row = divmod(max_nrows, rows)
div_col, mod_col = divmod(max_ncols, cols)
Something went wrong with that request. Please try again.