Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse Grid.__init__ in ImageGrid.__init__. #15670

Merged
merged 1 commit into from
Jan 6, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
94 changes: 27 additions & 67 deletions lib/mpl_toolkits/axes_grid1/axes_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ def __init__(self, fig,
share_all=False,
share_x=True,
share_y=True,
#aspect=True,
label_mode="L",
axes_class=None,
*,
aspect=False,
):
"""
Parameters
Expand All @@ -136,6 +137,7 @@ def __init__(self, fig,
- "all": all axes are labelled.

axes_class : subclass of `matplotlib.axes.Axes`, default: None
aspect : bool, default: False
"""
self._nrows, self._ncols = nrows_ncols

Expand All @@ -155,7 +157,7 @@ def __init__(self, fig,
if axes_class is None:
axes_class = self._defaultAxesClass

kw = dict(horizontal=[], vertical=[], aspect=False)
kw = dict(horizontal=[], vertical=[], aspect=aspect)
if isinstance(rect, (str, Number, SubplotSpec)):
self._divider = SubplotDivider(fig, rect, **kw)
elif len(rect) == 3:
Expand All @@ -182,7 +184,7 @@ def __init__(self, fig,
self.axes_row = axes_array.tolist()
self.axes_llc = self.axes_column[0][-1]

self._update_locators()
self._init_locators()

if add_all:
for ax in self.axes_all:
Expand All @@ -195,7 +197,7 @@ def _init_axes_pad(self, axes_pad):
self._horiz_pad_size = Size.Fixed(axes_pad[0])
self._vert_pad_size = Size.Fixed(axes_pad[1])

def _update_locators(self):
def _init_locators(self):

h = []
h_ax_pos = []
Expand Down Expand Up @@ -401,73 +403,20 @@ def __init__(self, fig,
to associated *cbar_axes*.
axes_class : subclass of `matplotlib.axes.Axes`, default: None
"""
self._nrows, self._ncols = nrows_ncols

if ngrids is None:
ngrids = self._nrows * self._ncols
else:
if not 0 < ngrids <= self._nrows * self._ncols:
raise Exception

self.ngrids = ngrids

self._init_axes_pad(axes_pad)

self._colorbar_mode = cbar_mode
self._colorbar_location = cbar_location
if cbar_pad is None:
# horizontal or vertical arrangement?
if cbar_location in ("left", "right"):
self._colorbar_pad = self._horiz_pad_size.fixed_size
else:
self._colorbar_pad = self._vert_pad_size.fixed_size
else:
self._colorbar_pad = cbar_pad

self._colorbar_pad = cbar_pad
self._colorbar_size = cbar_size
# The colorbar axes are created in _init_locators().

cbook._check_in_list(["column", "row"], direction=direction)
self._direction = direction

if axes_class is None:
axes_class = self._defaultAxesClass

kw = dict(horizontal=[], vertical=[], aspect=aspect)
if isinstance(rect, (str, Number, SubplotSpec)):
self._divider = SubplotDivider(fig, rect, **kw)
elif len(rect) == 3:
self._divider = SubplotDivider(fig, *rect, **kw)
elif len(rect) == 4:
self._divider = Divider(fig, rect, **kw)
else:
raise Exception("")

rect = self._divider.get_position()

axes_array = np.full((self._nrows, self._ncols), None, dtype=object)
for i in range(self.ngrids):
col, row = self._get_col_row(i)
if share_all:
sharex = sharey = axes_array[0, 0]
else:
sharex = axes_array[0, col]
sharey = axes_array[row, 0]
axes_array[row, col] = axes_class(
fig, rect, sharex=sharex, sharey=sharey)
self.axes_all = axes_array.ravel().tolist()
self.axes_column = axes_array.T.tolist()
self.axes_row = axes_array.tolist()
self.axes_llc = self.axes_column[0][-1]

self.cbar_axes = [
self._defaultCbarAxesClass(fig, rect,
orientation=self._colorbar_location)
for _ in range(self.ngrids)]

self._update_locators()
super().__init__(
fig, rect, nrows_ncols, ngrids,
direction=direction, axes_pad=axes_pad, add_all=add_all,
share_all=share_all, share_x=True, share_y=True, aspect=aspect,
label_mode=label_mode, axes_class=axes_class)

if add_all:
for ax in self.axes_all+self.cbar_axes:
for ax in self.cbar_axes:
fig.add_axes(ax)

if cbar_set_cax:
Expand All @@ -485,9 +434,20 @@ def __init__(self, fig,
for ax, cax in zip(self.axes_all, self.cbar_axes):
ax.cax = cax

self.set_label_mode(label_mode)
def _init_locators(self):
# Slightly abusing this method to inject colorbar creation into init.

def _update_locators(self):
if self._colorbar_pad is None:
timhoffm marked this conversation as resolved.
Show resolved Hide resolved
# horizontal or vertical arrangement?
if self._colorbar_location in ("left", "right"):
self._colorbar_pad = self._horiz_pad_size.fixed_size
else:
self._colorbar_pad = self._vert_pad_size.fixed_size
self.cbar_axes = [
self._defaultCbarAxesClass(
self.axes_all[0].figure, self._divider.get_position(),
orientation=self._colorbar_location)
for _ in range(self.ngrids)]

cb_mode = self._colorbar_mode
cb_location = self._colorbar_location
Expand Down