Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support exporting mne objects to xarray DataArray objects #11464

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions environment.yml
Expand Up @@ -55,3 +55,4 @@ dependencies:
- eeglabio
- edflib-python
- pybv
- xarray
2 changes: 1 addition & 1 deletion mne/utils/__init__.py
Expand Up @@ -52,7 +52,7 @@
assert_object_equal, assert_and_remove_boundary_annot,
_raw_annot, assert_dig_allclose, assert_meg_snr,
assert_snr, assert_stcs_equal, _click_ch_name,
requires_openmeeg_mark)
requires_openmeeg_mark, requires_xarray)
from .numerics import (hashfunc, _compute_row_norms,
_reg_pinv, random_permutation, _reject_data_segments,
compute_corr, _get_inst_data, array_split_idx,
Expand Down
1 change: 1 addition & 0 deletions mne/utils/_testing.py
Expand Up @@ -130,6 +130,7 @@ def requires_module(function, name, call=None):
requires_pylsl = partial(requires_module, name='pylsl')
requires_sklearn = partial(requires_module, name='sklearn')
requires_mne = partial(requires_module, name='MNE-C', call=_mne_call)
requires_xarray = partial(requires_module, name='xarray')


def requires_mne_mark():
Expand Down
21 changes: 21 additions & 0 deletions mne/utils/tests/test_xarray.py
@@ -0,0 +1,21 @@
import numpy as np
import mne
from mne.utils import requires_xarray


@requires_xarray
def test_conversion_to_xarray():
"""Test conversion of mne object to xarray DataArray."""
import xarray as xr
from mne.utils.xarray import to_xarray

info = mne.create_info(list('abcd'), sfreq=250)
data = np.random.rand(4, 350)
erp = mne.EvokedArray(data, info, tmin=-0.5)

erp_x = to_xarray(erp)
assert isinstance(erp_x, xr.DataArray)
assert erp_x.shape == (4, 350)
assert erp_x.dims == ('chan', 'time')
assert erp_x.coords['chan'].data.tolist() == erp.ch_names
assert (erp_x.coords['time'].data == erp.times).all()
40 changes: 40 additions & 0 deletions mne/utils/xarray.py
@@ -0,0 +1,40 @@
import numpy as np

from .. import Epochs, Evoked
from ..time_frequency import AverageTFR


def to_xarray(mne_inst):
"""Convert MNE object instance to xarray DataArray.

Parameters
----------
mne_inst : Epochs | Evoked
The MNE object to convert.

Returns
-------
xarr : DataArray
The xarray object.
"""
from xarray import DataArray
from mne.utils import _validate_type

_validate_type(mne_inst, (Epochs, Evoked, AverageTFR))

if isinstance(mne_inst, Epochs):
data = mne_inst.get_data()
dims = ('chan', 'trial', 'time')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call the dimension "channel"?

Copy link
Member Author

@mmagnuski mmagnuski Feb 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer shorter dimension names as you use them often during various operations, but I will change the name if this would be the common preference. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we're internally consistent on questions like this. E.g., in function names sometimes we use channel and sometimes ch. Also when exporting to data frames, the spectrum class makes a column freq (not frequency). So I'm not really sure what to do WRT "channel". However, I would say that trial is not ideal, it should be epoch (for e.g. resting state recordings cut into sequential chunks "trial" doesn't make sense)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yes I will definitelly change trial to epoch, especially that you can have multiple epochs from the same trial (for example successively presented items to remember).

elif isinstance(mne_inst, Evoked):
data = mne_inst.data
dims = ('chan', 'time')
else:
raise ValueError('MNE instance must be Epochs or Evoked.')

coords = dict(chan=mne_inst.ch_names)
if 'time' in dims:
coords['time'] = mne_inst.times
if 'trial' in dims:
coords['trial'] = np.arange(mne_inst.n_epochs)

return DataArray(data, dims=dims, coords=coords)
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -43,3 +43,4 @@ mne-qt-browser
darkdetect
qdarkstyle
threadpoolctl
xarray
2 changes: 2 additions & 0 deletions tools/github_actions_dependencies.sh
Expand Up @@ -28,6 +28,8 @@ else
echo "nilearn and openmeeg"
pip install $STD_ARGS --pre https://github.com/nilearn/nilearn/zipball/main
pip install $STD_ARGS --pre --only-binary ":all:" -i "https://test.pypi.org/simple" openmeeg
echo "xarray"
pip install --progress-bar off --pre xarray
echo "VTK"
pip install $STD_ARGS --pre --only-binary ":all:" vtk
python -c "import vtk"
Expand Down