Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel param for apply inverse #6609

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1645de3
add full_data param to apply_inverse_epochs
DiGyt Jul 16, 2019
cba0585
Merge https://github.com/mne-tools/mne-python into kernel_param_for_a…
DiGyt Jul 27, 2019
5775bec
Added test
DiGyt Jul 27, 2019
3e35749
Improved Tests, docstrings
DiGyt Jul 27, 2019
75cace0
changed param name, added fixture
DiGyt Jul 29, 2019
3424e24
corrected bool param semantics
DiGyt Jul 29, 2019
a5349a3
improved docstring
DiGyt Jul 30, 2019
246fbce
Merge https://github.com/mne-tools/mne-python into kernel_param_for_a…
DiGyt Jul 30, 2019
15bfbda
Update mne/minimum_norm/inverse.py
DiGyt Jul 30, 2019
5c6bec5
Update mne/minimum_norm/inverse.py
DiGyt Jul 30, 2019
a60e33c
changed fixture name...
DiGyt Jul 30, 2019
fd56d56
fixed docstring
DiGyt Jul 30, 2019
4eccbba
raise error if delayed and vector orientation
DiGyt Jul 31, 2019
864c20d
Delayed works for pick_ori = "vector" now
DiGyt Aug 1, 2019
acfe50a
Merge https://github.com/mne-tools/mne-python into kernel_param_for_a…
DiGyt Aug 6, 2019
108bfdf
corrected stuff
DiGyt Aug 6, 2019
05607e0
adapted data, introduced tests
DiGyt Aug 8, 2019
2a29730
also cover VolVectorSourceEstimates
DiGyt Aug 8, 2019
17f9d05
rebase
DiGyt Aug 8, 2019
cc6e612
Merge https://github.com/mne-tools/mne-python into kernel_param_for_a…
DiGyt Aug 8, 2019
4ace218
merged with master
DiGyt Aug 8, 2019
f254a95
only one return statement in _make_stc
DiGyt Aug 20, 2019
6950b95
fixed flake
DiGyt Aug 20, 2019
e41236f
this should work
massich Aug 21, 2019
6bd01a7
change error message
DiGyt Aug 22, 2019
782e61f
Merge https://github.com/mne-tools/mne-python into kernel_param_for_a…
DiGyt Aug 22, 2019
c3427b1
Merge https://github.com/mne-tools/mne-python into kernel_param_for_a…
DiGyt Aug 30, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions mne/minimum_norm/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,12 +1071,19 @@ def apply_inverse_raw(raw, inverse_operator, lambda2, method="dSPM",
def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM',
label=None, nave=1, pick_ori=None,
prepared=False, method_params=None,
verbose=None):
delayed=False, verbose=None):
"""Generate inverse solutions for epochs. Used in apply_inverse_epochs."""
_check_option('method', method, INVERSE_METHODS)
_check_ori(pick_ori, inverse_operator['source_ori'])
_check_ch_names(inverse_operator, epochs.info)

is_free_ori = not (is_fixed_orient(inverse_operator) or
pick_ori == 'normal')

if delayed and is_free_ori and pick_ori != "vector":
raise ValueError("delayed must be False for free orientations other "
"than pick_ori='vector' or 'normal'.")

#
# Set up the inverse according to the parameters
#
Expand All @@ -1095,13 +1102,10 @@ def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM',
tstep = 1.0 / epochs.info['sfreq']
tmin = epochs.times[0]

is_free_ori = not (is_fixed_orient(inverse_operator) or
pick_ori == 'normal')

if pick_ori == 'vector' and noise_norm is not None:
noise_norm = noise_norm.repeat(3, axis=0)

if not is_free_ori and noise_norm is not None:
if not (is_free_ori and pick_ori != 'vector') and noise_norm is not None:
# premultiply kernel with noise normalization
K *= noise_norm

Expand All @@ -1116,15 +1120,16 @@ def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM',
# Compute solution and combine current components (non-linear)
sol = np.dot(K, e[sel]) # apply imaging kernel

if pick_ori != 'vector':
logger.info('combining the current components...')
sol = combine_xyz(sol)
if is_free_ori and pick_ori != 'vector':
logger.info('combining the current components...')
sol = combine_xyz(sol)

if noise_norm is not None:
sol *= noise_norm

else:
# Linear inverse: do computation here or delayed
if len(sel) < K.shape[1]:
if delayed:
DiGyt marked this conversation as resolved.
Show resolved Hide resolved
sol = (K, e[sel])
else:
sol = np.dot(K, e[sel])
Expand All @@ -1143,7 +1148,7 @@ def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM',
def apply_inverse_epochs(epochs, inverse_operator, lambda2, method="dSPM",
label=None, nave=1, pick_ori=None,
return_generator=False, prepared=False,
method_params=None, verbose=None):
method_params=None, delayed=False, verbose=None):
"""Apply inverse operator to Epochs.

Parameters
Expand Down Expand Up @@ -1179,6 +1184,19 @@ def apply_inverse_epochs(epochs, inverse_operator, lambda2, method="dSPM",
Additional options for eLORETA. See Notes of :func:`apply_inverse`.

.. versionadded:: 0.16
delayed : bool
If False, the source time courses are computed. If True, they are
stored as a tuple of two smaller arrays in order to save memory. In
this case, the first array in the tuple corresponds to the "kernel"
shape (n_vertices [, n_orientations], n_sensors) and the second array
to the "sens_data" shape (n_sensors, n_times). The full source time
courses field will be automatically computed when stc.data is called
for the first time (see for example: :class:`mne.SourceEstimate`).
`delayed=True` is only implemented for fixed orientations (e.g.
from pick_ori = "normal") as well as pick_ori="vector".
Defaults to False.

.. versionadded:: 0.19
%(verbose)s

Returns
Expand All @@ -1194,7 +1212,7 @@ def apply_inverse_epochs(epochs, inverse_operator, lambda2, method="dSPM",
stcs = _apply_inverse_epochs_gen(
epochs, inverse_operator, lambda2, method=method, label=label,
nave=nave, pick_ori=pick_ori, verbose=verbose, prepared=prepared,
method_params=method_params)
method_params=method_params, delayed=delayed)

if not return_generator:
# return a list
Expand Down
61 changes: 49 additions & 12 deletions mne/minimum_norm/tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,23 +783,24 @@ def test_apply_mne_inverse_fixed_raw():
assert_array_almost_equal(stc.data, stc3.data)


@pytest.fixture(scope="module")
def epochs():
DiGyt marked this conversation as resolved.
Show resolved Hide resolved
"""Create an epochs object used for testing."""
raw = read_raw_fif(fname_raw)
picks = pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True,
eog=True, include=['STI 014'], exclude='bads')
events = read_events(fname_event)[:15]
return Epochs(raw, events, 1, -0.2, 0.5, picks=picks, baseline=(None, 0),
reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6),
flat=dict(grad=1e-15, mag=1e-15))


@testing.requires_testing_data
def test_apply_mne_inverse_epochs():
def test_apply_mne_inverse_epochs(epochs):
"""Test MNE with precomputed inverse operator on Epochs."""
inverse_operator = read_inverse_operator(fname_full)
label_lh = read_label(fname_label % 'Aud-lh')
label_rh = read_label(fname_label % 'Aud-rh')
event_id, tmin, tmax = 1, -0.2, 0.5
raw = read_raw_fif(fname_raw)

picks = pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True,
eog=True, include=['STI 014'], exclude='bads')
reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)
flat = dict(grad=1e-15, mag=1e-15)

events = read_events(fname_event)[:15]
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
baseline=(None, 0), reject=reject, flat=flat)

inverse_operator = prepare_inverse_operator(inverse_operator, nave=1,
lambda2=lambda2,
Expand Down Expand Up @@ -893,4 +894,40 @@ def test_inverse_ctf_comp():
apply_inverse_raw(raw, inv, 1. / 9.)


def _check_delayed_data(inst, delayed):
"""Check whether data is represented as kernel or not."""
if delayed:
assert isinstance(inst._kernel, np.ndarray)
assert isinstance(inst._sens_data, np.ndarray)
assert inst._data is None
assert not inst._kernel_removed
else:
assert inst._kernel is None
assert inst._sens_data is None
assert isinstance(inst._data, np.ndarray)


@testing.requires_testing_data
@pytest.mark.parametrize('pick_ori, inv_file', [['normal', fname_inv],
['vector', fname_inv],
['vector', fname_vol_inv]])
def test_delayed_data(epochs, pick_ori, inv_file):
"""Test if kernel in apply_inverse_epochs was properly applied."""
inverse_operator = \
prepare_inverse_operator(read_inverse_operator(inv_file), nave=1,
lambda2=lambda2, method="dSPM")

full_stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2,
pick_ori=pick_ori, delayed=False,
prepared=True)
kernel_stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2,
pick_ori=pick_ori, delayed=True,
prepared=True)

for full_stc, kern_stc in zip(full_stcs, kernel_stcs):
_check_delayed_data(full_stc, delayed=False)
_check_delayed_data(kern_stc, delayed=True)
assert_allclose(kern_stc.data, full_stc.data)


run_tests_if_main()
9 changes: 5 additions & 4 deletions mne/source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ def guess_src_type():
'arrays or an array')

# massage the data
SENTINEL = object() # a sentinel for the non kernel case
data, sens_data = data if isinstance(data, tuple) else (data, SENTINEL)
if src_type == 'surface' and vector:
n_vertices = len(vertices[0]) + len(vertices[1])
data = np.matmul(
Expand All @@ -416,9 +418,8 @@ def guess_src_type():
else:
pass # noqa

return Klass(
data=data, vertices=vertices, tmin=tmin, tstep=tstep, subject=subject
)
return Klass(data=data if sens_data is SENTINEL else (data, sens_data),
vertices=vertices, tmin=tmin, tstep=tstep, subject=subject)


def _verify_source_estimate_compat(a, b):
Expand Down Expand Up @@ -485,7 +486,7 @@ def __init__(self, data, vertices=None, tmin=None, tstep=None,
raise ValueError('If data is a tuple it has to be length 2')
kernel, sens_data = data
data = None
if kernel.shape[1] != sens_data.shape[0]:
if kernel.shape[-1] != sens_data.shape[0]:
raise ValueError('kernel and sens_data have invalid '
'dimensions')
if sens_data.ndim != 2:
Expand Down