Skip to content

Commit

Permalink
FIX SourceSpace.index_for_label(): 'lh', 'rh'
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed Jun 12, 2021
1 parent b7b7d99 commit 928eb3c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
14 changes: 7 additions & 7 deletions eelbrain/_data_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
from ._utils.numpy_utils import (
INT_TYPES, FULL_SLICE, FULL_AXIS_SLICE,
aslice, apply_numpy_index, deep_array, digitize_index, digitize_slice_endpoint,
index_length, index_to_int_array, newaxis, slice_to_arange)
index_length, index_to_bool_array, index_to_int_array, newaxis, slice_to_arange)
from .mne_fixes import MNE_EPOCHS, MNE_EVOKED, MNE_RAW, MNE_LABEL
from functools import reduce

Expand Down Expand Up @@ -9955,7 +9955,7 @@ def index_for_label(self, label):
name = label.name
else:
raise TypeError(f"{label!r}")
return NDVar(idx, (self,), name)
return NDVar(index_to_bool_array(idx, len(self)), (self,), name)

def _is_superset_of(self, dim):
self._assert_same_base(dim)
Expand Down Expand Up @@ -10311,19 +10311,19 @@ def _read_surf(self, hemi, surf='orig'):
path = Path(f'{self.subjects_dir}/{self.subject}/surf/{hemi}.{surf}')
return read_geometry(path)

def index_for_label(self, label):
def index_for_label(self, label: Union[str, Sequence[str], mne.Label, mne.BiHemiLabel]) -> NDVar:
"""Return the index for a label
Parameters
----------
label : str | sequence of str | Label | BiHemiLabel
The name of a region in the current parcellation, or an :mod:`mne`
:class:`~mne.label.Label` object. If the label does not
label
The name of a region in the current parcellation, ``'lh'``, ``'rh'``,
or an :mod:`mne`:class:`~mne.label.Label` object. If the label does not
match any sources in the SourceEstimate, a ValueError is raised.
Returns
-------
index : NDVar of bool
index : boolean NDVar
Index into the source space dim that corresponds to the label.
"""
return SourceSpaceBase.index_for_label(self, label)
Expand Down
9 changes: 9 additions & 0 deletions eelbrain/_utils/numpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ def index(index, at: int):
return FULL_AXIS_SLICE * at + (index,)


def index_to_bool_array(index, n):
if isinstance(index, np.ndarray):
if index.dtype.kind == 'b':
return index
out = np.zeros(n, bool)
out[index] = True
return out


def index_to_int_array(index, n):
if isinstance(index, np.ndarray):
if index.dtype.kind == 'i':
Expand Down

0 comments on commit 928eb3c

Please sign in to comment.