Skip to content

Commit

Permalink
FIX PlotData.from_args(): accept iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed Aug 27, 2021
1 parent 720a85e commit 79a0687
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
3 changes: 3 additions & 0 deletions eelbrain/_data_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -10946,3 +10946,6 @@ def intersect_dims(dims1, dims2, check_dims: bool = True):
IndexArg = Union[Var, np.ndarray, str]
ModelArg = Union[Model, Var, CategorialArg]
UVArg = Union[VarArg, CategorialArg]

# Types that can be coerced to NDVar with asndvar(); lists of those can be too
NDVarTypes = (NDVar, str, MNE_RAW, MNE_EPOCHS, MNE_EVOKED)
29 changes: 13 additions & 16 deletions eelbrain/plot/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,13 @@
from .._celltable import Celltable
from .._colorspaces import LocatedColormap, symmetric_cmaps, zerobased_cmaps, ALPHA_CMAPS
from .._config import CONFIG
from .._data_obj import Dimension, Dataset, Factor, Interaction, NDVar, Var, Case, UTS, NDVarArg, CategorialArg, IndexArg, CellArg, ascategorial, asndvar, assub, isnumeric, isdataobject, combine_cells, cellname
from .._data_obj import Dimension, Dataset, Factor, Interaction, NDVar, Var, Case, UTS, NDVarArg, CategorialArg, IndexArg, CellArg, NDVarTypes, ascategorial, asndvar, assub, isnumeric, isdataobject, combine_cells, cellname
from .._utils.notebooks import use_inline_backend
from .._stats import test, testnd
from .._stats import testnd
from .._utils import IS_WINDOWS, LazyProperty, intervals, ui
from .._ndvar import erode, resample
from .._text import enumeration, ms
from ..fmtxt import FMTextArg, Image, asfmtext, asfmtext_or_none
from ..mne_fixes import MNE_EPOCHS
from ..table import melt_ndvar
from ._decorations import mark_difference
from ._styles import Style, find_cell_styles
Expand Down Expand Up @@ -1131,13 +1130,10 @@ def from_args(
sub = assub(sub, ds)
if hasattr(y, '_default_plot_obj'):
ys = getattr(y, '_default_plot_obj')()
elif isinstance(y, MNE_EPOCHS):
# Epochs are Iterators over arrays
ys = (asndvar(y, sub, ds),)
else:
ys = y

if not isinstance(ys, (tuple, list, Iterator)):
if isinstance(ys, NDVarTypes):
ys = (ys,)

ax_names = None
Expand All @@ -1147,18 +1143,18 @@ def from_args(
for ax in ys:
if ax is None:
axes.append(None)
elif isinstance(ax, (tuple, list, Iterator)):
elif isinstance(ax, NDVarTypes):
ax = asndvar(ax, sub, ds)
agg, dims = find_data_dims(ax, dims)
layer = aggregate(ax, agg)
axes.append([layer])
else:
layers = []
for layer in ax:
layer = asndvar(layer, sub, ds)
agg, dims = find_data_dims(layer, dims)
layers.append(aggregate(layer, agg))
axes.append(layers)
else:
ax = asndvar(ax, sub, ds)
agg, dims = find_data_dims(ax, dims)
layer = aggregate(ax, agg)
axes.append([layer])
x_name = None
# determine y names
y_names = []
Expand All @@ -1168,9 +1164,7 @@ def from_args(
for layer in layers:
if layer.name and layer.name not in y_names:
y_names.append(layer.name)
elif any(isinstance(ax, (tuple, list, Iterator)) for ax in ys):
raise TypeError(f"y={y!r}, xax={xax!r}: y can't be nested list if xax is specified, use single list")
else:
elif all(ax is None or isinstance(ax, NDVarTypes) for ax in ys):
ys = [asndvar(layer, sub, ds) for layer in ys]
y_names = [layer.name for layer in ys]
layers = []
Expand Down Expand Up @@ -1209,6 +1203,9 @@ def from_args(
x_name = xax.name
ax_names = [cellname(cell) for cell in xax.cells]
axes = list(zip(*layers))
else:
raise TypeError(f"{y=}, {xax=}: y can't be nested list if xax is specified, use single list")

if len(y_names) == 0:
y_name = None
elif len(y_names) == 1:
Expand Down

0 comments on commit 79a0687

Please sign in to comment.