Skip to content

Commit

Permalink
ENH plot.GolorGrid: accept styles
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed May 5, 2021
1 parent 07b8e8b commit c6f42ce
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 26 deletions.
4 changes: 2 additions & 2 deletions eelbrain/plot/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ def from_stats(
x = ct._align(x, ds=ds, coerce=ascategorial)
title = frame_title(y, x, xax)
# find styles
styles = find_cell_styles(ct.x, colors)
styles = find_cell_styles(ct.cells, colors)
# find masks
if mask is None:
masks = defaultdict(lambda: None)
Expand Down Expand Up @@ -3479,7 +3479,7 @@ def __init__(self, xmin, xmax, xlim=None, axes=None):

def _init_with_data(
self,
epochs: List[List[NDVar]],
epochs: Sequence[Sequence[NDVar]],
xdim: str,
xlim: Union[float, Tuple[float, float]] = None,
axes: List[matplotlib.axes.Axes] = None,
Expand Down
11 changes: 9 additions & 2 deletions eelbrain/plot/_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .._data_obj import CellArg, cellname
from .._utils import IS_WINDOWS
from ._base import EelFigure, Layout, AxisScale, CMapArg, ColorArg, fix_vlim_for_cmap
from ._styles import find_cell_styles


POINT_SIZE = 0.0138889 # 1 point in inches
Expand Down Expand Up @@ -76,6 +77,12 @@ def __init__(
else:
raise KeyError(f"Neither {(row_cell_0, col_cell_0)} nor {(col_cell_0, row_cell_0)} exist as a key in colors")

if row_first:
cells = list(zip(row_cells, column_cells))
else:
cells = list(zip(column_cells, row_cells))
styles = find_cell_styles(cells, colors)

# reverse rows so we can plot upwards
row_cells = tuple(reversed(row_cells))
n_rows = len(row_cells)
Expand Down Expand Up @@ -113,11 +120,11 @@ def __init__(
cell = (column_cells[col], row_cells[row])

if shape == 'box':
patch = mpl.patches.Rectangle((col, row), 1, 1, fc=colors[cell], ec='none')
patch = mpl.patches.Rectangle((col, row), 1, 1, ec='none', **styles[cell].patch_args)
ax.add_patch(patch)
elif shape == 'line':
y = row + 0.5
ax.plot([col, col + 1], [y, y], color=colors[cell])
ax.plot([col, col + 1], [y, y], **styles[cell].line_args)
else:
raise ValueError(f"shape={shape!r}")

Expand Down
37 changes: 20 additions & 17 deletions eelbrain/plot/_styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import chain, product
from math import ceil
import operator
from typing import Any, Dict, Sequence, Tuple, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import numpy as np
import matplotlib as mpl
Expand Down Expand Up @@ -77,18 +77,16 @@ def to_styles_dict(colors: Dict[CellArg, Any]) -> StylesDict:


def find_cell_styles(
x: CategorialVariable,
colors: ColorsArg,
cells: Sequence[CellArg] = None,
colors: ColorsArg = None,
fallback: bool = True,
) -> StylesDict:
"""Process the colors arg from plotting functions
Parameters
----------
x
Model for which colors are needed. ``None`` if only a single value is
plotted.
cells
Cells for which colors are needed.
colors
Colors for the plots if multiple categories of data are plotted.
**str**: A colormap name; cells are mapped onto the colormap in
Expand All @@ -97,21 +95,17 @@ def find_cell_styles(
**dict**: A dictionary mapping each cell to a color.
Colors are specified as `matplotlib compatible color arguments
<http://matplotlib.org/api/colors_api.html>`_.
cells
In case only a subset of cells is used.
fallback
If a cell is missing, fall back on partial cells (on by default).
"""
if x is None:
if not isinstance(colors, dict):
if cells in (None, (None,)):
if isinstance(colors, dict):
out = colors
else:
if colors is None:
colors = 'k'
colors = {None: colors}
return to_styles_dict(colors)
elif cells is None:
cells = x.cells

if isinstance(colors, (list, tuple)):
out = {None: colors}
elif isinstance(colors, (list, tuple)):
if len(colors) < len(cells):
raise ValueError(f"colors={colors!r}: only {len(colors)} colors for {len(cells)} cells.")
out = dict(zip(cells, colors))
Expand All @@ -132,7 +126,16 @@ def find_cell_styles(
if missing:
raise KeysMissing(missing, 'colors', colors)
elif colors is None or isinstance(colors, str):
out = colors_for_categorial(x, cmap=colors)
if all(isinstance(cell, str) for cell in cells):
out = colors_for_oneway(cells, cmap=colors)
elif all(isinstance(cell, tuple) for cell in cells):
ns = {len(cell) for cell in cells}
if len(ns) == 1:
out = colors_for_nway(list(zip(*cells)))
else:
raise NotImplementedError(f"{cells=}: unequal cell size")
else:
raise NotImplementedError(f"{cells=}: unequal cell size")
else:
raise TypeError(f"colors={colors!r}")
return to_styles_dict(out)
Expand Down
10 changes: 5 additions & 5 deletions eelbrain/plot/_uv.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def __init__(
if colors is False:
styles = False
else:
styles = find_cell_styles(ct.x, colors, ct.cells)
styles = find_cell_styles(ct.cells, colors)
if label_fliers and ct.match is None:
raise TypeError(f"label_fliers={label_fliers!r} without specifying the match parameter: match is needed to determine labels")
if ct.x is None and test is True:
Expand Down Expand Up @@ -416,7 +416,7 @@ def __init__(
if colors is False:
styles = False
else:
styles = find_cell_styles(ct.x, colors, ct.cells)
styles = find_cell_styles(ct.cells, colors)
if pool_error is None:
pool_error = ct.all_within

Expand Down Expand Up @@ -552,7 +552,7 @@ def __init__(
if colors is False:
styles = False
else:
styles = find_cell_styles(ct.x, colors, ct.cells)
styles = find_cell_styles(ct.cells, colors)
if pool_error is None:
pool_error = ct.all_within

Expand Down Expand Up @@ -860,7 +860,7 @@ def __init__(
if not line_plot:
legend = False

styles = find_cell_styles(categories, colors)
styles = find_cell_styles(categories.cells, colors)

# get axes
layout = Layout(1, 1, 5, **kwargs)
Expand Down Expand Up @@ -1074,7 +1074,7 @@ def __init__(
cmap = colors
else:
cat = color
styles = find_cell_styles(color, colors)
styles = find_cell_styles(color.cells, colors)

# size
if size is not None:
Expand Down

0 comments on commit c6f42ce

Please sign in to comment.