Skip to content

Commit

Permalink
Changed from lists to Grouper
Browse files Browse the repository at this point in the history
  • Loading branch information
jklymak committed Jan 13, 2018
1 parent a8f72f3 commit 17d1025
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 18 deletions.
53 changes: 40 additions & 13 deletions lib/matplotlib/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,6 @@ def __init__(self, axes, pickradius=15):

self._autolabelpos = True
self._smart_bounds = False
self._align_label_siblings = [self]

self.label = self._get_label()
self.labelpad = rcParams['axes.labelpad']
Expand Down Expand Up @@ -1673,22 +1672,14 @@ def set_ticks(self, ticks, minor=False):
self.set_major_locator(mticker.FixedLocator(ticks))
return self.get_major_ticks(len(ticks))

def _get_tick_boxes_siblings(self, renderer):
def _get_tick_boxes_siblings(self, xdir, renderer):
"""
Get the bounding boxes for this `.axis` and its siblings
as set by `.Figure.align_xlabels` or `.Figure.align_ylablels`.
By default it just gets bboxes for self.
"""
bboxes = []
bboxes2 = []
# if we want to align labels from other axes:
for axx in self._align_label_siblings:
ticks_to_draw = axx._update_ticks(renderer)
tlb, tlb2 = axx._get_tick_bboxes(ticks_to_draw, renderer)
bboxes.extend(tlb)
bboxes2.extend(tlb2)
return bboxes, bboxes2
raise NotImplementedError('Derived must override')

def _update_label_position(self, renderer):
"""
Expand Down Expand Up @@ -1866,6 +1857,24 @@ def set_label_position(self, position):
self.label_position = position
self.stale = True

def _get_tick_boxes_siblings(self, renderer):
"""
Get the bounding boxes for this `.axis` and its siblings
as set by `.Figure.align_xlabels` or `.Figure.align_ylablels`.
By default it just gets bboxes for self.
"""
bboxes = []
bboxes2 = []
grp = self.figure._align_xlabel_grp
# if we want to align labels from other axes:
for axx in grp.get_siblings(self.axes):
ticks_to_draw = axx.xaxis._update_ticks(renderer)
tlb, tlb2 = axx.xaxis._get_tick_bboxes(ticks_to_draw, renderer)
bboxes.extend(tlb)
bboxes2.extend(tlb2)
return bboxes, bboxes2

def _update_label_position(self, renderer):
"""
Update the label position based on the bounding box enclosing
Expand All @@ -1876,7 +1885,7 @@ def _update_label_position(self, renderer):

# get bounding boxes for this axis and any siblings
# that have been set by `fig.align_xlabels()`
bboxes, bboxes2 = self._get_tick_boxes_siblings(renderer)
bboxes, bboxes2 = self._get_tick_boxes_siblings(renderer=renderer)

x, y = self.label.get_position()
if self.label_position == 'bottom':
Expand Down Expand Up @@ -2216,6 +2225,24 @@ def set_label_position(self, position):
self.label_position = position
self.stale = True

def _get_tick_boxes_siblings(self, renderer):
"""
Get the bounding boxes for this `.axis` and its siblings
as set by `.Figure.align_xlabels` or `.Figure.align_ylablels`.
By default it just gets bboxes for self.
"""
bboxes = []
bboxes2 = []
grp = self.figure._align_ylabel_grp
# if we want to align labels from other axes:
for axx in grp.get_siblings(self.axes):
ticks_to_draw = axx.yaxis._update_ticks(renderer)
tlb, tlb2 = axx.yaxis._get_tick_bboxes(ticks_to_draw, renderer)
bboxes.extend(tlb)
bboxes2.extend(tlb2)
return bboxes, bboxes2

def _update_label_position(self, renderer):
"""
Update the label position based on the bounding box enclosing
Expand All @@ -2226,7 +2253,7 @@ def _update_label_position(self, renderer):

# get bounding boxes for this axis and any siblings
# that have been set by `fig.align_ylabels()`
bboxes, bboxes2 = self._get_tick_boxes_siblings(renderer)
bboxes, bboxes2 = self._get_tick_boxes_siblings(renderer=renderer)

x, y = self.label.get_position()
if self.label_position == 'left':
Expand Down
24 changes: 19 additions & 5 deletions lib/matplotlib/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ def __init__(self,
self.clf()
self._cachedRenderer = None

# groupers to keep track of x and y labels we want to align.
self._align_xlabel_grp = cbook.Grouper()
self._align_ylabel_grp = cbook.Grouper()

@property
@cbook.deprecated("2.1", alternative="Figure.patch")
def figurePatch(self):
Expand Down Expand Up @@ -2103,6 +2107,11 @@ def align_xlabels(self, axs=None):
Optional list of (or ndarray) `~matplotlib.axes.Axes` to align
the xlabels. Default is to align all axes on the figure.
Note
----
This assumes that ``axs`` are from the same `~.GridSpec`, so that
their `~.SubplotSpec` positions correspond to figure positions.
See Also
--------
matplotlib.figure.Figure.align_ylabels
Expand All @@ -2124,7 +2133,6 @@ def align_xlabels(self, axs=None):

if axs is None:
axs = self.axes

axs = np.asarray(np.array(axs)).flatten().tolist()

for ax in axs:
Expand All @@ -2136,7 +2144,7 @@ def align_xlabels(self, axs=None):
# loop through other axes, and search for label positions
# that are same as this one, and that share the appropriate
# row number.
# Add to a list associated with each axes of sibblings.
# Add to a grouper associated with each axes of sibblings.
# This list is inspected in `axis.draw` by
# `axis._update_label_position`.
for axc in axs:
Expand All @@ -2146,7 +2154,8 @@ def align_xlabels(self, axs=None):
ss.get_rows_columns()
if (labpo == 'bottom' and rowc1 == row1 or
labpo == 'top' and rowc0 == row0):
axc.xaxis._align_label_siblings += [ax.xaxis]
# grouper for groups of xlabels to align
self._align_xlabel_grp.join(ax, axc)

def align_ylabels(self, axs=None):
"""
Expand All @@ -2167,6 +2176,11 @@ def align_ylabels(self, axs=None):
Optional list (or ndarray) of `~matplotlib.axes.Axes` to align
the ylabels. Default is to align all axes on the figure.
Note
----
This assumes that ``axs`` are from the same `~.GridSpec`, so that
their `~.SubplotSpec` positions correspond to figure positions.
See Also
--------
matplotlib.figure.Figure.align_xlabels
Expand All @@ -2187,7 +2201,6 @@ def align_ylabels(self, axs=None):

if axs is None:
axs = self.axes

axs = np.asarray(np.array(axs)).flatten().tolist()
for ax in axs:
_log.debug(' Working on: %s', ax.get_ylabel())
Expand All @@ -2209,7 +2222,8 @@ def align_ylabels(self, axs=None):
ss.get_rows_columns()
if (labpo == 'left' and colc0 == col0 or
labpo == 'right' and colc1 == col1):
axc.yaxis._align_label_siblings += [ax.yaxis]
# grouper for groups of ylabels to align
self._align_ylabel_grp.join(ax, axc)

def align_labels(self, axs=None):
"""
Expand Down

0 comments on commit 17d1025

Please sign in to comment.