Skip to content

Commit

Permalink
check info against compensation chans (mne-tools#5133)
Browse files Browse the repository at this point in the history
* ensure comp channels are retained through picked etc.

* make compenstation checks mandatory on all infos

* move error message to pick_info

* add read/write ctf fwd test with compchans

* FIX ica and comp_chans

* Fix: prepare_forward + prepare_cov
  • Loading branch information
bloyl authored and agramfort committed Apr 30, 2018
1 parent 75a0589 commit 786945e
Show file tree
Hide file tree
Showing 16 changed files with 311 additions and 51 deletions.
27 changes: 1 addition & 26 deletions mne/channels/channels.py
Expand Up @@ -793,7 +793,7 @@ def drop_channels(self, ch_names):
idx = np.setdiff1d(np.arange(len(self.ch_names)), bad_idx)
return self._pick_drop_channels(idx)

def _pick_drop_channels(self, idx, check_comps=True):
def _pick_drop_channels(self, idx):
# avoid circular imports
from ..time_frequency import AverageTFR, EpochsTFR

Expand All @@ -805,31 +805,6 @@ def _pick_drop_channels(self, idx, check_comps=True):
if hasattr(self, '_cals'):
self._cals = self._cals[idx]

if check_comps and len(self.info['comps']) > 0:
current_comp = get_current_comp(self.info)
# Check and possibly remove comps
comp_names = sorted(set(
comp_name for comp in self.info['comps']
for comp_name in comp['data']['col_names']))
comp_picks = pick_channels(self.ch_names, comp_names)
assert len(comp_picks) == len(comp_names)
missing = [comp_names[ii]
for ii in np.where(~np.in1d(comp_picks, idx))[0]]
if len(missing) > 0:
names = ', '.join(missing)
names = names[:20] + '...' if len(names) > 20 else names
if current_comp != 0:
raise RuntimeError(
'Compensation grade %d has been applied, but '
'compensation channels are missing: %s\n'
'Either remove compensation or pick compensation '
'channels' % (current_comp, names))
else:
logger.info('Removing %d compensators from info because '
'not all compensation channels were picked'
% (len(self.info['comps']),))
self.info['comps'] = []

pick_info(self.info, idx, copy=False)

if getattr(self, '_projector', None) is not None:
Expand Down
8 changes: 8 additions & 0 deletions mne/cov.py
Expand Up @@ -803,6 +803,8 @@ def _unpack_epochs(epochs):
data_mean = [1.0 / n_epoch * np.dot(mean, mean.T) for n_epoch, mean
in zip(n_epochs, data_mean)]

if info['comps']:
info['comps'] = []
info = pick_info(info, picks_meeg)
tslice = _get_tslice(epochs[0], tmin, tmax)
epochs = [ee.get_data()[:, picks_meeg, tslice] for ee in epochs]
Expand Down Expand Up @@ -1325,6 +1327,7 @@ def prepare_noise_cov(noise_cov, info, ch_names, rank=None,
A copy of the covariance with the good channels subselected
and parameters updated.
"""
info = info.copy()
noise_cov_idx = [noise_cov.ch_names.index(c) for c in ch_names]
n_chan = len(ch_names)
if not noise_cov['diag']:
Expand Down Expand Up @@ -1386,6 +1389,11 @@ def prepare_noise_cov(noise_cov, info, ch_names, rank=None,
else:
ranks_type['meg'] = int(rank)

# XXX cov objects don't include compensation_grade
# so we can't check that the supplied cov matches info
if info['comps']:
info['comps'] = []

for ch_type, this_has in has_type.items():
if not this_has:
continue
Expand Down
2 changes: 2 additions & 0 deletions mne/forward/_make_forward.py
Expand Up @@ -442,6 +442,7 @@ def _prepare_for_forward(src, mri_head_t, info, bem, mindist, n_jobs,
n_jobs, verbose]
cmd = 'make_forward_solution(%s)' % (', '.join([str(a) for a in arg_list]))
mri_id = dict(machid=np.zeros(2, np.int32), version=0, secs=0, usecs=0)

info = Info(chs=info['chs'], comps=info['comps'],
dev_head_t=info['dev_head_t'], mri_file=trans, mri_id=mri_id,
meas_file=info_extra, meas_id=None, working_dir=os.getcwd(),
Expand All @@ -465,6 +466,7 @@ def _prepare_for_forward(src, mri_head_t, info, bem, mindist, n_jobs,
raise RuntimeError('No MEG or EEG channels found.')

# pick out final info
info['comps'] = []
info = pick_info(info, pick_types(info, meg=meg, eeg=eeg, ref_meg=False,
exclude=[]))

Expand Down
8 changes: 8 additions & 0 deletions mne/forward/forward.py
Expand Up @@ -500,6 +500,14 @@ def read_forward_solution(fname, include=(), exclude=(), verbose=None):
#
fwd['info'] = _read_forward_meas_info(tree, fid)

# remove compensation matrcies.
if len(fwd['info'].get('comps', [])) > 0:
fwd['info']['comps'] = []
warn('Removing compensation matrices found in measurement info of '
'forward operator. This should not affect application of the '
'forward model but it will change the model if it\'s written '
'back to disk', UserWarning)

# MNE environment
parent_env = dir_tree_find(tree, FIFF.FIFFB_MNE_ENV)
if len(parent_env) > 0:
Expand Down
17 changes: 12 additions & 5 deletions mne/forward/tests/test_make_forward.py
Expand Up @@ -13,11 +13,11 @@
from mne.datasets import testing
from mne.io import read_raw_fif, read_raw_kit, read_raw_bti, read_info
from mne.io.constants import FIFF
from mne import (read_forward_solution, make_forward_solution,
convert_forward_solution, setup_volume_source_space,
read_source_spaces, make_sphere_model,
pick_types_forward, pick_info, pick_types, Transform,
read_evokeds, read_cov, read_dipole)
from mne import (read_forward_solution, write_forward_solution,
make_forward_solution, convert_forward_solution,
setup_volume_source_space, read_source_spaces,
make_sphere_model, pick_types_forward, pick_info, pick_types,
Transform, read_evokeds, read_cov, read_dipole)
from mne.utils import (requires_mne, requires_nibabel, _TempDir,
run_tests_if_main, run_subprocess)
from mne.forward._make_forward import _create_meg_coils, make_forward_dipole
Expand Down Expand Up @@ -203,6 +203,13 @@ def test_make_forward_solution_kit():
subjects_dir=subjects_dir)
_compare_forwards(fwd, fwd_py, 274, n_src)

temp_dir = _TempDir()
fname_temp = op.join(temp_dir, 'test-ctf-fwd.fif')
write_forward_solution(fname_temp, fwd_py)
fwd_py2 = read_forward_solution(fname_temp)
_compare_forwards(fwd_py, fwd_py2, 274, n_src)
repr(fwd_py)


@pytest.mark.slowtest
@testing.requires_testing_data
Expand Down
53 changes: 53 additions & 0 deletions mne/io/meas_info.py
Expand Up @@ -31,6 +31,7 @@
from ..utils import logger, verbose, warn, object_diff
from .. import __version__
from ..externals.six import b, BytesIO, string_types, text_type
from .compensator import get_current_comp


_kind_dict = dict(
Expand Down Expand Up @@ -480,6 +481,12 @@ def _check_consistency(self):
self['ch_names'][ch_idx] = ch_name
self['chs'][ch_idx]['ch_name'] = ch_name

# make sure required the compensation channels are present
comps_bad, comps_missing = _bad_chans_comp(self, self['ch_names'])
if comps_bad:
msg = 'Compensation channel(s) %s do not exist in info'
raise RuntimeError(msg % (comps_missing,))

if 'filename' in self:
warn('the "filename" key is misleading\
and info should not have it')
Expand Down Expand Up @@ -1881,3 +1888,49 @@ def anonymize_info(info):
value['secs'] = DATE_NONE[0]
value['usecs'] = DATE_NONE[1]
return info


def _bad_chans_comp(info, ch_names):
"""Check if channel names are consistent with current compensation status.
Parameters
----------
info : dict, instance of Info
Measurement information for the dataset.
ch_names : list of str
The channel names to check.
Returns
-------
status : bool
True if compensation is *currently* in use but some compensation
channels are not included in picks
False if compensation is *currently* not being used
or if compensation is being used and all compensation channels
in info and included in picks.
missing_ch_names: array-like of str, shape (n_missing,)
The names of compensation channels not included in picks.
Returns [] if no channels are missing.
"""
if 'comps' not in info:
# should this be thought of as a bug?
return False, []

# only include compensation channels that would affect selected channels
ch_names_s = set(ch_names)
comp_names = []
for comp in info['comps']:
if len(ch_names_s.intersection(comp['data']['row_names'])) > 0:
comp_names.extend(comp['data']['col_names'])
comp_names = sorted(set(comp_names))

missing_ch_names = sorted(set(comp_names).difference(ch_names))

if get_current_comp(info) != 0 and len(missing_ch_names) > 0:
return True, missing_ch_names

return False, missing_ch_names
22 changes: 22 additions & 0 deletions mne/io/pick.py
Expand Up @@ -12,6 +12,7 @@
from .constants import FIFF
from ..utils import logger, verbose
from ..externals.six import string_types
from .compensator import get_current_comp


def get_channel_types():
Expand Down Expand Up @@ -386,13 +387,34 @@ def pick_info(info, sel=(), copy=True):
res : dict
Info structure restricted to a selection of channels.
"""
# avoid circular imports
from .meas_info import _bad_chans_comp

info._check_consistency()
info = info.copy() if copy else info
if sel is None:
return info
elif len(sel) == 0:
raise ValueError('No channels match the selection.')

# make sure required the compensation channels are present
if len(info['comps']) > 0:
ch_names = [info['ch_names'][idx] for idx in sel]
_, comps_missing = _bad_chans_comp(info, ch_names)
current_comp = get_current_comp(info)
if len(comps_missing) > 0:
if current_comp != 0:
raise RuntimeError(
'Compensation grade %d has been applied, but '
'compensation channels are missing: %s\n'
'Either remove compensation or pick compensation '
'channels' % (current_comp, comps_missing))
else:
logger.info('Removing %d compensators from info because '
'not all compensation channels were picked'
% (len(info['comps']),))
info['comps'] = []

info['chs'] = [info['chs'][k] for k in sel]
info._update_redundant()
info['bads'] = [ch for ch in info['bads'] if ch in info['ch_names']]
Expand Down
37 changes: 36 additions & 1 deletion mne/io/tests/test_meas_info.py
Expand Up @@ -9,6 +9,7 @@
from scipy import sparse

from mne import Epochs, read_events, pick_info, pick_types
from mne.event import make_fixed_length_events
from mne.datasets import testing
from mne.io import (read_fiducials, write_fiducials, _coil_trans_to_loc,
_loc_to_coil_trans, read_raw_fif, read_info, write_info,
Expand All @@ -17,7 +18,9 @@
from mne.io.write import DATE_NONE
from mne.io.meas_info import (Info, create_info, _write_dig_points,
_read_dig_points, _make_dig_points, _merge_info,
_force_update_info, RAW_INFO_FIELDS)
_force_update_info, RAW_INFO_FIELDS,
_bad_chans_comp)
from mne.io import read_raw_ctf
from mne.utils import _TempDir, run_tests_if_main
from mne.channels.montage import read_montage, read_dig_montage

Expand All @@ -36,6 +39,7 @@
sss_path = op.join(data_path, 'SSS')
pre = op.join(sss_path, 'test_move_anon_')
sss_ctc_fname = pre + 'crossTalk_raw_sss.fif'
ctf_fname = op.join(data_path, 'CTF', 'testdata_ctf.ds')


def test_coil_trans():
Expand Down Expand Up @@ -456,4 +460,35 @@ def test_csr_csc():
assert_array_equal(ct_read.toarray(), ct.toarray())


@testing.requires_testing_data
def test_check_compensation_consistency():
"""Test check picks compensation."""
raw = read_raw_ctf(ctf_fname, preload=False)
events = make_fixed_length_events(raw, 99999)
picks = pick_types(raw.info, meg=True, exclude=[], ref_meg=True)
pick_ch_names = [raw.info['ch_names'][idx] for idx in picks]
for (comp, expected_result) in zip([0, 1], [False, False]):
raw.apply_gradient_compensation(comp)
ret, missing = _bad_chans_comp(raw.info, pick_ch_names)
assert ret == expected_result
assert len(missing) == 0
Epochs(raw, events, None, -0.2, 0.2, preload=False, picks=picks)

picks = pick_types(raw.info, meg=True, exclude=[], ref_meg=False)
pick_ch_names = [raw.info['ch_names'][idx] for idx in picks]

for (comp, expected_result) in zip([0, 1], [False, True]):
raw.apply_gradient_compensation(comp)
ret, missing = _bad_chans_comp(raw.info, pick_ch_names)
assert ret == expected_result
assert len(missing) == 17
if comp != 0:
with pytest.raises(RuntimeError,
match='Compensation grade 1 has been applied'):
Epochs(raw, events, None, -0.2, 0.2, preload=False,
picks=picks)
else:
Epochs(raw, events, None, -0.2, 0.2, preload=False, picks=picks)


run_tests_if_main()
19 changes: 18 additions & 1 deletion mne/io/tests/test_pick.py
Expand Up @@ -48,7 +48,6 @@ def test_pick_refs():
fname_ctf_raw = op.join(io_dir, 'tests', 'data', 'test_ctf_comp_raw.fif')
raw_ctf = read_raw_fif(fname_ctf_raw)
raw_ctf.apply_gradient_compensation(2)
infos.append(raw_ctf.info)
for info in infos:
info['bads'] = []
assert_raises(ValueError, pick_types, info, meg='foo')
Expand Down Expand Up @@ -77,12 +76,30 @@ def test_pick_refs():
picks_ref_grad])))
assert_array_equal(picks_meg_ref, np.sort(np.concatenate(
[picks_grad, picks_mag, picks_ref_grad, picks_ref_mag])))

for pick in (picks_meg_ref, picks_meg, picks_ref,
picks_grad, picks_ref_grad, picks_meg_ref_grad,
picks_mag, picks_ref_mag, picks_meg_ref_mag):
if len(pick) > 0:
pick_info(info, pick)

# test CTF expected failures directly
info = raw_ctf.info
info['bads'] = []
picks_meg_ref = pick_types(info, meg=True, ref_meg=True)
picks_meg = pick_types(info, meg=True, ref_meg=False)
picks_ref = pick_types(info, meg=False, ref_meg=True)
picks_mag = pick_types(info, meg='mag', ref_meg=False)
picks_ref_mag = pick_types(info, meg=False, ref_meg='mag')
picks_meg_ref_mag = pick_types(info, meg='mag', ref_meg='mag')
for pick in (picks_meg_ref, picks_ref, picks_ref_mag, picks_meg_ref_mag):
if len(pick) > 0:
pick_info(info, pick)

for pick in (picks_meg, picks_mag):
if len(pick) > 0:
assert_raises(RuntimeError, pick_info, info, pick)


def test_pick_channels_regexp():
"""Test pick with regular expression."""
Expand Down
8 changes: 8 additions & 0 deletions mne/minimum_norm/inverse.py
Expand Up @@ -1279,6 +1279,8 @@ def _xyz2lf(Lf_xyz, normals):
def _prepare_forward(forward, info, noise_cov, pca=False, rank=None,
verbose=None):
"""Prepare forward solution for inverse solvers."""
info = info.copy()

# fwd['sol']['row_names'] may be different order from fwd['info']['chs']
fwd_sol_ch_names = forward['sol']['row_names']
ch_names = [c['ch_name'] for c in info['chs']
Expand Down Expand Up @@ -1306,6 +1308,12 @@ def _prepare_forward(forward, info, noise_cov, pca=False, rank=None,
# Any function calling this helper will be using the returned fwd_info
# dict, so fwd['sol']['row_names'] becomes obsolete and is NOT re-ordered

# remove comps if needed.
assert('comps' not in forward['info'] or
len(forward['info']['comps']) == 0)
if info['comps']:
info['comps'] = []

info_idx = [info['ch_names'].index(name) for name in ch_names]
fwd_info = pick_info(info, info_idx)

Expand Down

0 comments on commit 786945e

Please sign in to comment.