From 786945eab780a04f9ffb9ea44de6034c09ed7bb2 Mon Sep 17 00:00:00 2001 From: Luke Bloy Date: Mon, 30 Apr 2018 14:30:44 -0400 Subject: [PATCH] check info against compensation chans (#5133) * ensure comp channels are retained through picked etc. * make compenstation checks mandatory on all infos * move error message to pick_info * add read/write ctf fwd test with compchans * FIX ica and comp_chans * Fix: prepare_forward + prepare_cov --- mne/channels/channels.py | 27 +----------- mne/cov.py | 8 ++++ mne/forward/_make_forward.py | 2 + mne/forward/forward.py | 8 ++++ mne/forward/tests/test_make_forward.py | 17 +++++--- mne/io/meas_info.py | 53 ++++++++++++++++++++++++ mne/io/pick.py | 22 ++++++++++ mne/io/tests/test_meas_info.py | 37 ++++++++++++++++- mne/io/tests/test_pick.py | 19 ++++++++- mne/minimum_norm/inverse.py | 8 ++++ mne/preprocessing/ica.py | 57 +++++++++++++++++++------- mne/preprocessing/tests/test_ica.py | 46 ++++++++++++++++++++- mne/realtime/fieldtrip_client.py | 3 ++ mne/tests/test_cov.py | 36 +++++++++++++++- mne/utils.py | 9 ++++ mne/viz/topomap.py | 10 ++++- 16 files changed, 311 insertions(+), 51 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index d78ae31d625..f5f88b563e3 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -793,7 +793,7 @@ def drop_channels(self, ch_names): idx = np.setdiff1d(np.arange(len(self.ch_names)), bad_idx) return self._pick_drop_channels(idx) - def _pick_drop_channels(self, idx, check_comps=True): + def _pick_drop_channels(self, idx): # avoid circular imports from ..time_frequency import AverageTFR, EpochsTFR @@ -805,31 +805,6 @@ def _pick_drop_channels(self, idx, check_comps=True): if hasattr(self, '_cals'): self._cals = self._cals[idx] - if check_comps and len(self.info['comps']) > 0: - current_comp = get_current_comp(self.info) - # Check and possibly remove comps - comp_names = sorted(set( - comp_name for comp in self.info['comps'] - for comp_name in comp['data']['col_names'])) - comp_picks = pick_channels(self.ch_names, comp_names) - assert len(comp_picks) == len(comp_names) - missing = [comp_names[ii] - for ii in np.where(~np.in1d(comp_picks, idx))[0]] - if len(missing) > 0: - names = ', '.join(missing) - names = names[:20] + '...' if len(names) > 20 else names - if current_comp != 0: - raise RuntimeError( - 'Compensation grade %d has been applied, but ' - 'compensation channels are missing: %s\n' - 'Either remove compensation or pick compensation ' - 'channels' % (current_comp, names)) - else: - logger.info('Removing %d compensators from info because ' - 'not all compensation channels were picked' - % (len(self.info['comps']),)) - self.info['comps'] = [] - pick_info(self.info, idx, copy=False) if getattr(self, '_projector', None) is not None: diff --git a/mne/cov.py b/mne/cov.py index bcd71053183..c0fc78f70d0 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -803,6 +803,8 @@ def _unpack_epochs(epochs): data_mean = [1.0 / n_epoch * np.dot(mean, mean.T) for n_epoch, mean in zip(n_epochs, data_mean)] + if info['comps']: + info['comps'] = [] info = pick_info(info, picks_meeg) tslice = _get_tslice(epochs[0], tmin, tmax) epochs = [ee.get_data()[:, picks_meeg, tslice] for ee in epochs] @@ -1325,6 +1327,7 @@ def prepare_noise_cov(noise_cov, info, ch_names, rank=None, A copy of the covariance with the good channels subselected and parameters updated. """ + info = info.copy() noise_cov_idx = [noise_cov.ch_names.index(c) for c in ch_names] n_chan = len(ch_names) if not noise_cov['diag']: @@ -1386,6 +1389,11 @@ def prepare_noise_cov(noise_cov, info, ch_names, rank=None, else: ranks_type['meg'] = int(rank) + # XXX cov objects don't include compensation_grade + # so we can't check that the supplied cov matches info + if info['comps']: + info['comps'] = [] + for ch_type, this_has in has_type.items(): if not this_has: continue diff --git a/mne/forward/_make_forward.py b/mne/forward/_make_forward.py index c4b574c8d66..5963397edd2 100644 --- a/mne/forward/_make_forward.py +++ b/mne/forward/_make_forward.py @@ -442,6 +442,7 @@ def _prepare_for_forward(src, mri_head_t, info, bem, mindist, n_jobs, n_jobs, verbose] cmd = 'make_forward_solution(%s)' % (', '.join([str(a) for a in arg_list])) mri_id = dict(machid=np.zeros(2, np.int32), version=0, secs=0, usecs=0) + info = Info(chs=info['chs'], comps=info['comps'], dev_head_t=info['dev_head_t'], mri_file=trans, mri_id=mri_id, meas_file=info_extra, meas_id=None, working_dir=os.getcwd(), @@ -465,6 +466,7 @@ def _prepare_for_forward(src, mri_head_t, info, bem, mindist, n_jobs, raise RuntimeError('No MEG or EEG channels found.') # pick out final info + info['comps'] = [] info = pick_info(info, pick_types(info, meg=meg, eeg=eeg, ref_meg=False, exclude=[])) diff --git a/mne/forward/forward.py b/mne/forward/forward.py index 615590c89fc..f3ce6dc1d46 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -500,6 +500,14 @@ def read_forward_solution(fname, include=(), exclude=(), verbose=None): # fwd['info'] = _read_forward_meas_info(tree, fid) + # remove compensation matrcies. + if len(fwd['info'].get('comps', [])) > 0: + fwd['info']['comps'] = [] + warn('Removing compensation matrices found in measurement info of ' + 'forward operator. This should not affect application of the ' + 'forward model but it will change the model if it\'s written ' + 'back to disk', UserWarning) + # MNE environment parent_env = dir_tree_find(tree, FIFF.FIFFB_MNE_ENV) if len(parent_env) > 0: diff --git a/mne/forward/tests/test_make_forward.py b/mne/forward/tests/test_make_forward.py index ab471264810..24c33a79ed5 100644 --- a/mne/forward/tests/test_make_forward.py +++ b/mne/forward/tests/test_make_forward.py @@ -13,11 +13,11 @@ from mne.datasets import testing from mne.io import read_raw_fif, read_raw_kit, read_raw_bti, read_info from mne.io.constants import FIFF -from mne import (read_forward_solution, make_forward_solution, - convert_forward_solution, setup_volume_source_space, - read_source_spaces, make_sphere_model, - pick_types_forward, pick_info, pick_types, Transform, - read_evokeds, read_cov, read_dipole) +from mne import (read_forward_solution, write_forward_solution, + make_forward_solution, convert_forward_solution, + setup_volume_source_space, read_source_spaces, + make_sphere_model, pick_types_forward, pick_info, pick_types, + Transform, read_evokeds, read_cov, read_dipole) from mne.utils import (requires_mne, requires_nibabel, _TempDir, run_tests_if_main, run_subprocess) from mne.forward._make_forward import _create_meg_coils, make_forward_dipole @@ -203,6 +203,13 @@ def test_make_forward_solution_kit(): subjects_dir=subjects_dir) _compare_forwards(fwd, fwd_py, 274, n_src) + temp_dir = _TempDir() + fname_temp = op.join(temp_dir, 'test-ctf-fwd.fif') + write_forward_solution(fname_temp, fwd_py) + fwd_py2 = read_forward_solution(fname_temp) + _compare_forwards(fwd_py, fwd_py2, 274, n_src) + repr(fwd_py) + @pytest.mark.slowtest @testing.requires_testing_data diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py index 566387adb9b..3da62d512cb 100644 --- a/mne/io/meas_info.py +++ b/mne/io/meas_info.py @@ -31,6 +31,7 @@ from ..utils import logger, verbose, warn, object_diff from .. import __version__ from ..externals.six import b, BytesIO, string_types, text_type +from .compensator import get_current_comp _kind_dict = dict( @@ -480,6 +481,12 @@ def _check_consistency(self): self['ch_names'][ch_idx] = ch_name self['chs'][ch_idx]['ch_name'] = ch_name + # make sure required the compensation channels are present + comps_bad, comps_missing = _bad_chans_comp(self, self['ch_names']) + if comps_bad: + msg = 'Compensation channel(s) %s do not exist in info' + raise RuntimeError(msg % (comps_missing,)) + if 'filename' in self: warn('the "filename" key is misleading\ and info should not have it') @@ -1881,3 +1888,49 @@ def anonymize_info(info): value['secs'] = DATE_NONE[0] value['usecs'] = DATE_NONE[1] return info + + +def _bad_chans_comp(info, ch_names): + """Check if channel names are consistent with current compensation status. + + Parameters + ---------- + info : dict, instance of Info + Measurement information for the dataset. + + ch_names : list of str + The channel names to check. + + Returns + ------- + status : bool + True if compensation is *currently* in use but some compensation + channels are not included in picks + + False if compensation is *currently* not being used + or if compensation is being used and all compensation channels + in info and included in picks. + + missing_ch_names: array-like of str, shape (n_missing,) + The names of compensation channels not included in picks. + Returns [] if no channels are missing. + + """ + if 'comps' not in info: + # should this be thought of as a bug? + return False, [] + + # only include compensation channels that would affect selected channels + ch_names_s = set(ch_names) + comp_names = [] + for comp in info['comps']: + if len(ch_names_s.intersection(comp['data']['row_names'])) > 0: + comp_names.extend(comp['data']['col_names']) + comp_names = sorted(set(comp_names)) + + missing_ch_names = sorted(set(comp_names).difference(ch_names)) + + if get_current_comp(info) != 0 and len(missing_ch_names) > 0: + return True, missing_ch_names + + return False, missing_ch_names diff --git a/mne/io/pick.py b/mne/io/pick.py index a7cc1c2f3ac..088c64e4071 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -12,6 +12,7 @@ from .constants import FIFF from ..utils import logger, verbose from ..externals.six import string_types +from .compensator import get_current_comp def get_channel_types(): @@ -386,6 +387,9 @@ def pick_info(info, sel=(), copy=True): res : dict Info structure restricted to a selection of channels. """ + # avoid circular imports + from .meas_info import _bad_chans_comp + info._check_consistency() info = info.copy() if copy else info if sel is None: @@ -393,6 +397,24 @@ def pick_info(info, sel=(), copy=True): elif len(sel) == 0: raise ValueError('No channels match the selection.') + # make sure required the compensation channels are present + if len(info['comps']) > 0: + ch_names = [info['ch_names'][idx] for idx in sel] + _, comps_missing = _bad_chans_comp(info, ch_names) + current_comp = get_current_comp(info) + if len(comps_missing) > 0: + if current_comp != 0: + raise RuntimeError( + 'Compensation grade %d has been applied, but ' + 'compensation channels are missing: %s\n' + 'Either remove compensation or pick compensation ' + 'channels' % (current_comp, comps_missing)) + else: + logger.info('Removing %d compensators from info because ' + 'not all compensation channels were picked' + % (len(info['comps']),)) + info['comps'] = [] + info['chs'] = [info['chs'][k] for k in sel] info._update_redundant() info['bads'] = [ch for ch in info['bads'] if ch in info['ch_names']] diff --git a/mne/io/tests/test_meas_info.py b/mne/io/tests/test_meas_info.py index 288b16ac549..236af6fc1dd 100644 --- a/mne/io/tests/test_meas_info.py +++ b/mne/io/tests/test_meas_info.py @@ -9,6 +9,7 @@ from scipy import sparse from mne import Epochs, read_events, pick_info, pick_types +from mne.event import make_fixed_length_events from mne.datasets import testing from mne.io import (read_fiducials, write_fiducials, _coil_trans_to_loc, _loc_to_coil_trans, read_raw_fif, read_info, write_info, @@ -17,7 +18,9 @@ from mne.io.write import DATE_NONE from mne.io.meas_info import (Info, create_info, _write_dig_points, _read_dig_points, _make_dig_points, _merge_info, - _force_update_info, RAW_INFO_FIELDS) + _force_update_info, RAW_INFO_FIELDS, + _bad_chans_comp) +from mne.io import read_raw_ctf from mne.utils import _TempDir, run_tests_if_main from mne.channels.montage import read_montage, read_dig_montage @@ -36,6 +39,7 @@ sss_path = op.join(data_path, 'SSS') pre = op.join(sss_path, 'test_move_anon_') sss_ctc_fname = pre + 'crossTalk_raw_sss.fif' +ctf_fname = op.join(data_path, 'CTF', 'testdata_ctf.ds') def test_coil_trans(): @@ -456,4 +460,35 @@ def test_csr_csc(): assert_array_equal(ct_read.toarray(), ct.toarray()) +@testing.requires_testing_data +def test_check_compensation_consistency(): + """Test check picks compensation.""" + raw = read_raw_ctf(ctf_fname, preload=False) + events = make_fixed_length_events(raw, 99999) + picks = pick_types(raw.info, meg=True, exclude=[], ref_meg=True) + pick_ch_names = [raw.info['ch_names'][idx] for idx in picks] + for (comp, expected_result) in zip([0, 1], [False, False]): + raw.apply_gradient_compensation(comp) + ret, missing = _bad_chans_comp(raw.info, pick_ch_names) + assert ret == expected_result + assert len(missing) == 0 + Epochs(raw, events, None, -0.2, 0.2, preload=False, picks=picks) + + picks = pick_types(raw.info, meg=True, exclude=[], ref_meg=False) + pick_ch_names = [raw.info['ch_names'][idx] for idx in picks] + + for (comp, expected_result) in zip([0, 1], [False, True]): + raw.apply_gradient_compensation(comp) + ret, missing = _bad_chans_comp(raw.info, pick_ch_names) + assert ret == expected_result + assert len(missing) == 17 + if comp != 0: + with pytest.raises(RuntimeError, + match='Compensation grade 1 has been applied'): + Epochs(raw, events, None, -0.2, 0.2, preload=False, + picks=picks) + else: + Epochs(raw, events, None, -0.2, 0.2, preload=False, picks=picks) + + run_tests_if_main() diff --git a/mne/io/tests/test_pick.py b/mne/io/tests/test_pick.py index 6e720394510..c065c8fe455 100644 --- a/mne/io/tests/test_pick.py +++ b/mne/io/tests/test_pick.py @@ -48,7 +48,6 @@ def test_pick_refs(): fname_ctf_raw = op.join(io_dir, 'tests', 'data', 'test_ctf_comp_raw.fif') raw_ctf = read_raw_fif(fname_ctf_raw) raw_ctf.apply_gradient_compensation(2) - infos.append(raw_ctf.info) for info in infos: info['bads'] = [] assert_raises(ValueError, pick_types, info, meg='foo') @@ -77,12 +76,30 @@ def test_pick_refs(): picks_ref_grad]))) assert_array_equal(picks_meg_ref, np.sort(np.concatenate( [picks_grad, picks_mag, picks_ref_grad, picks_ref_mag]))) + for pick in (picks_meg_ref, picks_meg, picks_ref, picks_grad, picks_ref_grad, picks_meg_ref_grad, picks_mag, picks_ref_mag, picks_meg_ref_mag): if len(pick) > 0: pick_info(info, pick) + # test CTF expected failures directly + info = raw_ctf.info + info['bads'] = [] + picks_meg_ref = pick_types(info, meg=True, ref_meg=True) + picks_meg = pick_types(info, meg=True, ref_meg=False) + picks_ref = pick_types(info, meg=False, ref_meg=True) + picks_mag = pick_types(info, meg='mag', ref_meg=False) + picks_ref_mag = pick_types(info, meg=False, ref_meg='mag') + picks_meg_ref_mag = pick_types(info, meg='mag', ref_meg='mag') + for pick in (picks_meg_ref, picks_ref, picks_ref_mag, picks_meg_ref_mag): + if len(pick) > 0: + pick_info(info, pick) + + for pick in (picks_meg, picks_mag): + if len(pick) > 0: + assert_raises(RuntimeError, pick_info, info, pick) + def test_pick_channels_regexp(): """Test pick with regular expression.""" diff --git a/mne/minimum_norm/inverse.py b/mne/minimum_norm/inverse.py index 9487480fd5e..502457e70f2 100644 --- a/mne/minimum_norm/inverse.py +++ b/mne/minimum_norm/inverse.py @@ -1279,6 +1279,8 @@ def _xyz2lf(Lf_xyz, normals): def _prepare_forward(forward, info, noise_cov, pca=False, rank=None, verbose=None): """Prepare forward solution for inverse solvers.""" + info = info.copy() + # fwd['sol']['row_names'] may be different order from fwd['info']['chs'] fwd_sol_ch_names = forward['sol']['row_names'] ch_names = [c['ch_name'] for c in info['chs'] @@ -1306,6 +1308,12 @@ def _prepare_forward(forward, info, noise_cov, pca=False, rank=None, # Any function calling this helper will be using the returned fwd_info # dict, so fwd['sol']['row_names'] becomes obsolete and is NOT re-ordered + # remove comps if needed. + assert('comps' not in forward['info'] or + len(forward['info']['comps']) == 0) + if info['comps']: + info['comps'] = [] + info_idx = [info['ch_names'].index(name) for name in ch_names] fwd_info = pick_info(info, info_idx) diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 48f59dd57b2..a45c876a946 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -48,7 +48,7 @@ _reject_data_segments, check_random_state, compute_corr, _get_inst_data, _ensure_int, copy_function_doc_to_method_doc, _pl, warn, - _check_preload) + _check_preload, _check_compensation_grade) from ..fixes import _get_args from ..filter import filter_data @@ -441,9 +441,10 @@ def _fit_raw(self, raw, picks, start, stop, decim, reject, flat, tstep, self.max_pca_components = len(picks) logger.info('Inferring max_pca_components from picks') - self.info = pick_info(raw.info, picks) - if self.info['comps']: - self.info['comps'] = [] + info = raw.info.copy() + if info['comps']: + info['comps'] = [] + self.info = pick_info(info, picks) self.ch_names = self.info['ch_names'] start, stop = _check_start_stop(raw, start, stop) @@ -463,7 +464,7 @@ def _fit_raw(self, raw, picks, start, stop, decim, reject, flat, tstep, self.n_samples_ = data.shape[1] # this may operate inplace or make a copy - data, self.pre_whitener_ = self._pre_whiten(data, raw.info, picks) + data, self.pre_whitener_ = self._pre_whiten(data, info, picks) self._fit(data, self.max_pca_components, 'raw') @@ -481,9 +482,10 @@ def _fit_epochs(self, epochs, picks, decim, verbose): '(please be patient, this may take a while)' % len(picks)) # filter out all the channels the raw wouldn't have initialized - self.info = pick_info(epochs.info, picks) - if self.info['comps']: - self.info['comps'] = [] + info = epochs.info.copy() + if info['comps']: + info['comps'] = [] + self.info = pick_info(info, picks) self.ch_names = self.info['ch_names'] if self.max_pca_components is None: @@ -501,7 +503,7 @@ def _fit_epochs(self, epochs, picks, decim, verbose): # This will make at least one copy (one from hstack, maybe one # more from _pre_whiten) data, self.pre_whitener_ = \ - self._pre_whiten(np.hstack(data), epochs.info, picks) + self._pre_whiten(np.hstack(data), info, picks) self._fit(data, self.max_pca_components, 'epochs') @@ -653,7 +655,14 @@ def _transform_raw(self, raw, start, stop, reject_by_annotation=False): data = raw.get_data(picks, start, stop, 'omit') else: data = raw[picks, start:stop][0] - data, _ = self._pre_whiten(data, raw.info, picks) + + # remove comp matrices + assert(raw.compensation_grade == self.compensation_grade) + info = raw.info.copy() + if info['comps']: + info['comps'] = [] + + data, _ = self._pre_whiten(data, info, picks) return self._transform(data) def _transform_epochs(self, epochs, concatenate): @@ -670,9 +679,14 @@ def _transform_epochs(self, epochs, concatenate): 'provide Epochs compatible with ' 'ica.ch_names' % (len(self.ch_names), len(picks))) + # remove comp matrices + assert(epochs.compensation_grade == self.compensation_grade) + info = epochs.info.copy() + if info['comps']: + info['comps'] = [] data = np.hstack(epochs.get_data()[:, picks]) - data, _ = self._pre_whiten(data, epochs.info, picks) + data, _ = self._pre_whiten(data, info, picks) sources = self._transform(data) if not concatenate: @@ -696,7 +710,13 @@ def _transform_evoked(self, evoked): 'ica.ch_names' % (len(self.ch_names), len(picks))) - data, _ = self._pre_whiten(evoked.data[picks], evoked.info, picks) + # remove comp matrices + assert(evoked.compensation_grade == self.compensation_grade) + info = evoked.info.copy() + if info['comps']: + info['comps'] = [] + + data, _ = self._pre_whiten(evoked.data[picks], info, picks) sources = self._transform(data) return sources @@ -742,15 +762,17 @@ def get_sources(self, inst, add_channels=None, start=None, stop=None): The ICA sources time series. """ if isinstance(inst, BaseRaw): + _check_compensation_grade(self, inst, 'ICA', 'Raw') sources = self._sources_as_raw(inst, add_channels, start, stop) elif isinstance(inst, BaseEpochs): + _check_compensation_grade(self, inst, 'ICA', 'Epochs') sources = self._sources_as_epochs(inst, add_channels, False) elif isinstance(inst, Evoked): + _check_compensation_grade(self, inst, 'ICA', 'Evoked') sources = self._sources_as_evoked(inst, add_channels) else: raise ValueError('Data input must be of Raw, Epochs or Evoked ' 'type') - return sources def _sources_as_raw(self, raw, add_channels, start, stop): @@ -903,14 +925,18 @@ def score_sources(self, inst, target=None, score_func='pearsonr', scores for each source as returned from score_func """ if isinstance(inst, BaseRaw): + _check_compensation_grade(self, inst, 'ICA', 'Raw') sources = self._transform_raw(inst, start, stop, reject_by_annotation) elif isinstance(inst, BaseEpochs): + _check_compensation_grade(self, inst, 'ICA', 'Epochs') sources = self._transform_epochs(inst, concatenate=True) elif isinstance(inst, Evoked): + _check_compensation_grade(self, inst, 'ICA', 'Evoked') sources = self._transform_evoked(inst) else: - raise ValueError('Input must be of Raw, Epochs or Evoked type') + raise ValueError('Data input must be of Raw, Epochs or Evoked ' + 'type') if target is not None: # we can have univariate metrics without target target = self._check_target(target, inst, start, stop, @@ -1214,15 +1240,18 @@ def apply(self, inst, include=None, exclude=None, n_pca_components=None, The processed data. """ if isinstance(inst, BaseRaw): + _check_compensation_grade(self, inst, 'ICA', 'Raw') out = self._apply_raw(raw=inst, include=include, exclude=exclude, n_pca_components=n_pca_components, start=start, stop=stop) elif isinstance(inst, BaseEpochs): + _check_compensation_grade(self, inst, 'ICA', 'Epochs') out = self._apply_epochs(epochs=inst, include=include, exclude=exclude, n_pca_components=n_pca_components) elif isinstance(inst, Evoked): + _check_compensation_grade(self, inst, 'ICA', 'Evoked') out = self._apply_evoked(evoked=inst, include=include, exclude=exclude, n_pca_components=n_pca_components) diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index 85041fa2fb2..76f3da8bbd7 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -26,12 +26,14 @@ read_ica, run_ica) from mne.preprocessing.ica import (get_score_funcs, corrmap, _sort_components, _ica_explained_variance) -from mne.io import read_raw_fif, Info, RawArray +from mne.io import read_raw_fif, Info, RawArray, read_raw_ctf from mne.io.meas_info import _kind_dict from mne.io.pick import _DATA_CH_TYPES_SPLIT from mne.tests.common import assert_naming from mne.utils import (catch_logging, _TempDir, requires_sklearn, run_tests_if_main) +from mne.datasets import testing +from mne.event import make_fixed_length_events # Set our plotters to test mode import matplotlib @@ -44,6 +46,9 @@ event_name = op.join(data_dir, 'test-eve.fif') test_cov_name = op.join(data_dir, 'test-cov.fif') +ctf_fname = op.join(testing.data_path(download=False), 'CTF', + 'testdata_ctf.ds') + event_id, tmin, tmax = 1, -0.2, 0.2 # if stop is too small pca may fail in some cases, but we're okay on this file start, stop = 0, 6 @@ -866,4 +871,43 @@ def test_n_components_and_max_pca_components_none(method): assert_is_none(ica.n_components) +@requires_sklearn +@testing.requires_testing_data +def test_ica_ctf(): + """Test run ICA computation on ctf data with/without compensation.""" + method = 'fastica' + raw = read_raw_ctf(ctf_fname, preload=True) + events = make_fixed_length_events(raw, 99999) + for comp in [0, 1]: + raw.apply_gradient_compensation(comp) + epochs = Epochs(raw, events, None, -0.2, 0.2, preload=True) + evoked = epochs.average() + + # test fit + for inst in [raw, epochs]: + ica = ICA(n_components=2, random_state=0, max_iter=2, + method=method) + with warnings.catch_warnings(record=True): # convergence + ica.fit(raw) + + # test apply and get_sources + for inst in [raw, epochs, evoked]: + ica.apply(inst) + ica.get_sources(inst) + + # test mixed compensation case + raw.apply_gradient_compensation(0) + ica = ICA(n_components=2, random_state=0, max_iter=2, method=method) + with warnings.catch_warnings(record=True): # convergence + ica.fit(raw) + raw.apply_gradient_compensation(1) + epochs = Epochs(raw, events, None, -0.2, 0.2, preload=True) + evoked = epochs.average() + for inst in [raw, epochs, evoked]: + with pytest.raises(RuntimeError, match='Compensation grade of ICA'): + ica.apply(inst) + with pytest.raises(RuntimeError, match='Compensation grade of ICA'): + ica.get_sources(inst) + + run_tests_if_main() diff --git a/mne/realtime/fieldtrip_client.py b/mne/realtime/fieldtrip_client.py index fc25717bafb..e92f212fdc7 100644 --- a/mne/realtime/fieldtrip_client.py +++ b/mne/realtime/fieldtrip_client.py @@ -181,6 +181,9 @@ def _guess_measurement_info(self): this_info['kind'] = FIFF.FIFFV_MISC_CH chs_unknown.append(ch) + # Set coil_type (does FT supply this information somehow?) + this_info['coil_type'] = FIFF.FIFFV_COIL_NONE + # Fieldtrip already does calibration this_info['range'] = 1.0 this_info['cal'] = 1.0 diff --git a/mne/tests/test_cov.py b/mne/tests/test_cov.py index 1b4c86083c5..3aaa7ed61f3 100644 --- a/mne/tests/test_cov.py +++ b/mne/tests/test_cov.py @@ -25,11 +25,13 @@ compute_covariance, read_evokeds, compute_proj_raw, pick_channels_cov, pick_types, pick_info, make_ad_hoc_cov) from mne.fixes import _get_args -from mne.io import read_raw_fif, RawArray, read_info +from mne.io import read_raw_fif, RawArray, read_info, read_raw_ctf from mne.tests.common import assert_naming, assert_snr from mne.utils import _TempDir, requires_version, run_tests_if_main from mne.io.proc_history import _get_sss_rank from mne.io.pick import channel_type, _picks_by_type, _DATA_CH_TYPES_SPLIT +from mne.datasets import testing +from mne.event import make_fixed_length_events warnings.simplefilter('always') # enable b/c these tests throw warnings @@ -42,6 +44,9 @@ erm_cov_fname = op.join(base_dir, 'test_erm-cov.fif') hp_fif_fname = op.join(base_dir, 'test_chpi_raw_sss.fif') +ctf_fname = op.join(testing.data_path(download=False), 'CTF', + 'testdata_ctf.ds') + def test_cov_mismatch(): """Test estimation with MEG<->Head mismatch.""" @@ -644,4 +649,33 @@ def test_compute_covariance_auto_reg(): scalings=dict(misc=123)) +@testing.requires_testing_data +@requires_version('sklearn', '0.15') +def test_cov_ctf(): + """Test basic cov computation on ctf data with/without compensation.""" + raw = read_raw_ctf(ctf_fname, preload=True) + events = make_fixed_length_events(raw, 99999) + ch_names = [raw.info['ch_names'][pick] + for pick in pick_types(raw.info, meg=True, eeg=False, + ref_meg=False)] + + for comp in [0, 1]: + raw.apply_gradient_compensation(comp) + epochs = Epochs(raw, events, None, -0.2, 0.2, preload=True) + noise_cov = compute_covariance(epochs, tmax=0., method=['shrunk']) + prepare_noise_cov(noise_cov, raw.info, ch_names) + + raw.apply_gradient_compensation(0) + epochs = Epochs(raw, events, None, -0.2, 0.2, preload=True) + noise_cov = compute_covariance(epochs, tmax=0., method=['shrunk']) + raw.apply_gradient_compensation(1) + + # TODO This next call in principle should fail. + prepare_noise_cov(noise_cov, raw.info, ch_names) + + # make sure comps matrices was not removed from raw + if not raw.info['comps']: + raise RuntimeError('Comps matrices removed') + + run_tests_if_main() diff --git a/mne/utils.py b/mne/utils.py index 771677ec367..f8f071d9a57 100644 --- a/mne/utils.py +++ b/mne/utils.py @@ -2204,6 +2204,15 @@ def _check_preload(inst, msg): '%s.load_data().' % (name, name)) +def _check_compensation_grade(inst, inst2, name, name2): + """Ensure that objects have same compenstation_grade.""" + if (None not in [inst.info, inst2.info]) and (inst.compensation_grade != + inst2.compensation_grade): + msg = ('Compensation grade of %s (%d) and %s (%d) don\'t match') + raise RuntimeError(msg % (name, inst.compensation_grade, + name2, inst2.compensation_grade)) + + def _check_pandas_installed(strict=True): """Aux function.""" try: diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index 42e97bec0c6..caf556a653a 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -46,6 +46,9 @@ def _prepare_topo_plot(inst, ch_type, layout): for ii, this_ch in enumerate(info['chs']): this_ch['ch_name'] = clean_ch_names[ii] info['bads'] = _clean_names(info['bads']) + for comp in info['comps']: + comp['data']['col_names'] = _clean_names(comp['data']['col_names']) + info._update_redundant() info._check_consistency() @@ -1453,8 +1456,11 @@ def plot_evoked_topomap(evoked, times="auto", ch_type=None, layout=None, else: data = evoked.data - # Skip comps check here by using private function - evoked = evoked.copy()._pick_drop_channels(picks, check_comps=False) + # because we are only plotting we can safely remove compensation matrices + # regardless of compensation status. + evoked = evoked.copy() + evoked.info['comps'] = [] + evoked = evoked._pick_drop_channels(picks) interactive = isinstance(times, string_types) and times == 'interactive' if axes is not None: