Skip to content

Commit

Permalink
ENH combine(): to_list parameter to combing NDVars into list
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed Jun 24, 2021
1 parent 9a879dd commit c16c258
Showing 1 changed file with 59 additions and 35 deletions.
94 changes: 59 additions & 35 deletions eelbrain/_data_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ def combine(
check_dims: bool = True,
incomplete: str = 'raise',
dim_intersection: bool = False,
to_list: bool = False,
):
"""Combine a list of items of the same type into one item.
Expand All @@ -1089,8 +1090,13 @@ def combine(
for numerical variables).
dim_intersection
Only applies to combining :class:`NDVar`: normally, when :class:`NDVar`
have mismatching dimensions, a DimensionMismatchError is raised. With
``dim_intersection``, the intersection is used instead.
have mismatching dimensions, a :exc:`DimensionMismatchError` is raised.
With ``dim_intersection=True``, the intersection is used instead.
to_list
Only applies to combining :class:`NDVar`: normally, when :class:`NDVar`
have mismatching dimensions, a :exc:`DimensionMismatchError` is raised.
With ``to_list=True``, the :class:`NDVar` are added as :class:`list` of
:class:`NDVar` instead.
Notes
-----
Expand Down Expand Up @@ -1143,7 +1149,7 @@ def combine(
for key in keys:
pieces = [ds[key] if key in ds else
_empty_like(sample[key], ds.n_cases) for ds in items]
out[key] = combine(pieces, check_dims=check_dims, dim_intersection=dim_intersection)
out[key] = combine(pieces, check_dims=check_dims, dim_intersection=dim_intersection, to_list=to_list)
else:
keys = set(first_item)
if incomplete == 'raise':
Expand All @@ -1157,7 +1163,7 @@ def combine(
out_keys = (k for k in first_item if k in keys)

for key in out_keys:
out[key] = combine([ds[key] for ds in items], check_dims=check_dims, dim_intersection=dim_intersection)
out[key] = combine([ds[key] for ds in items], check_dims=check_dims, dim_intersection=dim_intersection, to_list=to_list)
return out
elif stype is Var:
x = np.hstack([i.x for i in items])
Expand All @@ -1180,35 +1186,37 @@ def combine(
has_case = True
all_dims = [item.dims[1:] for item in items]
elif any(v_have_case):
raise DimensionMismatchError("Some items have a 'case' dimension, others do not")
raise DimensionMismatchError(f"{name}: Some items have a Case dimension, others do not")
else:
has_case = False
all_dims = [item.dims for item in items]

if dim_intersection:
dims = reduce(lambda x, y: intersect_dims(x, y, check_dims), all_dims)
else:
dims, *other_dims = all_dims
if not all(dims_stackable(dims_i, dims, check_dims) for dims_i in other_dims):
msg = ["Some NDVars have mismatching dimensions", "set dim_intersection=True to discard elements not present in all"]
if not check_dims:
msg.insert(1, "set check_dims=False to ignore non-critical differences (e.g. connectivity)")
raise DimensionMismatchError.from_dims_list('; '.join(msg), all_dims, check_dims)
idx = {d.name: d for d in dims}
# reduce data to common dimension range
sub_items = []
for item in items:
if item.dims[has_case:] == dims:
sub_items.append(item)
dims, *other_dims = all_dims
if not all(dims_stackable(dims_i, dims, check_dims) for dims_i in other_dims):
if to_list:
if has_case:
return list(chain.from_iterable(items))
return items
elif dim_intersection:
dims = reduce(lambda x, y: intersect_dims(x, y, check_dims), all_dims)
idx = {dim.name: dim for dim in dims}
# reduce data to common dimension range
sub_items = []
for item in items:
if item.dims[has_case:] == dims:
sub_items.append(item)
else:
sub_items.append(item.sub(**idx))
elif not check_dims and all(dims_stackable(dims_i, dims, True) for dims_i in other_dims):
raise DimensionMismatchError.from_dims_list("Some NDVars have mismatching dimensions; set check_dims=False to ignore non-critical differences (e.g. connectivity)", all_dims, check_dims)
else:
sub_items.append(item.sub(**idx))
raise DimensionMismatchError.from_dims_list("Some NDVars have mismatching dimensions; use to_list=True to combine them into a list, or dim_intersection=True to discard elements not present in all", all_dims, check_dims)
# combine data
if has_case:
x = np.concatenate([v.x for v in sub_items], axis=0)
else:
x = np.stack([v.x for v in sub_items])
dims = ('case',) + dims
return NDVar(x, dims, name, _info.merge_info(sub_items))
return NDVar(x, (Case, *dims), name, _info.merge_info(sub_items))
elif stype is Datalist:
return Datalist(sum(items, []), name, items[0]._fmt)
else:
Expand Down Expand Up @@ -6196,35 +6204,51 @@ def eval(self, expression):
raise EvalError(expression, exception, ds_repr) from exception

@classmethod
def from_caselist(cls, names, cases, name=None, caption=None, info=None, random=None, check_dims=True, dim_intersection=False):
def from_caselist(
cls,
names: Sequence[str],
cases: Sequence[Sequence[str, Number, NDVar]],
name: str = None,
caption: str = None,
info: dict = None,
random: Union[str, Collection[str]] = None,
check_dims: bool = True,
dim_intersection: bool = False,
to_list: bool = False,
):
"""Create a Dataset from a list of cases
Parameters
----------
names : sequence of str
names
Names for the variables.
cases : sequence of sequence of { str | scalar | NDVar }
cases
A sequence of cases, whereby each case is itself represented as a
sequence of values (str or scalar). Variable type (Factor or Var)
is inferred from whether values are str or not.
name : str
name
Name for the Dataset.
caption : str
caption
Caption for the table.
info : dict
info
Info dictionary, can contain arbitrary entries and can be accessed
as ``.info`` attribute after initialization. The Dataset makes a
shallow copy.
random : str | sequence of str
random
Names of the columns that should be assigned as random factor.
check_dims : bool
check_dims
For :class:`NDVar` columns, check dimensions for consistency between
cases (e.g., channel locations in a :class:`Sensor`). Default is
``True``. Set to ``False`` to ignore mismatches.
dim_intersection : bool
dim_intersection
Only applies to combining :class:`NDVar`: normally, when :class:`NDVar`
have mismatching dimensions, a :exc:`DimensionMismatchError` is raised.
With ``dim_intersection=True``, the intersection is used instead.
to_list
Only applies to combining :class:`NDVar`: normally, when :class:`NDVar`
have mismatching dimensions, a DimensionMismatchError is raised. With
``dim_intersection``, the intersection is used instead.
have mismatching dimensions, a :exc:`DimensionMismatchError` is raised.
With ``to_list=True``, the :class:`NDVar` are added as :class:`list` of
:class:`NDVar` instead.
Examples
--------
Expand All @@ -6246,7 +6270,7 @@ def from_caselist(cls, names, cases, name=None, caption=None, info=None, random=
n_cases = n_cases.pop()
if len(names) != n_cases:
raise ValueError(f'{names=}: {len(names)} names but {n_cases} cases')
items = {key: combine([case[i] for case in cases], check_dims=check_dims, dim_intersection=dim_intersection) for i, key in enumerate(names)}
items = {key: combine([case[i] for case in cases], check_dims=check_dims, dim_intersection=dim_intersection, to_list=to_list) for i, key in enumerate(names)}
for key in random:
item = items[key]
if isinstance(item, Factor):
Expand Down

0 comments on commit c16c258

Please sign in to comment.