diff --git a/mne/minimum_norm/inverse.py b/mne/minimum_norm/inverse.py index fe313167360..9714637b167 100644 --- a/mne/minimum_norm/inverse.py +++ b/mne/minimum_norm/inverse.py @@ -1071,12 +1071,19 @@ def apply_inverse_raw(raw, inverse_operator, lambda2, method="dSPM", def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM', label=None, nave=1, pick_ori=None, prepared=False, method_params=None, - verbose=None): + delayed=False, verbose=None): """Generate inverse solutions for epochs. Used in apply_inverse_epochs.""" _check_option('method', method, INVERSE_METHODS) _check_ori(pick_ori, inverse_operator['source_ori']) _check_ch_names(inverse_operator, epochs.info) + is_free_ori = not (is_fixed_orient(inverse_operator) or + pick_ori == 'normal') + + if delayed and is_free_ori and pick_ori != "vector": + raise ValueError("delayed must be False for free orientations other " + "than pick_ori='vector' or 'normal'.") + # # Set up the inverse according to the parameters # @@ -1095,13 +1102,10 @@ def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM', tstep = 1.0 / epochs.info['sfreq'] tmin = epochs.times[0] - is_free_ori = not (is_fixed_orient(inverse_operator) or - pick_ori == 'normal') - if pick_ori == 'vector' and noise_norm is not None: noise_norm = noise_norm.repeat(3, axis=0) - if not is_free_ori and noise_norm is not None: + if not (is_free_ori and pick_ori != 'vector') and noise_norm is not None: # premultiply kernel with noise normalization K *= noise_norm @@ -1116,15 +1120,16 @@ def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM', # Compute solution and combine current components (non-linear) sol = np.dot(K, e[sel]) # apply imaging kernel - if pick_ori != 'vector': - logger.info('combining the current components...') - sol = combine_xyz(sol) + if is_free_ori and pick_ori != 'vector': + logger.info('combining the current components...') + sol = combine_xyz(sol) if noise_norm is not None: sol *= noise_norm + else: # Linear inverse: do computation here or delayed - if len(sel) < K.shape[1]: + if delayed: sol = (K, e[sel]) else: sol = np.dot(K, e[sel]) @@ -1143,7 +1148,7 @@ def _apply_inverse_epochs_gen(epochs, inverse_operator, lambda2, method='dSPM', def apply_inverse_epochs(epochs, inverse_operator, lambda2, method="dSPM", label=None, nave=1, pick_ori=None, return_generator=False, prepared=False, - method_params=None, verbose=None): + method_params=None, delayed=False, verbose=None): """Apply inverse operator to Epochs. Parameters @@ -1179,6 +1184,19 @@ def apply_inverse_epochs(epochs, inverse_operator, lambda2, method="dSPM", Additional options for eLORETA. See Notes of :func:`apply_inverse`. .. versionadded:: 0.16 + delayed : bool + If False, the source time courses are computed. If True, they are + stored as a tuple of two smaller arrays in order to save memory. In + this case, the first array in the tuple corresponds to the "kernel" + shape (n_vertices [, n_orientations], n_sensors) and the second array + to the "sens_data" shape (n_sensors, n_times). The full source time + courses field will be automatically computed when stc.data is called + for the first time (see for example: :class:`mne.SourceEstimate`). + `delayed=True` is only implemented for fixed orientations (e.g. + from pick_ori = "normal") as well as pick_ori="vector". + Defaults to False. + + .. versionadded:: 0.19 %(verbose)s Returns @@ -1194,7 +1212,7 @@ def apply_inverse_epochs(epochs, inverse_operator, lambda2, method="dSPM", stcs = _apply_inverse_epochs_gen( epochs, inverse_operator, lambda2, method=method, label=label, nave=nave, pick_ori=pick_ori, verbose=verbose, prepared=prepared, - method_params=method_params) + method_params=method_params, delayed=delayed) if not return_generator: # return a list diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py index 452e4e325a4..7bdc737e771 100644 --- a/mne/minimum_norm/tests/test_inverse.py +++ b/mne/minimum_norm/tests/test_inverse.py @@ -783,23 +783,24 @@ def test_apply_mne_inverse_fixed_raw(): assert_array_almost_equal(stc.data, stc3.data) +@pytest.fixture(scope="module") +def epochs(): + """Create an epochs object used for testing.""" + raw = read_raw_fif(fname_raw) + picks = pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True, + eog=True, include=['STI 014'], exclude='bads') + events = read_events(fname_event)[:15] + return Epochs(raw, events, 1, -0.2, 0.5, picks=picks, baseline=(None, 0), + reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6), + flat=dict(grad=1e-15, mag=1e-15)) + + @testing.requires_testing_data -def test_apply_mne_inverse_epochs(): +def test_apply_mne_inverse_epochs(epochs): """Test MNE with precomputed inverse operator on Epochs.""" inverse_operator = read_inverse_operator(fname_full) label_lh = read_label(fname_label % 'Aud-lh') label_rh = read_label(fname_label % 'Aud-rh') - event_id, tmin, tmax = 1, -0.2, 0.5 - raw = read_raw_fif(fname_raw) - - picks = pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True, - eog=True, include=['STI 014'], exclude='bads') - reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) - flat = dict(grad=1e-15, mag=1e-15) - - events = read_events(fname_event)[:15] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=reject, flat=flat) inverse_operator = prepare_inverse_operator(inverse_operator, nave=1, lambda2=lambda2, @@ -893,4 +894,40 @@ def test_inverse_ctf_comp(): apply_inverse_raw(raw, inv, 1. / 9.) +def _check_delayed_data(inst, delayed): + """Check whether data is represented as kernel or not.""" + if delayed: + assert isinstance(inst._kernel, np.ndarray) + assert isinstance(inst._sens_data, np.ndarray) + assert inst._data is None + assert not inst._kernel_removed + else: + assert inst._kernel is None + assert inst._sens_data is None + assert isinstance(inst._data, np.ndarray) + + +@testing.requires_testing_data +@pytest.mark.parametrize('pick_ori, inv_file', [['normal', fname_inv], + ['vector', fname_inv], + ['vector', fname_vol_inv]]) +def test_delayed_data(epochs, pick_ori, inv_file): + """Test if kernel in apply_inverse_epochs was properly applied.""" + inverse_operator = \ + prepare_inverse_operator(read_inverse_operator(inv_file), nave=1, + lambda2=lambda2, method="dSPM") + + full_stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, + pick_ori=pick_ori, delayed=False, + prepared=True) + kernel_stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, + pick_ori=pick_ori, delayed=True, + prepared=True) + + for full_stc, kern_stc in zip(full_stcs, kernel_stcs): + _check_delayed_data(full_stc, delayed=False) + _check_delayed_data(kern_stc, delayed=True) + assert_allclose(kern_stc.data, full_stc.data) + + run_tests_if_main() diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 5af0b43409b..a09bfd66fae 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -405,6 +405,8 @@ def guess_src_type(): 'arrays or an array') # massage the data + SENTINEL = object() # a sentinel for the non kernel case + data, sens_data = data if isinstance(data, tuple) else (data, SENTINEL) if src_type == 'surface' and vector: n_vertices = len(vertices[0]) + len(vertices[1]) data = np.matmul( @@ -416,9 +418,8 @@ def guess_src_type(): else: pass # noqa - return Klass( - data=data, vertices=vertices, tmin=tmin, tstep=tstep, subject=subject - ) + return Klass(data=data if sens_data is SENTINEL else (data, sens_data), + vertices=vertices, tmin=tmin, tstep=tstep, subject=subject) def _verify_source_estimate_compat(a, b): @@ -485,7 +486,7 @@ def __init__(self, data, vertices=None, tmin=None, tstep=None, raise ValueError('If data is a tuple it has to be length 2') kernel, sens_data = data data = None - if kernel.shape[1] != sens_data.shape[0]: + if kernel.shape[-1] != sens_data.shape[0]: raise ValueError('kernel and sens_data have invalid ' 'dimensions') if sens_data.ndim != 2: