Skip to content

Commit

Permalink
ENH plot.UTSStat(): only show color legend for unique cells
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed Nov 24, 2021
1 parent ec5f062 commit 0789ec5
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 7 deletions.
14 changes: 10 additions & 4 deletions eelbrain/plot/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,7 @@ class PlotData:
ct: Celltable = None
x: Union[Factor, Interaction] = None
xax: Union[Factor, Interaction] = None
styles: Dict[CellArg, Style] = None

def __post_init__(self):
self.n_plots = len(self.plot_data)
Expand Down Expand Up @@ -1272,16 +1273,21 @@ def from_stats(
if agg:
raise NotImplementedError
# reconstruct x/xax
if x is not None:
x = ct._align(x, ds=ds, coerce=ascategorial)
if xax is None:
default_color_cells = None
ax_cells = [None]
else:
xax = ct._align(xax, ds=ds, coerce=ascategorial)
ax_cells = xax.cells
if x is not None:
x = ct._align(x, ds=ds, coerce=ascategorial)
if x is None:
default_color_cells = None
else:
default_color_cells = x.cells
title = frame_title(y, x, xax)
# find styles
styles = find_cell_styles(ct.cells, colors)
styles = find_cell_styles(ct.cells, colors, default_cells=default_color_cells)
# find masks
if mask is None:
masks = defaultdict(lambda: None)
Expand All @@ -1304,7 +1310,7 @@ def from_stats(
cells = [cell for cell in cells if cell in ct.data]
layers = [StatLayer(ct.data[cell], style=styles[cell], ct=ct, cell=cell, mask=masks[cell]) for cell in cells]
axes.append(AxisData(layers, cellname(ax_cell)))
return cls(axes, dims, title, ct=ct, x=x, xax=xax)
return cls(axes, dims, title, ct=ct, x=x, xax=xax, styles=styles)

@classmethod
def empty(cls, plots: Union[int, List[bool]], dims: Sequence[str], title: str):
Expand Down
15 changes: 12 additions & 3 deletions eelbrain/plot/_styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class Style:
edgecolor: Any = None
edgewidth: float = 0,
linemarker: Union[bool, str] = False
alias: CellArg = None # This style should not appear in legends

@cached_property
def line_args(self):
Expand Down Expand Up @@ -131,6 +132,7 @@ def find_cell_styles(
cells: Sequence[CellArg] = None,
colors: ColorsArg = None,
fallback: bool = True,
default_cells: Sequence[CellArg] = None,
) -> StylesDict:
"""Process the colors arg from plotting functions
Expand All @@ -148,6 +150,10 @@ def find_cell_styles(
<http://matplotlib.org/api/colors_api.html>`_.
fallback
If a cell is missing, fall back on partial cells (on by default).
default_cells
Alternative cells to use if ``colors`` is ``None`` or a :class:`dict`.
For example, when plots use the ``xax`` parameter, cells will contain a
``xax`` component, but colors should be consistent across axes.
"""
if cells in (None, (None,)):
if isinstance(colors, dict):
Expand All @@ -171,13 +177,16 @@ def find_cell_styles(
super_cells = chain((cell[:-i] for i in range(1, len(cell))), (cell[0],))
for super_cell in super_cells:
if super_cell in out:
out[cell] = out[super_cell]
out[cell] = replace(out[super_cell], alias=super_cell)
missing.remove(cell)
break
if missing:
raise KeysMissing(missing, 'colors', colors)
elif colors is None or isinstance(colors, str):
if all(isinstance(cell, str) for cell in cells):
if default_cells is not None:
default_colors = find_cell_styles(default_cells, colors)
return find_cell_styles(cells, default_colors)
elif all(isinstance(cell, str) for cell in cells):
out = colors_for_oneway(cells, cmap=colors)
elif all(isinstance(cell, tuple) for cell in cells):
ns = {len(cell) for cell in cells}
Expand All @@ -188,7 +197,7 @@ def find_cell_styles(
else:
raise NotImplementedError(f"{cells=}: unequal cell size")
else:
raise TypeError(f"colors={colors!r}")
raise TypeError(f"{colors=}")
return to_styles_dict(out)


Expand Down
5 changes: 5 additions & 0 deletions eelbrain/plot/_uts.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def __init__(
ymax = p.vmax if ymax is None else max(ymax, p.vmax)
self._set_axtitle(axtitle, data)

# The legend should only display cells with distinct styles: remap legend handles to source cells
alias_cells = {cell: style.alias for cell, style in data.styles.items() if style.alias}
if alias_cells:
legend_handles = {alias_cells[cell]: handle for cell, handle in legend_handles.items()}

# axes limits
if top is not None:
ymax = top
Expand Down
1 change: 1 addition & 0 deletions eelbrain/plot/tests/test_uts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_uts_stat():
assert [len(pl.stat_plots) for pl in p._plots] == [2, 2]
assert p.figure.axes[0].get_title() == 'b0'
assert p.figure.axes[1].get_title() == 'b1'
assert set(p._LegendMixin__handles) == {'a0', 'a1'}
p.close()
p = plot.UTSStat('uts', 'A', 'B', match='rm', ds=ds, pool_error=False)
assert [len(pl.stat_plots) for pl in p._plots] == [2, 2]
Expand Down

0 comments on commit 0789ec5

Please sign in to comment.