Skip to content

Commit

Permalink
ENH: Make Colorbar inherit from an Axes base class
Browse files Browse the repository at this point in the history
  • Loading branch information
greglucas committed Jun 2, 2021
1 parent 0b741ee commit 569fbc1
Showing 1 changed file with 38 additions and 44 deletions.
82 changes: 38 additions & 44 deletions lib/matplotlib/colorbar.py
Expand Up @@ -310,7 +310,7 @@ def draw(self, renderer):
return ret


class ColorbarBase:
class ColorbarBase(ColorbarAxes):
r"""
Draw a colorbar in an existing axes.
Expand Down Expand Up @@ -417,9 +417,9 @@ def __init__(self, ax, *, cmap=None,
['uniform', 'proportional'], spacing=spacing)

# wrap the axes so that it can be positioned as an inset axes:
ax = ColorbarAxes(ax, userax=userax)
self.ax = ax
ax.set(navigate=False)
super().__init__(ax, userax=userax)
self.ax = self
self.set(navigate=False)

if cmap is None:
cmap = cm.get_cmap()
Expand Down Expand Up @@ -448,24 +448,19 @@ def __init__(self, ax, *, cmap=None,
self.extendrect = extendrect
self.solids = None
self.solids_patches = []
self.lines = []
# self.lines = []

for spine in self.ax.spines.values():
for spine in self.spines.values():
spine.set_visible(False)
for spine in self.ax.outer_ax.spines.values():
for spine in self.outer_ax.spines.values():
spine.set_visible(False)
self.outline = self.ax.spines['outline'] = _ColorbarSpine(self.ax)

self.patch = mpatches.Polygon(
np.empty((0, 2)),
color=mpl.rcParams['axes.facecolor'], linewidth=0.01, zorder=-1)
ax.add_artist(self.patch)
self.outline = self.spines['outline'] = _ColorbarSpine(self)

self.dividers = collections.LineCollection(
[],
colors=[mpl.rcParams['axes.edgecolor']],
linewidths=[0.5 * mpl.rcParams['axes.linewidth']])
self.ax.add_collection(self.dividers)
self.add_collection(self.dividers)

self.locator = None
self.formatter = None
Expand Down Expand Up @@ -519,8 +514,8 @@ def draw_all(self):
# also adds the outline path to self.outline spine:
self._do_extends(extendlen)

self.ax.set_xlim(self.vmin, self.vmax)
self.ax.set_ylim(self.vmin, self.vmax)
self.set_xlim(self.vmin, self.vmax)
self.set_ylim(self.vmin, self.vmax)

# set up the tick locators and formatters. A bit complicated because
# boundary norms + uniform spacing requires a manual locator.
Expand Down Expand Up @@ -548,7 +543,7 @@ def _add_solids(self, X, Y, C):
and any(hatch is not None for hatch in mappable.hatches)):
self._add_solids_patches(X, Y, C, mappable)
else:
self.solids = self.ax.pcolormesh(
self.solids = self.pcolormesh(
X, Y, C, cmap=self.cmap, norm=self.norm, alpha=self.alpha,
edgecolors='none', shading='flat')
if not self.drawedges:
Expand All @@ -569,7 +564,7 @@ def _add_solids_patches(self, X, Y, C, mappable):
facecolor=self.cmap(self.norm(C[i][0])),
hatch=hatches[i], linewidth=0,
antialiased=False, alpha=self.alpha)
self.ax.add_patch(patch)
self.add_patch(patch)
patches.append(patch)
self.solids_patches = patches

Expand Down Expand Up @@ -605,7 +600,7 @@ def _do_extends(self, extendlen):
if self.orientation == 'horizontal':
bounds = bounds[[1, 0, 3, 2]]
xyout = xyout[:, ::-1]
self.ax._set_inner_bounds(bounds)
self._set_inner_bounds(bounds)

# xyout is the path for the spine:
self.outline.set_xy(xyout)
Expand Down Expand Up @@ -634,9 +629,9 @@ def _do_extends(self, extendlen):
color = self.cmap(self.norm(self._values[0]))
patch = mpatches.PathPatch(
mpath.Path(xy), facecolor=color, linewidth=0,
antialiased=False, transform=self.ax.outer_ax.transAxes,
antialiased=False, transform=self.outer_ax.transAxes,
hatch=hatches[0])
self.ax.outer_ax.add_patch(patch)
self.outer_ax.add_patch(patch)
if self._extend_upper():
if not self.extendrect:
# triangle
Expand All @@ -651,8 +646,8 @@ def _do_extends(self, extendlen):
patch = mpatches.PathPatch(
mpath.Path(xy), facecolor=color,
linewidth=0, antialiased=False,
transform=self.ax.outer_ax.transAxes, hatch=hatches[-1])
self.ax.outer_ax.add_patch(patch)
transform=self.outer_ax.transAxes, hatch=hatches[-1])
self.outer_ax.add_patch(patch)
return

def add_lines(self, levels, colors, linewidths, erase=True):
Expand Down Expand Up @@ -699,25 +694,24 @@ def add_lines(self, levels, colors, linewidths, erase=True):
# make a clip path that is just a linewidth bigger than the axes...
fac = np.max(linewidths) / 72
xy = np.array([[0, 0], [1, 0], [1, 1], [0, 1], [0, 0]])
inches = self.ax.get_figure().dpi_scale_trans
inches = self.get_figure().dpi_scale_trans
# do in inches:
xy = inches.inverted().transform(self.ax.transAxes.transform(xy))
xy = inches.inverted().transform(self.transAxes.transform(xy))
xy[[0, 1, 4], 1] -= fac
xy[[2, 3], 1] += fac
# back to axes units...
xy = self.ax.transAxes.inverted().transform(inches.transform(xy))
xy = self.transAxes.inverted().transform(inches.transform(xy))
if self.orientation == 'horizontal':
xy = xy.T
col.set_clip_path(mpath.Path(xy, closed=True),
self.ax.transAxes)
self.ax.add_collection(col)
self.transAxes)
self.add_collection(col)
self.stale = True

def update_ticks(self):
"""
Setup the ticks and ticklabels. This should not be needed by users.
"""
ax = self.ax
# Get the locator and formatter; defaults to self.locator if not None.
self._get_ticker_locator_formatter()
self._long_axis().set_major_locator(self.locator)
Expand Down Expand Up @@ -832,7 +826,7 @@ def minorticks_on(self):
"""
Turn on colorbar minor ticks.
"""
self.ax.minorticks_on()
super().minorticks_on()
self.minorlocator = self._long_axis().get_minor_locator()
self._short_axis().set_minor_locator(ticker.NullLocator())

Expand Down Expand Up @@ -863,9 +857,9 @@ def set_label(self, label, *, loc=None, **kwargs):
Supported keywords are *labelpad* and `.Text` properties.
"""
if self.orientation == "vertical":
self.ax.set_ylabel(label, loc=loc, **kwargs)
self.set_ylabel(label, loc=loc, **kwargs)
else:
self.ax.set_xlabel(label, loc=loc, **kwargs)
self.set_xlabel(label, loc=loc, **kwargs)
self.stale = True

def set_alpha(self, alpha):
Expand All @@ -874,8 +868,8 @@ def set_alpha(self, alpha):

def remove(self):
"""Remove this colorbar from the figure."""
self.ax.inner_ax.remove()
self.ax.outer_ax.remove()
self.inner_ax.remove()
self.outer_ax.remove()

def _ticker(self, locator, formatter):
"""
Expand Down Expand Up @@ -1009,16 +1003,16 @@ def _reset_locator_formatter_scale(self):
((self.boundaries is not None) or
isinstance(self.norm, colors.BoundaryNorm))):
funcs = (self._forward_boundaries, self._inverse_boundaries)
self.ax.set_xscale('function', functions=funcs)
self.ax.set_yscale('function', functions=funcs)
self.set_xscale('function', functions=funcs)
self.set_yscale('function', functions=funcs)
self.__scale = 'function'
elif hasattr(self.norm, '_scale') and (self.norm._scale is not None):
self.ax.set_xscale(self.norm._scale)
self.ax.set_yscale(self.norm._scale)
self.set_xscale(self.norm._scale)
self.set_yscale(self.norm._scale)
self.__scale = self.norm._scale.name
else:
self.ax.set_xscale('linear')
self.ax.set_yscale('linear')
self.set_xscale('linear')
self.set_yscale('linear')
if type(self.norm) is colors.Normalize:
self.__scale = 'linear'
else:
Expand Down Expand Up @@ -1130,14 +1124,14 @@ def _extend_upper(self):
def _long_axis(self):
"""Return the long axis"""
if self.orientation == 'vertical':
return self.ax.yaxis
return self.ax.xaxis
return self.yaxis
return self.xaxis

def _short_axis(self):
"""Return the short axis"""
if self.orientation == 'vertical':
return self.ax.xaxis
return self.ax.yaxis
return self.xaxis
return self.yaxis


class Colorbar(ColorbarBase):
Expand Down

0 comments on commit 569fbc1

Please sign in to comment.