|
| 1 | +from collections.abc import Iterable |
| 2 | +from itertools import cycle |
| 3 | + |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from .. import decorators as decs |
| 7 | +from .. import utils |
| 8 | +from . import tools |
| 9 | + |
| 10 | +major_grid_style = 'solid' |
| 11 | +minor_grid_style = (0, (1, 2)) |
| 12 | + |
| 13 | + |
| 14 | +def layout( |
| 15 | + axs, |
| 16 | + *, |
| 17 | + axis: str = 'both', |
| 18 | + title: str | list = None, |
| 19 | + x_label: str = None, |
| 20 | + y_label: str = None, |
| 21 | + abc: str | bool = None, |
| 22 | + make_square: bool = None, |
| 23 | + margins: float = None, |
| 24 | + aspect: str | float | tuple = None, |
| 25 | + ticks: tools._tick_vis = None, |
| 26 | + grid: tools._grid_vis = None, |
| 27 | + minor: bool = None, |
| 28 | + spines: tools._tick_vis = None, |
| 29 | + x_breaks: list[float] = None, |
| 30 | + y_breaks: list[float] = None, |
| 31 | + x_lims: list[float] = None, |
| 32 | + y_lims: list[float] = None, |
| 33 | + x_scale: str = None, |
| 34 | + y_scale: str = None, |
| 35 | + x_tick_labels: list[str] = None, |
| 36 | + y_tick_labels: list[str] = None, |
| 37 | + **kwargs, |
| 38 | +): |
| 39 | + # decompose kwargs into title, label, tick, and grid settings |
| 40 | + title_settings = utils.get_hook_dict(kwargs, 'title', remove_hook=True) |
| 41 | + label_settings = utils.get_hook_dict(kwargs, 'label', remove_hook=True) |
| 42 | + |
| 43 | + tick_settings = utils.get_hook_dict(kwargs, 'tick', remove_hook=True) |
| 44 | + grid_settings = utils.get_hook_dict(kwargs, 'grid', remove_hook=False) |
| 45 | + tick_settings.update(grid_settings) |
| 46 | + |
| 47 | + # ensure axs is a list |
| 48 | + if not isinstance(axs, Iterable): |
| 49 | + axs = [axs] |
| 50 | + if not isinstance(title, list): |
| 51 | + title = [title] |
| 52 | + |
| 53 | + handle_abc_labels(axs, abc, **kwargs) |
| 54 | + |
| 55 | + pairs = list(zip(axs, cycle(title))) |
| 56 | + |
| 57 | + for ax, title in pairs: |
| 58 | + # handle ticks, grid, and spine visibility |
| 59 | + # NOTE when axis != 'both', ticks=True does weird things... seems to be a matplotlib issue |
| 60 | + handle_tick_settings(ax, axis, ticks, minor, grid, tick_settings) |
| 61 | + |
| 62 | + # handle other layout elements |
| 63 | + handle_title(ax, title, title_settings) |
| 64 | + handle_labels(ax, axis, x_label, y_label, label_settings) |
| 65 | + handle_tick_labels(ax, x_tick_labels, y_tick_labels) |
| 66 | + |
| 67 | + handle_spines(ax, spines) |
| 68 | + handle_breaks(ax, x_breaks, y_breaks) |
| 69 | + handle_scales(ax, x_scale, y_scale) |
| 70 | + handle_lims(ax, x_lims, y_lims) |
| 71 | + |
| 72 | + handle_aspect(ax, aspect) |
| 73 | + |
| 74 | + # TODO when x_lim/y_lim are set, margins don't work as expected |
| 75 | + handle_margins(ax, margins, make_square) |
| 76 | + |
| 77 | + if make_square is True: |
| 78 | + tools.axis_ratio(ax, yx_ratio=1, margins=margins, how='lims') |
| 79 | + |
| 80 | + |
| 81 | +def handle_abc_labels(axs, abc=None, **kwargs): |
| 82 | + if abc: |
| 83 | + ax_labels = np.arange(1, len(axs) + 1) |
| 84 | + if abc == 'ABC': |
| 85 | + ax_labels = [chr(64 + num) for num in ax_labels] |
| 86 | + elif abc == 'abc': |
| 87 | + ax_labels = [chr(96 + num) for num in ax_labels] |
| 88 | + |
| 89 | + abc_params = utils.get_hook_dict(kwargs, 'abc') |
| 90 | + abc_params['loc'] = abc_params['loc'] if 'loc' in abc_params else 'tl' |
| 91 | + abc_params['size'] = abc_params['size'] if 'size' in abc_params else 18 |
| 92 | + for i, ax in enumerate(axs): |
| 93 | + decs.place_abc_label( |
| 94 | + ax, |
| 95 | + label=ax_labels[i], |
| 96 | + **abc_params, |
| 97 | + ) |
| 98 | + |
| 99 | + |
| 100 | +def handle_tick_grid_vis(ax, axis, ticks, minor, grid, tick_settings): |
| 101 | + tools.set_minor_ticks_by_axis(ax, axis=axis) |
| 102 | + minor = False if minor is None else minor |
| 103 | + |
| 104 | + tools.set_tick_visibility(ax, axis=axis, ticks=ticks, minor=minor) |
| 105 | + |
| 106 | + # set axis below if no grid zorder is specified to make sure grid lines are below other plot elements |
| 107 | + ax_below = False if 'grid_zorder' in tick_settings else True |
| 108 | + ax.set_axisbelow(ax_below) |
| 109 | + |
| 110 | + maj_grid, min_grid = tools.set_grid_visibility(ax, axis=axis, grid=grid, minor=minor, apply=False) |
| 111 | + tick_settings['gridOn'] = [maj_grid, min_grid] |
| 112 | + |
| 113 | + # Set default grid style, since rcParams don't offer minor grid style |
| 114 | + if 'grid_linestyle' not in tick_settings: |
| 115 | + tick_settings['grid_linestyle'] = [major_grid_style, minor_grid_style] |
| 116 | + |
| 117 | + |
| 118 | +def handle_text_element(getter, setter, text: str = None, params: dict = {}): |
| 119 | + """Generic helper to get current text if needed and set it with params.""" |
| 120 | + if text is None and len(params) == 0: |
| 121 | + return |
| 122 | + |
| 123 | + if text is None and len(params) > 0: |
| 124 | + text = getter() |
| 125 | + |
| 126 | + setter(text, **params) |
| 127 | + |
| 128 | + |
| 129 | +def handle_tick_settings(ax, axis, ticks, minor, grid, tick_settings): |
| 130 | + if ticks is None and minor is None and grid is None and len(tick_settings) == 0: |
| 131 | + return |
| 132 | + |
| 133 | + # first all the visibility settings |
| 134 | + handle_tick_grid_vis(ax, axis, ticks, minor, grid, tick_settings) |
| 135 | + |
| 136 | + # tick (and grid) settings are applied separately for major and minor ticks |
| 137 | + majmin_settings = {k: utils.maj_min_args(maj_min=v) for k, v in tick_settings.items()} |
| 138 | + |
| 139 | + for i, which in enumerate(['major', 'minor']): |
| 140 | + tick_settings_select = {k: v[i] for k, v in majmin_settings.items()} |
| 141 | + ax.tick_params(axis=axis, which=which, **tick_settings_select) |
| 142 | + |
| 143 | + |
| 144 | +def handle_spines(ax, spines): |
| 145 | + if spines is not None: |
| 146 | + tools.set_spine_visibility(ax, spines) |
| 147 | + |
| 148 | + |
| 149 | +def handle_aspect(ax, aspect): |
| 150 | + if aspect is not None: |
| 151 | + aspect = [aspect] if not isinstance(aspect, (list, tuple)) else aspect |
| 152 | + adjustable = None if len(aspect) < 2 else aspect[1] |
| 153 | + aspect_params = {'aspect': aspect[0], 'adjustable': adjustable} |
| 154 | + ax.set_aspect(**aspect_params) |
| 155 | + |
| 156 | + |
| 157 | +def handle_breaks(ax, x_breaks, y_breaks): |
| 158 | + if x_breaks is not None: |
| 159 | + ax.set_xticks(x_breaks) |
| 160 | + |
| 161 | + if y_breaks is not None: |
| 162 | + ax.set_yticks(y_breaks) |
| 163 | + |
| 164 | + |
| 165 | +def handle_scales(ax, x_scale, y_scale): |
| 166 | + if y_scale is not None: |
| 167 | + scale_params = tools.parse_scale(scale=y_scale) |
| 168 | + ax.set_yscale(**scale_params) |
| 169 | + |
| 170 | + if x_scale is not None: |
| 171 | + scale_params = tools.parse_scale(scale=x_scale) |
| 172 | + ax.set_xscale(**scale_params) |
| 173 | + |
| 174 | + |
| 175 | +def handle_lims(ax, x_lims, y_lims): |
| 176 | + if y_lims is not None: |
| 177 | + ax.set_ylim(y_lims) |
| 178 | + |
| 179 | + if x_lims is not None: |
| 180 | + ax.set_xlim(x_lims) |
| 181 | + |
| 182 | + |
| 183 | +def handle_title(ax, title, title_params): |
| 184 | + if title is None and len(title_params) == 0: |
| 185 | + return |
| 186 | + |
| 187 | + if title is None: |
| 188 | + title = ax.get_title() |
| 189 | + title = None if len(title) == 0 else title |
| 190 | + |
| 191 | + if title is not None or len(title_params) > 0: |
| 192 | + handle_text_element(ax.get_title, ax.set_title, title, title_params) |
| 193 | + |
| 194 | + |
| 195 | +def handle_labels(ax, axis, x_label, y_label, label_params): |
| 196 | + if x_label is None and y_label is None and len(label_params) == 0: |
| 197 | + return |
| 198 | + |
| 199 | + loc_lookup = { |
| 200 | + 'x': {'start': 'left', 'center': 'center', 'end': 'right'}, |
| 201 | + 'y': {'start': 'bottom', 'center': 'center', 'end': 'top'}, |
| 202 | + } |
| 203 | + |
| 204 | + def normalize_params(axis_key: str, params: dict | None) -> dict: |
| 205 | + params = params or {} |
| 206 | + loc = params.get('loc') |
| 207 | + if loc is not None: |
| 208 | + try: |
| 209 | + params['loc'] = loc_lookup[axis_key][loc] |
| 210 | + except KeyError: |
| 211 | + raise ValueError( |
| 212 | + f"Invalid {axis_key} label loc '{loc}'. Valid options are {list(loc_lookup[axis_key])}." |
| 213 | + ) |
| 214 | + return params |
| 215 | + |
| 216 | + x_label_params = normalize_params('x', label_params.copy()) if axis in ['x', 'both'] else {} |
| 217 | + y_label_params = normalize_params('y', label_params.copy()) if axis in ['y', 'both'] else {} |
| 218 | + |
| 219 | + handle_text_element(ax.get_xlabel, ax.set_xlabel, x_label, x_label_params) |
| 220 | + handle_text_element(ax.get_ylabel, ax.set_ylabel, y_label, y_label_params) |
| 221 | + |
| 222 | + |
| 223 | +def handle_tick_labels(ax, x_tick_labels, y_tick_labels): |
| 224 | + if x_tick_labels is not None: |
| 225 | + ax.set_xticklabels(x_tick_labels) |
| 226 | + |
| 227 | + if y_tick_labels is not None: |
| 228 | + ax.set_yticklabels(y_tick_labels) |
| 229 | + |
| 230 | + |
| 231 | +def handle_margins(ax, margins, make_square): |
| 232 | + if margins is not None and not make_square: |
| 233 | + xmargin, ymargin = utils.maj_min_args(margins) |
| 234 | + |
| 235 | + ax.set_xmargin(xmargin) |
| 236 | + ax.set_ymargin(ymargin) |
0 commit comments