From cce10df0b812f088bb0bd78dafc967df5d8f7464 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 18 Mar 2020 10:37:47 -0400 Subject: [PATCH] ENH: Add eSSS --- doc/changes/latest.inc | 2 + mne/chpi.py | 2 +- mne/fixes.py | 17 ++ mne/preprocessing/maxwell.py | 144 ++++++++++--- mne/preprocessing/tests/test_maxwell.py | 257 +++++++++++++++++++----- mne/proj.py | 3 - mne/tests/test_proj.py | 2 - mne/utils/docs.py | 7 + 8 files changed, 346 insertions(+), 88 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 43c8132c138..f0b914e9319 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -107,6 +107,8 @@ Changelog - Add automatic T3 magnetometer detection and application of :meth:`mne.io.Raw.fix_mag_coil_types` to :func:`mne.preprocessing.maxwell_filter` by `Eric Larson`_ +- Add extended SSS (eSSS) support to :func:`mne.preprocessing.maxwell_filter` by `Eric Larson`_ + - Add ``'auto'`` option to :meth:`mne.preprocessing.ICA.find_bads_ecg` to automatically determine the threshold for CTPS method by `Yu-Han Luo`_ - Add a ``notebook`` 3d backend for visualization in jupyter notebook with :func:`mne.viz.set_3d_backend` by `Guillaume Favelier`_ diff --git a/mne/chpi.py b/mne/chpi.py index 4dbffd08439..854d3ca1a43 100644 --- a/mne/chpi.py +++ b/mne/chpi.py @@ -517,7 +517,7 @@ def _setup_ext_proj(info, ext_order): ext = _sss_basis( dict(origin=(0., 0., 0.), int_order=0, ext_order=ext_order), mf_coils).T - out_removes = _regularize_out(0, 1, mag_or_fine) + out_removes = _regularize_out(0, 1, mag_or_fine, []) ext = ext[~np.in1d(np.arange(len(ext)), out_removes)] ext = linalg.orth(ext.T).T assert ext.shape[1] == len(meg_picks) diff --git a/mne/fixes.py b/mne/fixes.py index 3b0b5bfc2b2..7fb58539316 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -20,6 +20,7 @@ import warnings import numpy as np +import scipy from scipy import linalg from scipy.linalg import LinAlgError @@ -314,6 +315,22 @@ def _get_dpss(): from numpy.fft import fft, ifft, fftfreq, rfft, irfft, rfftfreq, ifftshift +############################################################################### +# Orth with rcond argument (SciPy 1.1) + +if LooseVersion(scipy.__version__) >= '1.1': + from scipy.linalg import orth +else: + def orth(A, rcond=None): # noqa + u, s, vh = linalg.svd(A, full_matrices=False) + M, N = u.shape[0], vh.shape[1] + if rcond is None: + rcond = numpy.finfo(s.dtype).eps * max(M, N) + tol = np.amax(s) * rcond + num = np.sum(s > tol, dtype=int) + Q = u[:, :num] + return Q + ############################################################################### # NumPy Generator (NumPy 1.17) diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index a97953eeaf6..793b6823418 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -28,11 +28,12 @@ from ..io.meas_info import _simplify_info from ..io.proc_history import _read_ctc from ..io.write import _generate_meas_id, DATE_NONE -from ..io import _loc_to_coil_trans, _coil_trans_to_loc, BaseRaw, RawArray +from ..io import (_loc_to_coil_trans, _coil_trans_to_loc, BaseRaw, RawArray, + Projection) from ..io.pick import pick_types, pick_info from ..utils import (verbose, logger, _clean_names, warn, _time_mask, _pl, _check_option, _ensure_int, _validate_type, use_log_level) -from ..fixes import _get_args, _safe_svd, einsum, bincount +from ..fixes import _get_args, _safe_svd, einsum, bincount, orth from ..channels.channels import _get_T1T2_mag_inds, fix_mag_coil_types @@ -48,7 +49,8 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3, st_correlation=0.98, coord_frame='head', destination=None, regularize='in', ignore_ref=False, bad_condition='error', head_pos=None, st_fixed=True, st_only=False, mag_scale=100., - skip_by_annotation=('edge', 'bad_acq_skip'), verbose=None): + skip_by_annotation=('edge', 'bad_acq_skip'), + extended_proj=(), verbose=None): """Maxwell filter data using multipole moments. Parameters @@ -93,6 +95,7 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3, %(maxwell_skip)s .. versionadded:: 0.17 + %(maxwell_extended)s %(verbose)s Returns @@ -161,6 +164,8 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3, +-----------------------------------------------------------------------------+-----+-----------+ | Certified for clinical use | | ✓ | +-----------------------------------------------------------------------------+-----+-----------+ + | Extended external basis (eSSS) | ✓ | | + +-----------------------------------------------------------------------------+-----+-----------+ Epoch-based movement compensation is described in [1]_. @@ -214,7 +219,7 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3, regularize=regularize, ignore_ref=ignore_ref, bad_condition=bad_condition, head_pos=head_pos, st_fixed=st_fixed, st_only=st_only, mag_scale=mag_scale, - skip_by_annotation=skip_by_annotation) + skip_by_annotation=skip_by_annotation, extended_proj=extended_proj) raw_sss = _run_maxwell_filter(raw, **params) # Update info _update_sss_info(raw_sss, **params['update_kwargs']) @@ -230,7 +235,7 @@ def _prep_maxwell_filter( regularize='in', ignore_ref=False, bad_condition='error', head_pos=None, st_fixed=True, st_only=False, mag_scale=100., - skip_by_annotation=('edge', 'bad_acq_skip'), + skip_by_annotation=('edge', 'bad_acq_skip'), extended_proj=(), reconstruct='in', verbose=None): # There are an absurd number of different possible notations for spherical # coordinates, which confounds the notation for spherical harmonics. Here, @@ -280,6 +285,28 @@ def _prep_maxwell_filter( coil_scale, mag_scale = _get_coil_scale( meg_picks, mag_picks, grad_picks, mag_scale, info) + # + # Extended projection vectors + # + _validate_type(extended_proj, (list, tuple), 'extended_proj') + good_names = [info['ch_names'][c] for c in meg_picks[good_mask]] + if len(extended_proj) > 0: + extended_proj_ = list() + for pi, proj in enumerate(extended_proj): + item = 'extended_proj[%d]' % (pi,) + _validate_type(proj, Projection, item) + got_names = proj['data']['col_names'] + diff = sorted(set(got_names).symmetric_difference(set(good_names))) + if diff: + raise ValueError('%s channel names (length %d) do not match ' + 'the good MEG channel names (length %d):\n%s' + % (item, len(got_names), len(good_names), + ', '.join(diff))) + extended_proj_.append(proj['data']['data']) + extended_proj = np.concatenate(extended_proj_) + logger.info(' Extending external SSS basis using %d projection ' + 'vectors' % (len(extended_proj),)) + # # Fine calibration processing (load fine cal and overwrite sensor geometry) # @@ -298,7 +325,8 @@ def _prep_maxwell_filter( origin_head = origin update_kwargs = dict( origin=origin, coord_frame=coord_frame, sss_cal=sss_cal, - int_order=int_order, ext_order=ext_order) + int_order=int_order, ext_order=ext_order, + extended_proj=extended_proj) del origin, coord_frame, sss_cal origin_head.setflags(write=False) @@ -315,6 +343,7 @@ def _prep_maxwell_filter( # between old and new fif files if meg_ch_names[0] not in ctc_chs: ctc_chs = _clean_names(ctc_chs, remove_whitespace=True) + meg_ch_names = _clean_names(meg_ch_names, remove_whitespace=True) missing = sorted(list(set(meg_ch_names) - set(ctc_chs))) if len(missing) != 0: raise RuntimeError('Missing MEG channels in cross-talk matrix:\n%s' @@ -322,7 +351,7 @@ def _prep_maxwell_filter( missing = sorted(list(set(ctc_chs) - set(meg_ch_names))) if len(missing) > 0: warn('Not all cross-talk channels in raw:\n%s' % missing) - ctc_picks = [ctc_chs.index(info['ch_names'][c]) for c in meg_picks] + ctc_picks = [ctc_chs.index(name) for name in meg_ch_names] ctc = sss_ctc['decoupler'][ctc_picks][:, ctc_picks] # I have no idea why, but MF transposes this for storage.. sss_ctc['decoupler'] = sss_ctc['decoupler'].T.tocsc() @@ -336,6 +365,8 @@ def _prep_maxwell_filter( all_coils = _prep_mf_coils(info, ignore_ref) S_recon = _trans_sss_basis(exp, all_coils, recon_trans, coil_scale) exp['ext_order'] = ext_order + exp['extended_proj'] = extended_proj + del extended_proj # Reconstruct data from internal space only (Eq. 38), and rescale S_recon S_recon /= coil_scale if recon_trans is not None: @@ -836,15 +867,52 @@ def _get_decomp(trans, all_coils, cal, regularize, exp, ignore_ref, exp, all_coils, trans, coil_scale, cal, ignore_ref, grad_picks, mag_picks, mag_scale) S_decomp = S_decomp_full[good_mask] + # + # Extended SSS basis (eSSS) + # + extended_proj = exp.get('extended_proj', ()) + if len(extended_proj) > 0: + rcond = 1e-4 + thresh = 1e-4 + extended_proj = extended_proj.T * coil_scale[good_mask] + extended_proj /= np.linalg.norm(extended_proj, axis=0) + n_int = _get_n_moments(exp['int_order']) + if S_decomp.shape[1] > n_int: + S_ext = S_decomp[:, n_int:].copy() + S_ext /= np.linalg.norm(S_ext, axis=0) + S_ext_orth = orth(S_ext, rcond=rcond) + assert S_ext_orth.shape[1] == S_ext.shape[1] + extended_proj -= np.dot(S_ext_orth, + np.dot(S_ext_orth.T, extended_proj)) + scale = np.mean(np.linalg.norm(S_decomp[n_int:], axis=0)) + else: + scale = np.mean(np.linalg.norm(S_decomp[:n_int], axis=0)) + mask = np.linalg.norm(extended_proj, axis=0) > thresh + extended_remove = list(np.where(~mask)[0] + S_decomp.shape[1]) + logger.debug(' Reducing %d -> %d' + % (extended_proj.shape[1], mask.sum())) + extended_proj /= np.linalg.norm(extended_proj, axis=0) / scale + S_decomp = np.concatenate([S_decomp, extended_proj], axis=-1) + if extended_proj.shape[1]: + S_decomp_full = np.pad( + S_decomp_full, ((0, 0), (0, extended_proj.shape[1])), + 'constant') + S_decomp_full[good_mask, -extended_proj.shape[1]:] = extended_proj + else: + extended_remove = list() + del extended_proj # # Regularization # - S_decomp, pS_decomp, sing, reg_moments, n_use_in = _regularize( - regularize, exp, S_decomp, mag_or_fine, t=t) + S_decomp, reg_moments, n_use_in = _regularize( + regularize, exp, S_decomp, mag_or_fine, extended_remove, t=t) S_decomp_full = S_decomp_full.take(reg_moments, axis=1) + # # Pseudo-inverse of total multipolar moment basis set (Part of Eq. 37) + # + pS_decomp, sing = _col_norm_pinv(S_decomp.copy()) cond = sing[0] / sing[-1] if bad_condition != 'ignore' and cond >= 1000.: msg = 'Matrix is badly conditioned: %0.0f >= 1000' % cond @@ -870,42 +938,46 @@ def _get_s_decomp(exp, all_coils, trans, coil_scale, cal, ignore_ref, # Compute point-like mags to incorporate gradiometer imbalance grad_cals = _sss_basis_point(exp, trans, cal, ignore_ref, mag_scale) # Add point like magnetometer data to bases. - S_decomp[grad_picks, :] += grad_cals + if len(grad_picks) > 0: + S_decomp[grad_picks, :] += grad_cals # Scale magnetometers by calibration coefficient - S_decomp[mag_picks, :] /= cal['mag_cals'] + if len(mag_picks) > 0: + S_decomp[mag_picks, :] /= cal['mag_cals'] # We need to be careful about KIT gradiometers return S_decomp @verbose -def _regularize(regularize, exp, S_decomp, mag_or_fine, t, verbose=None): +def _regularize(regularize, exp, S_decomp, mag_or_fine, extended_remove, t, + verbose=None): """Regularize a decomposition matrix.""" # ALWAYS regularize the out components according to norm, since # gradiometer-only setups (e.g., KIT) can have zero first-order # (homogeneous field) components int_order, ext_order = exp['int_order'], exp['ext_order'] - n_in, n_out = _get_n_moments([int_order, ext_order]) + n_in = _get_n_moments(int_order) + n_out = S_decomp.shape[1] - n_in t_str = '%8.3f' % t if regularize is not None: # regularize='in' in_removes, out_removes = _regularize_in( - int_order, ext_order, S_decomp, mag_or_fine) + int_order, ext_order, S_decomp, mag_or_fine, extended_remove) else: in_removes = [] - out_removes = _regularize_out(int_order, ext_order, mag_or_fine) + out_removes = _regularize_out(int_order, ext_order, mag_or_fine, + extended_remove) reg_in_moments = np.setdiff1d(np.arange(n_in), in_removes) - reg_out_moments = np.setdiff1d(np.arange(n_in, n_in + n_out), + reg_out_moments = np.setdiff1d(np.arange(n_in, S_decomp.shape[1]), out_removes) n_use_in = len(reg_in_moments) n_use_out = len(reg_out_moments) reg_moments = np.concatenate((reg_in_moments, reg_out_moments)) S_decomp = S_decomp.take(reg_moments, axis=1) - pS_decomp, sing = _col_norm_pinv(S_decomp.copy()) if regularize is not None or n_use_out != n_out: logger.info(' Using %s/%s harmonic components for %s ' '(%s/%s in, %s/%s out)' % (n_use_in + n_use_out, n_in + n_out, t_str, n_use_in, n_in, n_use_out, n_out)) - return S_decomp, pS_decomp, sing, reg_moments, n_use_in + return S_decomp, reg_moments, n_use_in @verbose @@ -1038,6 +1110,7 @@ def _sss_basis_basic(exp, coils, mag_scale=100., method='standard'): from scipy.special import sph_harm int_order, ext_order = exp['int_order'], exp['ext_order'] origin = exp['origin'] + assert 'extended_proj' not in exp # advanced option not supported # Compute vector between origin and coil, convert to spherical coords if method == 'standard': # Get position, normal, weights, and number of integration pts. @@ -1428,7 +1501,7 @@ def _check_info(info, sss=True, tsss=True, calibration=True, ctc=True): def _update_sss_info(raw, origin, int_order, ext_order, nchan, coord_frame, sss_ctc, sss_cal, max_st, reg_moments, st_only, - recon_trans): + recon_trans, extended_proj): """Update info inplace after Maxwell filtering. Parameters @@ -1456,10 +1529,12 @@ def _update_sss_info(raw, origin, int_order, ext_order, nchan, coord_frame, Whether tSSS only was performed. recon_trans : instance of Transformation The reconstruction trans. + extended_proj : ndarray + Extended external bases. """ n_in, n_out = _get_n_moments([int_order, ext_order]) raw.info['maxshield'] = False - components = np.zeros(n_in + n_out).astype('int32') + components = np.zeros(n_in + n_out + len(extended_proj)).astype('int32') components[reg_moments] = 1 sss_info_dict = dict(in_order=int_order, out_order=ext_order, nchan=nchan, origin=origin.astype('float32'), @@ -1578,7 +1653,7 @@ def _update_sensor_geometry(info, fine_cal, ignore_ref): # Determine gradiometer imbalances and magnetometer calibrations grad_imbalances = np.array([fine_cal['imb_cals'][info_to_cal[gi]] for gi in grad_picks]).T - if grad_imbalances.shape[0] not in [1, 3]: + if grad_imbalances.shape[0] not in [0, 1, 3]: raise ValueError('Must have 1 (x) or 3 (x, y, z) point-like ' + 'magnetometers. Currently have %i' % grad_imbalances.shape[0]) @@ -1660,8 +1735,10 @@ def _get_grad_point_coilsets(info, n_types, ignore_ref): y=np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1.]]), z=np.eye(4)) grad_coilsets = list() - grad_info = pick_info( - _simplify_info(info), pick_types(info, meg='grad', exclude=[])) + grad_picks = pick_types(info, meg='grad', exclude=[]) + if len(grad_picks) == 0: + return grad_coilsets + grad_info = pick_info(_simplify_info(info), grad_picks) # Coil_type values for x, y, z point magnetometers # Note: 1D correction files only have x-direction corrections for ch in grad_info['chs']: @@ -1693,14 +1770,15 @@ def _sss_basis_point(exp, trans, cal, ignore_ref=False, mag_scale=100.): return S_tot -def _regularize_out(int_order, ext_order, mag_or_fine): +def _regularize_out(int_order, ext_order, mag_or_fine, extended_remove): """Regularize out components based on norm.""" n_in = _get_n_moments(int_order) - out_removes = list(np.arange(0 if mag_or_fine.any() else 3) + n_in) - return list(out_removes) + remove_homog = ext_order > 0 and not mag_or_fine.any() + return list(range(n_in, n_in + 3 * remove_homog)) + extended_remove -def _regularize_in(int_order, ext_order, S_decomp, mag_or_fine): +def _regularize_in(int_order, ext_order, S_decomp, mag_or_fine, + extended_remove): """Regularize basis set using idealized SNR measure.""" n_in, n_out = _get_n_moments([int_order, ext_order]) @@ -1713,8 +1791,9 @@ def _regularize_in(int_order, ext_order, S_decomp, mag_or_fine): I_tots = np.zeros(n_in) # we might not traverse all, so use np.zeros in_keepers = list(range(n_in)) - out_removes = _regularize_out(int_order, ext_order, mag_or_fine) - out_keepers = list(np.setdiff1d(np.arange(n_in, n_in + n_out), + out_removes = _regularize_out(int_order, ext_order, mag_or_fine, + extended_remove) + out_keepers = list(np.setdiff1d(np.arange(n_in, S_decomp.shape[1]), out_removes)) remove_order = [] S_decomp = S_decomp.copy() @@ -1786,7 +1865,7 @@ def _regularize_in(int_order, ext_order, S_decomp, mag_or_fine): logger.debug(' Condition %0.3f/%0.3f = %03.1f, ' 'Removing in component %s: l=%s, m=%+0.0f' % (tuple(eigs[ii]) + (eigs[ii, 0] / eigs[ii, 1], - ri, degrees[ri], orders[ri]))) + ri, degrees[ri], orders[ri]))) logger.debug(' Resulting information: %0.1f bits/sample ' '(%0.1f%% of peak %0.1f)' % (I_tots[lim_idx], 100 * I_tots[lim_idx] / max_info, @@ -1864,7 +1943,7 @@ def find_bad_channels_maxwell( cross_talk=None, coord_frame='head', regularize='in', ignore_ref=False, bad_condition='error', head_pos=None, mag_scale=100., skip_by_annotation=('edge', 'bad_acq_skip'), h_freq=40.0, - verbose=None): + extended_proj=(), verbose=None): r"""Find bad channels using Maxwell filtering. Parameters @@ -1908,6 +1987,7 @@ def find_bad_channels_maxwell( applied before processing the data. This defaults to ``40.``, which should provide similar results to MaxFilter. If you do not wish to apply a filter, set this to ``None``. + %(maxwell_extended)s %(verbose)s Returns @@ -2034,7 +2114,7 @@ def find_bad_channels_maxwell( calibration=calibration, cross_talk=cross_talk, coord_frame=coord_frame, regularize=regularize, ignore_ref=ignore_ref, bad_condition=bad_condition, head_pos=head_pos, - mag_scale=mag_scale) + mag_scale=mag_scale, extended_proj=extended_proj) del origin, int_order, ext_order, calibration, cross_talk, coord_frame del regularize, ignore_ref, bad_condition, head_pos, mag_scale good_meg_picks = params['meg_picks'][params['good_mask']] diff --git a/mne/preprocessing/tests/test_maxwell.py b/mne/preprocessing/tests/test_maxwell.py index acb3d8b11c4..35fce3df8cb 100644 --- a/mne/preprocessing/tests/test_maxwell.py +++ b/mne/preprocessing/tests/test_maxwell.py @@ -2,8 +2,10 @@ # # License: BSD (3-clause) +from contextlib import contextmanager import os.path as op import pathlib +import re import numpy as np from numpy.testing import assert_allclose, assert_array_equal @@ -12,7 +14,7 @@ from scipy.special import sph_harm import mne -from mne import compute_raw_covariance, pick_types, concatenate_raws +from mne import compute_raw_covariance, pick_types, concatenate_raws, pick_info from mne.annotations import _annotations_starts_stops from mne.chpi import read_head_pos, filter_chpi from mne.forward import _prep_meg_channels @@ -27,7 +29,7 @@ _bases_real_to_complex, _prep_mf_coils, find_bad_channels_maxwell) from mne.rank import _get_rank_sss, _compute_rank_int from mne.utils import (assert_meg_snr, run_tests_if_main, catch_logging, - object_diff, buggy_mkl_svd) + object_diff, buggy_mkl_svd, use_log_level) data_path = testing.data_path(download=False) sss_path = op.join(data_path, 'SSS') @@ -808,8 +810,10 @@ def test_head_translation(): # that calculates the localization error: # http://ieeexplore.ieee.org/xpl/articleDetails.jsp?arnumber=1495874 -def _assert_shielding(raw_sss, erm_power, shielding_factor, meg='mag'): +def _assert_shielding(raw_sss, erm_power, min_factor, max_factor=np.inf, + meg='mag'): """Assert a minimum shielding factor using empty-room power.""" + __tracebackhide__ = True picks = pick_types(raw_sss.info, meg=meg, ref_meg=False) if isinstance(erm_power, BaseRaw): picks_erm = pick_types(raw_sss.info, meg=meg, ref_meg=False) @@ -818,8 +822,69 @@ def _assert_shielding(raw_sss, erm_power, shielding_factor, meg='mag'): sss_power = raw_sss[picks][0].ravel() sss_power = np.sqrt(np.sum(sss_power * sss_power)) factor = erm_power / sss_power - assert factor >= shielding_factor, \ - 'Shielding factor %0.3f < %0.3f' % (factor, shielding_factor) + assert min_factor <= factor < max_factor, ( + 'Shielding factor not %0.3f <= %0.3f < %0.3f' + % (min_factor, factor, max_factor)) + + +@buggy_mkl_svd +@testing.requires_testing_data +@pytest.mark.parametrize('regularize', ('in', None)) +@pytest.mark.parametrize('bads', ([], ['MEG0111'])) +def test_esss(regularize, bads): + """Test extended-basis SSS.""" + # Make some fake "projectors" that actually contain external SSS bases + raw_erm = read_crop(erm_fname).load_data().pick_types(meg=True) + raw_erm.info['bads'] = bads + proj_sss = mne.compute_proj_raw(raw_erm, meg='combined', verbose='error', + n_mag=15, n_grad=15) + good_info = pick_info(raw_erm.info, pick_types(raw_erm.info, meg=True)) + S_tot = _trans_sss_basis( + dict(int_order=0, ext_order=3, origin=(0., 0., 0.)), + all_coils=_prep_mf_coils(good_info), coil_scale=1., trans=None) + assert S_tot.shape[-1] == len(proj_sss) + for a, b in zip(proj_sss, S_tot.T): + a['data']['data'][:] = b + with catch_logging() as log: + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', + regularize=regularize, verbose=True) + log = log.getvalue() + assert 'xtend' not in log + with catch_logging() as log: + raw_sss_2 = maxwell_filter(raw_erm, coord_frame='meg', + regularize=regularize, ext_order=0, + extended_proj=proj_sss, verbose=True) + log = log.getvalue() + assert 'Extending external SSS basis using 15 projection' in log + assert_allclose(raw_sss_2._data, raw_sss._data, atol=1e-20) + + # Degenerate condititons + proj_sss = proj_sss[:2] + proj_sss[0]['data']['col_names'] = proj_sss[0]['data']['col_names'][:-1] + with pytest.raises(ValueError, match='do not match the good MEG'): + maxwell_filter(raw_erm, coord_frame='meg', extended_proj=proj_sss) + proj_sss[0] = 1. + with pytest.raises(TypeError, match=r'extended_proj\[0\] must be an inst'): + maxwell_filter(raw_erm, coord_frame='meg', extended_proj=proj_sss) + with pytest.raises(TypeError, match='extended_proj must be an inst'): + maxwell_filter(raw_erm, coord_frame='meg', extended_proj=1.) + + +@contextmanager +def get_n_projected(): + """Get the number of projected tSSS components from the log.""" + count = list() + with use_log_level(True): + with catch_logging() as log: + yield count + log = log.getvalue() + assert 'Processing data using tSSS' in log + log = log.splitlines() + reg = re.compile(r'\s+Projecting\s+([0-9])+\s+intersecting tSSS .*') + for line in log: + m = reg.match(line) + if m: + count.append(int(m.group(1))) @buggy_mkl_svd @@ -836,95 +901,176 @@ def test_shielding_factor(tmpdir): # Vanilla SSS (second value would be for meg=True instead of meg='mag') _assert_shielding(read_crop(sss_erm_std_fname), erm_power, 10) # 1.5) raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None) - _assert_shielding(raw_sss, erm_power, 12) # 1.5) - _assert_shielding(raw_sss, erm_power_grad, 0.45, 'grad') # 1.5) + _assert_shielding(raw_sss, erm_power, 12, 13) # 1.5) + _assert_shielding(raw_sss, erm_power_grad, 0.45, 0.55, 'grad') # 1.5) + + # No external basis + raw_sss_0 = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, + ext_order=0) + _assert_shielding(raw_sss_0, erm_power, 1.0, 1.1) + del raw_sss_0 + + # Regularization + _assert_shielding(read_crop(sss_erm_std_fname), erm_power, 10) # 1.5) + raw_sss = maxwell_filter(raw_erm, coord_frame='meg') + _assert_shielding(raw_sss, erm_power, 14.5, 15.5) + + # + # Extended (eSSS) + # + + # Show that using empty-room projectors increase shielding factor + proj = mne.compute_proj_raw(raw_erm, meg='combined', verbose='error', + n_mag=15, n_grad=15) + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, + extended_proj=proj[:3]) + _assert_shielding(raw_sss, erm_power, 38, 39) + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, + extended_proj=proj) + _assert_shielding(raw_sss, erm_power, 49, 51) + # Now with reg + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', + extended_proj=proj[:3]) + _assert_shielding(raw_sss, erm_power, 42, 44) + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', + extended_proj=proj) + _assert_shielding(raw_sss, erm_power, 59, 61) - # Using different mag_scale values + # + # Different mag_scale values + # raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, mag_scale='auto') - _assert_shielding(raw_sss, erm_power, 12) - _assert_shielding(raw_sss, erm_power_grad, 0.48, 'grad') + _assert_shielding(raw_sss, erm_power, 12, 13) + _assert_shielding(raw_sss, erm_power_grad, 0.48, 0.58, 'grad') raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, mag_scale=1.) # not a good choice - _assert_shielding(raw_sss, erm_power, 7.3) - _assert_shielding(raw_sss, erm_power_grad, 0.2, 'grad') + _assert_shielding(raw_sss, erm_power, 7.3, 8.) + _assert_shielding(raw_sss, erm_power_grad, 0.2, 0.3, 'grad') raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, mag_scale=1000., bad_condition='ignore') - _assert_shielding(raw_sss, erm_power, 4.0) - _assert_shielding(raw_sss, erm_power_grad, 0.1, 'grad') + _assert_shielding(raw_sss, erm_power, 4.0, 5.0) + _assert_shielding(raw_sss, erm_power_grad, 0.1, 0.2, 'grad') + # # Fine cal + # _assert_shielding(read_crop(sss_erm_fine_cal_fname), erm_power, 12) # 2.0) raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, origin=mf_meg_origin, calibration=pathlib.Path(fine_cal_fname)) - _assert_shielding(raw_sss, erm_power, 12) # 2.0) + _assert_shielding(raw_sss, erm_power, 12, 13) # 2.0) + # # Crosstalk + # _assert_shielding(read_crop(sss_erm_ctc_fname), erm_power, 12) # 2.1) raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, origin=mf_meg_origin, cross_talk=ctc_fname) - _assert_shielding(raw_sss, erm_power, 12) # 2.1) + _assert_shielding(raw_sss, erm_power, 12, 13) # 2.1) # Fine cal + Crosstalk raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, calibration=fine_cal_fname, origin=mf_meg_origin, cross_talk=ctc_fname) - _assert_shielding(raw_sss, erm_power, 13) # 2.2) + _assert_shielding(raw_sss, erm_power, 13, 14) # 2.2) + # Fine cal + Crosstalk + Extended + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, + calibration=fine_cal_fname, + origin=mf_meg_origin, + cross_talk=ctc_fname, extended_proj=proj) + _assert_shielding(raw_sss, erm_power, 28, 30) + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, + calibration=fine_cal_fname, + origin=mf_meg_origin, + cross_talk=ctc_fname, extended_proj=proj[:3]) + _assert_shielding(raw_sss, erm_power, 25, 27) # tSSS _assert_shielding(read_crop(sss_erm_st_fname), erm_power, 37) # 5.8) raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, origin=mf_meg_origin, st_duration=1.) - _assert_shielding(raw_sss, erm_power, 37) # 5.8) + _assert_shielding(raw_sss, erm_power, 37, 38) # 5.8) # Crosstalk + tSSS - raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, - cross_talk=ctc_fname, origin=mf_meg_origin, - st_duration=1.) - _assert_shielding(raw_sss, erm_power, 38) # 5.91) + with get_n_projected() as counts: + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, + cross_talk=ctc_fname, origin=mf_meg_origin, + st_duration=1.) + _assert_shielding(raw_sss, erm_power, 38, 39) # 5.91) + assert counts[0] == 4 # Fine cal + tSSS - raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, - calibration=fine_cal_fname, - origin=mf_meg_origin, st_duration=1.) - _assert_shielding(raw_sss, erm_power, 38) # 5.98) + with get_n_projected() as counts: + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, + calibration=fine_cal_fname, + origin=mf_meg_origin, st_duration=1.) + _assert_shielding(raw_sss, erm_power, 38, 39) # 5.98) + assert counts[0] == 4 + + # Extended + tSSS + with get_n_projected() as counts: + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, + origin=mf_meg_origin, st_duration=1., + extended_proj=proj) + _assert_shielding(raw_sss, erm_power, 40, 42) + assert counts[0] == 0 + with get_n_projected() as counts: + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, + origin=mf_meg_origin, st_duration=1., + extended_proj=proj[:3]) + _assert_shielding(raw_sss, erm_power, 35, 37) + assert counts[0] == 1 # Fine cal + Crosstalk + tSSS _assert_shielding(read_crop(sss_erm_st1FineCalCrossTalk_fname), - erm_power, 39) # 6.07) + erm_power, 39, 40) # 6.07) raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, calibration=fine_cal_fname, origin=mf_meg_origin, cross_talk=ctc_fname, st_duration=1.) - _assert_shielding(raw_sss, erm_power, 39) # 6.05) + _assert_shielding(raw_sss, erm_power, 39, 40) # 6.05) + + # Fine cal + Crosstalk + tSSS + Extended (a bit worse) + _assert_shielding(read_crop(sss_erm_st1FineCalCrossTalk_fname), + erm_power, 39, 40) # 6.07) + raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None, + calibration=fine_cal_fname, origin=mf_meg_origin, + cross_talk=ctc_fname, st_duration=1., + extended_proj=proj[:3]) + _assert_shielding(raw_sss, erm_power, 34, 36) # Fine cal + Crosstalk + tSSS + Reg-in _assert_shielding(read_crop(sss_erm_st1FineCalCrossTalkRegIn_fname), - erm_power, 57) # 6.97) + erm_power, 57, 58) # 6.97) raw_sss = maxwell_filter(raw_erm, calibration=fine_cal_fname, cross_talk=ctc_fname, st_duration=1., origin=mf_meg_origin, coord_frame='meg', regularize='in') - _assert_shielding(raw_sss, erm_power, 53) # 6.64) - raw_sss = maxwell_filter(raw_erm, calibration=fine_cal_fname, - cross_talk=ctc_fname, st_duration=1., - coord_frame='meg', regularize='in') - _assert_shielding(raw_sss, erm_power, 58) # 7.0) - _assert_shielding(raw_sss, erm_power_grad, 1.6, 'grad') - raw_sss = maxwell_filter(raw_erm, calibration=fine_cal_fname, - cross_talk=ctc_fname, st_duration=1., - coord_frame='meg', regularize='in', - mag_scale='auto') - _assert_shielding(raw_sss, erm_power, 51) - _assert_shielding(raw_sss, erm_power_grad, 1.5, 'grad') - raw_sss = maxwell_filter(raw_erm, calibration=fine_cal_fname_3d, - cross_talk=ctc_fname, st_duration=1., - coord_frame='meg', regularize='in') - + _assert_shielding(raw_sss, erm_power, 53, 54) # 6.64) + with get_n_projected() as counts: + raw_sss = maxwell_filter(raw_erm, calibration=fine_cal_fname, + cross_talk=ctc_fname, st_duration=1., + coord_frame='meg', regularize='in') + _assert_shielding(raw_sss, erm_power, 58, 59) # 7.0) + _assert_shielding(raw_sss, erm_power_grad, 1.6, 1.7, 'grad') + assert counts[0] == 4 + with get_n_projected() as counts: + raw_sss = maxwell_filter(raw_erm, calibration=fine_cal_fname, + cross_talk=ctc_fname, st_duration=1., + coord_frame='meg', regularize='in', + mag_scale='auto') + _assert_shielding(raw_sss, erm_power, 51, 52) + _assert_shielding(raw_sss, erm_power_grad, 1.5, 1.6, 'grad') + assert counts[0] == 3 + with get_n_projected() as counts: + raw_sss = maxwell_filter(raw_erm, calibration=fine_cal_fname_3d, + cross_talk=ctc_fname, st_duration=1., + coord_frame='meg', regularize='in') # Our 3D cal has worse defaults for this ERM than the 1D file - _assert_shielding(raw_sss, erm_power, 54) + _assert_shielding(raw_sss, erm_power, 57, 58) + assert counts[0] == 3 # Show it by rewriting the 3D as 1D and testing it temp_dir = str(tmpdir) temp_fname = op.join(temp_dir, 'test_cal.dat') @@ -932,11 +1078,22 @@ def test_shielding_factor(tmpdir): with open(temp_fname, 'w') as fid_out: for line in fid: fid_out.write(' '.join(line.strip().split(' ')[:14]) + '\n') - raw_sss = maxwell_filter(raw_erm, calibration=temp_fname, - cross_talk=ctc_fname, st_duration=1., - coord_frame='meg', regularize='in') + with get_n_projected() as counts: + raw_sss = maxwell_filter(raw_erm, calibration=temp_fname, + cross_talk=ctc_fname, st_duration=1., + coord_frame='meg', regularize='in') # Our 3D cal has worse defaults for this ERM than the 1D file - _assert_shielding(raw_sss, erm_power, 44) + _assert_shielding(raw_sss, erm_power, 44, 45) + assert counts[0] == 3 + + # Fine cal + Crosstalk + tSSS + Reg-in + Extended + with get_n_projected() as counts: + raw_sss = maxwell_filter(raw_erm, calibration=fine_cal_fname, + cross_talk=ctc_fname, st_duration=1., + coord_frame='meg', regularize='in', + extended_proj=proj[:3]) + _assert_shielding(raw_sss, erm_power, 48, 50) + assert counts[0] == 1 @pytest.mark.slowtest diff --git a/mne/proj.py b/mne/proj.py index fde5b9fdb18..62cc785236c 100644 --- a/mne/proj.py +++ b/mne/proj.py @@ -18,7 +18,6 @@ from .forward import (is_fixed_orient, _subject_from_forward, convert_forward_solution) from .source_estimate import _make_stc -from .rank import _get_rank_sss def read_proj(fname): @@ -82,8 +81,6 @@ def _compute_proj(data, info, n_grad, n_mag, n_eeg, desc_prefix, _check_option('meg', meg, ['separate', 'combined']) if meg == 'combined': - _get_rank_sss(info, msg='meg="combined" can only be used with ' - 'Maxfiltered data', verbose=False) if n_grad != n_mag: raise ValueError('n_grad (%d) must be equal to n_mag (%d) when ' 'using meg="combined"') diff --git a/mne/tests/test_proj.py b/mne/tests/test_proj.py index 3b5790aea9d..f47c09925d9 100644 --- a/mne/tests/test_proj.py +++ b/mne/tests/test_proj.py @@ -376,8 +376,6 @@ def test_sss_proj(): raw = read_raw_fif(raw_fname) raw.crop(0, 1.0).load_data().pick_types(meg=True, exclude=()) raw.pick_channels(raw.ch_names[:51]).del_proj() - with pytest.raises(ValueError, match='can only be used with Maxfiltered'): - compute_proj_raw(raw, meg='combined') raw_sss = maxwell_filter(raw, int_order=5, ext_order=2) sss_rank = 21 # really low due to channel picking assert len(raw_sss.info['projs']) == 0 diff --git a/mne/utils/docs.py b/mne/utils/docs.py index ecb56076593..378d2654a48 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -522,6 +522,13 @@ or :meth:`mne.io.Raw.append`, or separated during acquisition. To disable, provide an empty list. """ +docdict['maxwell_extended'] = """ +extended_proj : list + The empty-room projection vectors used to extend the external + SSS basis (i.e., use eSSS). + + .. versionadded:: 0.21 +""" # Rank docdict['rank'] = """