Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support exporting mne objects to xarray DataArray objects #11464

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ dependencies:
- eeglabio
- edflib-python
- pybv
- xarray
23 changes: 22 additions & 1 deletion mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
fill_doc, _check_option, _build_data_frame,
_check_pandas_installed, _check_pandas_index_arguments,
_convert_times, _scale_dataframe_data, _check_time_format,
_check_preload, _check_fname, TimeMixin)
_check_preload, _check_fname, TimeMixin, to_xarray,
_check_xarray_installed)
from .viz import (plot_evoked, plot_evoked_topomap, plot_evoked_field,
plot_evoked_image, plot_evoked_topo)
from .viz.evoked import plot_evoked_white, plot_evoked_joint
Expand Down Expand Up @@ -876,6 +877,26 @@ def to_data_frame(self, picks=None, index=None,
default_index=['time'])
return df

@fill_doc
def to_xarray(self, picks=None):
"""Export a copy of evoked data as an xarray DataArray.

Parameters
----------
%(picks_all)s

Returns
-------
xarr : DataArray
The xarray object.

Notes
-----
.. versionadded:: 0.1.3
"""
_check_xarray_installed()
return to_xarray(self, picks=picks)


@fill_doc
class EvokedArray(Evoked):
Expand Down
6 changes: 4 additions & 2 deletions mne/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
_check_edflib_installed, _to_rgb, _soft_import,
_check_dict_keys, _check_pymatreader_installed,
_import_h5py, _import_h5io_funcs,
_import_pymatreader_funcs, _check_head_radius)
_import_pymatreader_funcs, _check_head_radius,
_check_xarray_installed)
from .config import (set_config, get_config, get_config_path, set_cache_dir,
set_memmap_min_size, get_subjects_dir, _get_stim_channel,
sys_info, _get_extra_data_path, _get_root_dir,
Expand Down Expand Up @@ -52,7 +53,7 @@
assert_object_equal, assert_and_remove_boundary_annot,
_raw_annot, assert_dig_allclose, assert_meg_snr,
assert_snr, assert_stcs_equal, _click_ch_name,
requires_openmeeg_mark)
requires_openmeeg_mark, requires_xarray)
from .numerics import (hashfunc, _compute_row_norms,
_reg_pinv, random_permutation, _reject_data_segments,
compute_corr, _get_inst_data, array_split_idx,
Expand All @@ -73,3 +74,4 @@
_get_blas_funcs)
from .dataframe import (_set_pandas_dtype, _scale_dataframe_data,
_convert_times, _build_data_frame)
from .xarray import to_xarray
1 change: 1 addition & 0 deletions mne/utils/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def requires_module(function, name, call=None):
requires_pylsl = partial(requires_module, name='pylsl')
requires_sklearn = partial(requires_module, name='sklearn')
requires_mne = partial(requires_module, name='MNE-C', call=_mne_call)
requires_xarray = partial(requires_module, name='xarray')


def requires_mne_mark():
Expand Down
5 changes: 5 additions & 0 deletions mne/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,11 @@ def _check_pandas_installed(strict=True):
return _soft_import('pandas', 'dataframe integration', strict=strict)


def _check_xarray_installed(strict=True):
"""Aux function."""
return _soft_import('xarray', 'converting to xarray', strict=strict)


def _check_eeglabio_installed(strict=True):
"""Aux function."""
return _soft_import('eeglabio', 'exporting to EEGLab', strict=strict)
Expand Down
25 changes: 25 additions & 0 deletions mne/utils/tests/test_xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
import mne
from mne.utils import requires_xarray


@requires_xarray
def test_conversion_to_xarray():
"""Test conversion of mne object to xarray DataArray."""
import xarray as xr

info = mne.create_info(list('abcd'), sfreq=250)
data = np.random.rand(4, 350)
evoked = mne.EvokedArray(data, info, tmin=-0.5)

xarr = evoked.to_xarray()
assert isinstance(xarr, xr.DataArray)
assert xarr.shape == evoked.data.shape
assert xarr.dims == ('chan', 'time')
assert xarr.coords['chan'].data.tolist() == evoked.ch_names
assert (xarr.coords['time'].data == evoked.times).all()

xarr = evoked.to_xarray(picks=['b', 'd'])
assert xarr.shape == (2, 350)
assert xarr.coords['chan'].data.tolist() == ['b', 'd']
assert (xarr.data == evoked.data[[1, 3]]).all()
51 changes: 51 additions & 0 deletions mne/utils/xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np


def to_xarray(inst, picks=None):
"""Convert MNE object instance to xarray DataArray.

Parameters
----------
inst : Epochs | Evoked
The MNE object to convert.
picks : list of str | array-like of int | None
Channels to include. If None only good data channels are kept.
(I couldn't use @fill_doc here, so I put this temporarily here)

Returns
-------
xarr : DataArray
The xarray object.
"""
from xarray import DataArray
from .. import Epochs, Evoked
from . import _validate_type
from ..io.pick import _picks_to_idx

_validate_type(inst, (Epochs, Evoked))

if isinstance(inst, Epochs):
dims = ('chan', 'epoch', 'time')
elif isinstance(inst, Evoked):
dims = ('chan', 'time')
else:
raise ValueError('MNE instance must be Epochs or Evoked.')

coords = dict(chan=inst.ch_names)
if picks is not None:
picks = _picks_to_idx(inst.info, picks, exclude=[])
coords['chan'] = [inst.ch_names[pck] for pck in picks]

if 'time' in dims:
coords['time'] = inst.times
if 'epoch' in dims:
coords['epoch'] = np.arange(inst.n_epochs)

data = inst.get_data(picks=picks)
xarr = DataArray(data, dims=dims, coords=coords)

# add channel types as additional dimension coordinate
xarr = xarr.assign_coords(
ch_type=('chan', inst.get_channel_types(picks=picks))
)
return xarr
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ mne-qt-browser
darkdetect
qdarkstyle
threadpoolctl
xarray
2 changes: 2 additions & 0 deletions tools/github_actions_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ else
echo "nilearn and openmeeg"
pip install $STD_ARGS --pre https://github.com/nilearn/nilearn/zipball/main
pip install $STD_ARGS --pre --only-binary ":all:" -i "https://test.pypi.org/simple" openmeeg
echo "xarray"
pip install --progress-bar off --pre xarray
echo "VTK"
pip install $STD_ARGS --pre --only-binary ":all:" vtk
python -c "import vtk"
Expand Down