Skip to content

Commit

Permalink
Replace metadata eval with query (#10705)
Browse files Browse the repository at this point in the history
* Replace metadata eval with query

* Reset index

* add test

* update what's new

* Adapt test

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
  • Loading branch information
cbrnr and agramfort committed Jun 7, 2022
1 parent 36ec7f4 commit b2652a6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Expand Up @@ -112,6 +112,8 @@ Bugs

- Retain epochs metadata when using :func:`mne.channels.combine_channels` (:gh:`10504` by `Clemens Brunner`_)

- Fix epochs indexing with metadata containing boolean type and missing values (:gh:`10705` by `Clemens Brunner`_ and `Alex Gramfort`_)

- Fix reading of fiducial locations in :func:`mne.io.read_raw_eeglab` (:gh:`10521` by `Alex Gramfort`_)

- Prevent creation of montage with invalid ``[x, y, z]`` coordinates with :func:`mne.channels.make_dig_montage` (:gh:`10547` by `Mathieu Scheltienne`_)
Expand Down
19 changes: 18 additions & 1 deletion mne/tests/test_epochs.py
Expand Up @@ -2969,7 +2969,7 @@ def test_default_values():
@requires_pandas
def test_metadata(tmp_path):
"""Test metadata support with pandas."""
from pandas import DataFrame
from pandas import DataFrame, Series, NA

data = np.random.randn(10, 2, 2000)
chs = ['a', 'b']
Expand Down Expand Up @@ -3154,6 +3154,23 @@ def test_metadata(tmp_path):
assert len(epochs) == 1
assert len(epochs.metadata) == 1

# gh-10705: support boolean columns
metadata = DataFrame(
{"A": Series([True, True, True, False, False, NA], dtype="boolean")}
)
rng = np.random.default_rng()
epochs = mne.EpochsArray(
data=rng.standard_normal(size=(6, 8, 500)),
info=mne.create_info(8, 250, "eeg"),
event_id={"A": 1},
metadata=metadata
)

assert len(epochs["A"]) == 6 # epochs of event type A
assert len(epochs["A == True"]) == 3 # epochs for which column A == True
assert len(epochs["not A"]) == 2 # epochs for which column A == False
assert len(epochs["A.isna()"]) == 1 # epochs for NA in column A


def assert_metadata_equal(got, exp):
"""Assert metadata are equal."""
Expand Down
7 changes: 5 additions & 2 deletions mne/utils/mixin.py
Expand Up @@ -254,12 +254,15 @@ def _keys_to_idx(self, keys):
self._check_metadata(metadata=md)
try:
# Try metadata
mask = self.metadata.eval(keys[0], engine='python').values
vals = self.metadata.reset_index().query(
keys[0],
engine='python'
).index.values
except Exception as exp:
msg += (' The epochs.metadata Pandas query did not '
'yield any results: %s' % (exp.args[0],))
else:
return np.where(mask)[0]
return vals
else:
# If not, warn this might be a problem
msg += (' The epochs.metadata Pandas query could not '
Expand Down

0 comments on commit b2652a6

Please sign in to comment.