Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace metadata eval with query #10705

Merged
merged 5 commits into from Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Expand Up @@ -110,6 +110,8 @@ Bugs

- Retain epochs metadata when using :func:`mne.channels.combine_channels` (:gh:`10504` by `Clemens Brunner`_)

- Fix epochs indexing with metadata containing boolean type and missing values (:gh:`10705` by `Clemens Brunner`_ and `Alex Gramfort`_)

- Fix reading of fiducial locations in :func:`mne.io.read_raw_eeglab` (:gh:`10521` by `Alex Gramfort`_)

- Prevent creation of montage with invalid ``[x, y, z]`` coordinates with :func:`mne.channels.make_dig_montage` (:gh:`10547` by `Mathieu Scheltienne`_)
Expand Down
24 changes: 24 additions & 0 deletions mne/tests/test_epochs.py
Expand Up @@ -3155,6 +3155,30 @@ def test_metadata(tmp_path):
assert len(epochs.metadata) == 1


@requires_pandas
def test_metadata_query_bool():
"""Test metadata query with boolean indexing and NaNs gh #10705."""
import mne
import numpy as np
import pandas as pd
from random import choices

n_epochs, n_chans, n_samples = 40, 3, 101
correct = pd.Series(choices([True, False, None], k=n_epochs),
dtype="boolean")
metadata = pd.DataFrame({"correct": correct})
rng = np.random.default_rng()
epochs = mne.EpochsArray(
data=rng.standard_normal(size=(n_epochs, n_chans, n_samples)),
info=mne.create_info(n_chans, 500, "eeg"),
metadata=metadata
)

assert len(epochs) > metadata.sum()["correct"]
epochs = epochs["correct"]
assert len(epochs) == metadata.sum()["correct"]


def assert_metadata_equal(got, exp):
"""Assert metadata are equal."""
if exp is None:
Expand Down
7 changes: 5 additions & 2 deletions mne/utils/mixin.py
Expand Up @@ -254,12 +254,15 @@ def _keys_to_idx(self, keys):
self._check_metadata(metadata=md)
try:
# Try metadata
mask = self.metadata.eval(keys[0], engine='python').values
vals = self.metadata.reset_index().query(
keys[0],
engine='python'
).index.values
except Exception as exp:
msg += (' The epochs.metadata Pandas query did not '
'yield any results: %s' % (exp.args[0],))
else:
return np.where(mask)[0]
return vals
else:
# If not, warn this might be a problem
msg += (' The epochs.metadata Pandas query could not '
Expand Down