From 3ec817ce140ab1ddbb742d47d9fc4156c673b6ec Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Mon, 4 Jan 2016 17:06:14 +0200 Subject: [PATCH] Automatic info['nchan'] and info['ch_names'] The info object has two redundant fields: `nchan` and `ch_names`. They are there as convenience fields. However, whenever the `chs` list is updated, these fields need to be manually updated as well. This PR makes these fields behave more like properties. It does so by making `Info` a subclass of `collections.MutableMapping`, which allows it to redefine `__setitem__` and `__getitem__` while retaining full compatibility with the default Python dict. The `nchan` field just maps to `len(info['chs'])`. The `ch_names` field is a bit more tricky. From the outside, it behaves as a mapping to `[ch['ch_name'] for ch in info['chs']]`. However, in order not to generate a new list every time the field is accessed, the field is an instance of `_ChannelNameList`. The `_ChannelNameList` class is a subclass of `collections.Sequence`, thus implementing a list that is read-only, but otherwise fully compatible with a normal Python list. It overwrites the `__getitem(self, index)__` method to map to `info['chs'][index]['ch_name']` on the fly. The rest of the code is updated to no longer set the `nchan` and `ch_names` fields of Info objects. --- doc/whats_new.rst | 2 + mne/channels/channels.py | 1 - mne/channels/layout.py | 3 +- mne/channels/montage.py | 2 - mne/channels/tests/test_layout.py | 4 +- mne/evoked.py | 1 - mne/forward/_field_interpolation.py | 1 - mne/forward/_make_forward.py | 13 +- mne/forward/forward.py | 3 - mne/io/brainvision/brainvision.py | 13 +- mne/io/bti/bti.py | 2 - mne/io/ctf/info.py | 2 - mne/io/edf/edf.py | 2 - mne/io/eeglab/eeglab.py | 1 - mne/io/egi/egi.py | 3 +- mne/io/fiff/tests/test_raw_fiff.py | 6 +- mne/io/kit/kit.py | 11 +- mne/io/meas_info.py | 222 +++++++++++++++++++++---- mne/io/nicolet/nicolet.py | 5 +- mne/io/pick.py | 14 +- mne/io/reference.py | 3 - mne/io/tests/test_meas_info.py | 145 +++++++++++++++- mne/io/tests/test_pick.py | 8 - mne/minimum_norm/tests/test_inverse.py | 7 +- mne/preprocessing/ica.py | 6 +- mne/preprocessing/maxwell.py | 2 - mne/realtime/fieldtrip_client.py | 5 +- mne/tests/test_epochs.py | 7 +- mne/time_frequency/tfr.py | 5 +- mne/utils.py | 12 +- mne/viz/topomap.py | 4 +- 31 files changed, 389 insertions(+), 126 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 6dc8c9790d3..74e50249eff 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -47,6 +47,8 @@ API - Deprecated function :func:`mne.time_frequency.multitaper_psd` and replaced by :func:`mne.time_frequency.psd_multitaper` by `Chris Holdgraf`_ + - The `'ch_names'` and `'nchan'` fields of the :class:`mne.io.Info` class are now read-only and automatically update to accommodate changes in the `'chs'` field, by `Marijn van Vliet`_ + .. _changes_0_11: Version 0.11 diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 09fdc2ecf18..6daa561e26f 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -651,7 +651,6 @@ def rename_channels(info, mapping): # do the reampping in info info['bads'] = bads - info['ch_names'] = ch_names for ch, ch_name in zip(info['chs'], ch_names): ch['ch_name'] = ch_name info._check_consistency() diff --git a/mne/channels/layout.py b/mne/channels/layout.py index 5059996c85e..c46d643656c 100644 --- a/mne/channels/layout.py +++ b/mne/channels/layout.py @@ -18,6 +18,7 @@ from ..transforms import _polar_to_cartesian, _cartesian_to_sphere from ..io.pick import pick_types from ..io.constants import FIFF +from ..io.meas_info import Info from ..utils import _clean_names from ..externals.six.moves import map @@ -411,7 +412,7 @@ def find_layout(info, ch_type=None, exclude='bads'): layout_name = 'Vectorview-grad' elif ((has_eeg_coils_only and ch_type in [None, 'eeg']) or (has_eeg_coils_and_meg and ch_type == 'eeg')): - if not isinstance(info, dict): + if not isinstance(info, (dict, Info)): raise RuntimeError('Cannot make EEG layout, no measurement info ' 'was passed to `find_layout`') return make_eeg_layout(info, exclude=exclude) diff --git a/mne/channels/montage.py b/mne/channels/montage.py index b6c8453aab5..3516d61b957 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -619,7 +619,6 @@ def _set_montage(info, montage, update_ch_names=False): """ if isinstance(montage, Montage): if update_ch_names: - info['ch_names'] = montage.ch_names info['chs'] = list() for ii, ch_name in enumerate(montage.ch_names): ch_info = {'cal': 1., 'logno': ii + 1, 'scanno': ii + 1, @@ -638,7 +637,6 @@ def _set_montage(info, montage, update_ch_names=False): continue ch_idx = info['ch_names'].index(ch_name) - info['ch_names'][ch_idx] = ch_name info['chs'][ch_idx]['loc'] = np.r_[pos, [0.] * 9] sensors_found.append(ch_idx) diff --git a/mne/channels/tests/test_layout.py b/mne/channels/tests/test_layout.py index 51663250de9..50639b1cb91 100644 --- a/mne/channels/tests/test_layout.py +++ b/mne/channels/tests/test_layout.py @@ -44,7 +44,6 @@ test_info = _empty_info(1000) test_info.update({ - 'ch_names': ['ICA 001', 'ICA 002', 'EOG 061'], 'chs': [{'cal': 1, 'ch_name': 'ICA 001', 'coil_type': 0, @@ -81,7 +80,7 @@ 'scanno': 376, 'unit': 107, 'unit_mul': 0}], - 'nchan': 3}) +}) def test_io_layout_lout(): @@ -229,7 +228,6 @@ def test_find_layout(): sample_info4 = copy.deepcopy(sample_info) for ii, name in enumerate(sample_info4['ch_names']): new = name.replace(' ', '') - sample_info4['ch_names'][ii] = new sample_info4['chs'][ii]['ch_name'] = new eegs = pick_types(sample_info, meg=False, eeg=True) diff --git a/mne/evoked.py b/mne/evoked.py index fcee8539f97..eef202c549a 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -198,7 +198,6 @@ def __init__(self, fname, condition=None, baseline=None, proj=True, 'channel definitions are different') info['chs'] = chs - info['nchan'] = nchan logger.info(' Found channel information in evoked data. ' 'nchan = %d' % nchan) if sfreq > 0: diff --git a/mne/forward/_field_interpolation.py b/mne/forward/_field_interpolation.py index bfca5b3e8f7..a1ef42e3322 100644 --- a/mne/forward/_field_interpolation.py +++ b/mne/forward/_field_interpolation.py @@ -211,7 +211,6 @@ def _as_meg_type_evoked(evoked, ch_type='grad', mode='fast'): # change channel names to emphasize they contain interpolated data for ch in evoked.info['chs']: ch['ch_name'] += '_virtual' - evoked.info['ch_names'] = [ch['ch_name'] for ch in evoked.info['chs']] return evoked diff --git a/mne/forward/_make_forward.py b/mne/forward/_make_forward.py index 563a4552edd..05688502b00 100644 --- a/mne/forward/_make_forward.py +++ b/mne/forward/_make_forward.py @@ -423,10 +423,9 @@ def _prepare_for_forward(src, mri_head_t, info, bem, mindist, n_jobs, mindist, overwrite, n_jobs, verbose] cmd = 'make_forward_solution(%s)' % (', '.join([str(a) for a in arg_list])) mri_id = dict(machid=np.zeros(2, np.int32), version=0, secs=0, usecs=0) - info = Info(nchan=info['nchan'], chs=info['chs'], comps=info['comps'], - ch_names=info['ch_names'], dev_head_t=info['dev_head_t'], - mri_file=trans, mri_id=mri_id, meas_file=info_extra, - meas_id=None, working_dir=os.getcwd(), + info = Info(chs=info['chs'], comps=info['comps'], + dev_head_t=info['dev_head_t'], mri_file=trans, mri_id=mri_id, + meas_file=info_extra, meas_id=None, working_dir=os.getcwd(), command_line=cmd, bads=info['bads'], mri_head_t=mri_head_t) logger.info('') @@ -549,13 +548,13 @@ def make_forward_solution(info, trans, src, bem, fname=None, meg=True, if fname is not None and op.isfile(fname) and not overwrite: raise IOError('file "%s" exists, consider using overwrite=True' % fname) - if not isinstance(info, (dict, string_types)): - raise TypeError('info should be a dict or string') + if not isinstance(info, (Info, string_types)): + raise TypeError('info should be an instance of Info or string') if isinstance(info, string_types): info_extra = op.split(info)[1] info = read_info(info, verbose=False) else: - info_extra = 'info dict' + info_extra = 'instance of Info' # Report the setup logger.info('Source space : %s' % src) diff --git a/mne/forward/forward.py b/mne/forward/forward.py index 773026dc511..ac50fcb9956 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -318,9 +318,6 @@ def _read_forward_meas_info(tree, fid): chs.append(tag.data) info['chs'] = chs - info['ch_names'] = [c['ch_name'] for c in chs] - info['nchan'] = len(chs) - # Get the MRI <-> head coordinate transformation tag = find_tag(fid, parent_mri, FIFF.FIFF_COORD_TRANS) coord_head = FIFF.FIFFV_COORD_HEAD diff --git a/mne/io/brainvision/brainvision.py b/mne/io/brainvision/brainvision.py index ba81693099c..c8baa488bda 100644 --- a/mne/io/brainvision/brainvision.py +++ b/mne/io/brainvision/brainvision.py @@ -328,10 +328,10 @@ def _get_vhdr_info(vhdr_fname, eog, misc, scale, montage): fmt = _fmt_dict[fmt] # load channel labels - info['nchan'] = cfg.getint('Common Infos', 'NumberOfChannels') + 1 - ch_names = [''] * info['nchan'] - cals = np.empty(info['nchan']) - ranges = np.empty(info['nchan']) + nchan = cfg.getint('Common Infos', 'NumberOfChannels') + 1 + ch_names = [''] * nchan + cals = np.empty(nchan) + ranges = np.empty(nchan) cals.fill(np.nan) ch_dict = dict() for chan, props in cfg.items('Channel Infos'): @@ -437,13 +437,12 @@ def _get_vhdr_info(vhdr_fname, eog, misc, scale, montage): # Creates a list of dicts of eeg channels for raw.info logger.info('Setting channel info structure...') info['chs'] = [] - info['ch_names'] = ch_names for idx, ch_name in enumerate(ch_names): - if ch_name in eog or idx in eog or idx - info['nchan'] in eog: + if ch_name in eog or idx in eog or idx - nchan in eog: kind = FIFF.FIFFV_EOG_CH coil_type = FIFF.FIFFV_COIL_NONE unit = FIFF.FIFF_UNIT_V - elif ch_name in misc or idx in misc or idx - info['nchan'] in misc: + elif ch_name in misc or idx in misc or idx - nchan in misc: kind = FIFF.FIFFV_MISC_CH coil_type = FIFF.FIFFV_COIL_NONE unit = FIFF.FIFF_UNIT_V diff --git a/mne/io/bti/bti.py b/mne/io/bti/bti.py index 68da36460e3..918e3d0810c 100644 --- a/mne/io/bti/bti.py +++ b/mne/io/bti/bti.py @@ -1120,7 +1120,6 @@ def _get_bti_info(pdf_fname, config_fname, head_shape_fname, rotation_x, info['buffer_size_sec'] = 1. # reasonable default for writing date = bti_info['processes'][0]['timestamp'] info['meas_date'] = [date, 0] - info['nchan'] = len(bti_info['chs']) # browse processing info for filter specs. hp, lp = info['highpass'], info['lowpass'] @@ -1211,7 +1210,6 @@ def _get_bti_info(pdf_fname, config_fname, head_shape_fname, rotation_x, chs.append(chan_info) info['chs'] = chs - info['ch_names'] = neuromag_ch_names if rename_channels else bti_ch_names if head_shape_fname: logger.info('... Reading digitization points from %s' % diff --git a/mne/io/ctf/info.py b/mne/io/ctf/info.py index 2a58d9c9239..a1f6cca5fed 100644 --- a/mne/io/ctf/info.py +++ b/mne/io/ctf/info.py @@ -389,13 +389,11 @@ def _compose_meas_info(res4, coils, trans, eeg): if trans['t_ctf_head_head'] is not None: info['ctf_head_t'] = trans['t_ctf_head_head'] info['chs'] = _convert_channel_info(res4, trans, eeg is None) - info['nchan'] = len(info['chs']) info['comps'] = _convert_comp_data(res4) if eeg is None: # Pick EEG locations from chan info if not read from a separate file eeg = _pick_eeg_pos(info) _add_eeg_pos(eeg, trans, info) - info['ch_names'] = [ch['ch_name'] for ch in info['chs']] logger.info(' Measurement info composed.') info._check_consistency() return info diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index bcf6eb29796..37bf68860b2 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -430,9 +430,7 @@ def _get_edf_info(fname, stim_channel, annot, annotmap, eog, misc, preload): info = _empty_info(sfreq) info['filename'] = fname info['meas_date'] = calendar.timegm(date.utctimetuple()) - info['nchan'] = nchan info['chs'] = chs - info['ch_names'] = ch_names if highpass.size == 0: pass diff --git a/mne/io/eeglab/eeglab.py b/mne/io/eeglab/eeglab.py index 72f29063d8a..2df501932c5 100644 --- a/mne/io/eeglab/eeglab.py +++ b/mne/io/eeglab/eeglab.py @@ -63,7 +63,6 @@ def _get_info(eeg, montage, eog=()): """Get measurement info. """ info = _empty_info(sfreq=eeg.srate) - info['nchan'] = eeg.nbchan # add the ch_names and info['chs'][idx]['loc'] path = None diff --git a/mne/io/egi/egi.py b/mne/io/egi/egi.py index 5d62c93d13a..f6cf57cd5d8 100644 --- a/mne/io/egi/egi.py +++ b/mne/io/egi/egi.py @@ -255,8 +255,7 @@ def __init__(self, input_fname, montage=None, eog=None, misc=None, ch_names.extend(list(egi_info['event_codes'])) if self._new_trigger is not None: ch_names.append('STI 014') # our new_trigger - info['nchan'] = nchan = len(ch_names) - info['ch_names'] = ch_names + nchan = len(ch_names) for ii, ch_name in enumerate(ch_names): ch_info = { 'cal': cal, 'logno': ii + 1, 'scanno': ii + 1, 'range': 1.0, diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index ddc1096e0c7..5017bb7c6c4 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -1019,8 +1019,10 @@ def test_add_channels(): raw_meg = raw.pick_types(meg=True, eeg=False, copy=True) raw_stim = raw.pick_types(meg=False, eeg=False, stim=True, copy=True) raw_new = raw_meg.add_channels([raw_eeg, raw_stim], copy=True) - assert_true(all(ch in raw_new.ch_names - for ch in raw_stim.ch_names + raw_meg.ch_names)) + assert_true( + all(ch in raw_new.ch_names + for ch in list(raw_stim.ch_names) + list(raw_meg.ch_names)) + ) raw_new = raw_meg.add_channels([raw_eeg], copy=True) assert_true(ch in raw_new.ch_names for ch in raw.ch_names) diff --git a/mne/io/kit/kit.py b/mne/io/kit/kit.py index 3c1b7740a02..8b334f88c01 100644 --- a/mne/io/kit/kit.py +++ b/mne/io/kit/kit.py @@ -195,12 +195,12 @@ def _set_stimchannels(self, info, stim, stim_code): % (np.max(stim), self._raw_extras[0]['nchan'])) # modify info - info['nchan'] = self._raw_extras[0]['nchan'] + 1 + nchan = self._raw_extras[0]['nchan'] + 1 ch_name = 'STI 014' chan_info = {} chan_info['cal'] = KIT.CALIB_FACTOR - chan_info['logno'] = info['nchan'] - chan_info['scanno'] = info['nchan'] + chan_info['logno'] = nchan + chan_info['scanno'] = nchan chan_info['range'] = 1.0 chan_info['unit'] = FIFF.FIFF_UNIT_NONE chan_info['unit_mul'] = 0 @@ -209,7 +209,6 @@ def _set_stimchannels(self, info, stim, stim_code): chan_info['loc'] = np.zeros(12) chan_info['kind'] = FIFF.FIFFV_STIM_CH info['chs'].append(chan_info) - info['ch_names'].append(ch_name) if self.preload: err = "Can't change stim channel after preloading data" raise NotImplementedError(err) @@ -661,7 +660,7 @@ def get_kit_info(rawfile): info = _empty_info(float(sqd['sfreq'])) info.update(meas_date=int(time.time()), lowpass=sqd['lowpass'], highpass=sqd['highpass'], filename=rawfile, - nchan=sqd['nchan'], buffer_size_sec=1.) + buffer_size_sec=1.) # Creates a list of dicts of meg channels for raw.info logger.info('Setting channel info structure...') @@ -738,8 +737,6 @@ def get_kit_info(rawfile): chan_info['kind'] = FIFF.FIFFV_MISC_CH info['chs'].append(chan_info) - info['ch_names'] = ch_names['MEG'] + ch_names['MISC'] - return info, sqd diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py index 75cdcf150c2..a7db838aeb3 100644 --- a/mne/io/meas_info.py +++ b/mne/io/meas_info.py @@ -4,10 +4,12 @@ # # License: BSD (3-clause) +import collections from warnings import warn from copy import deepcopy from datetime import datetime as dt import os.path as op +import itertools import numpy as np from scipy import linalg @@ -24,7 +26,7 @@ write_coord_trans, write_ch_info, write_name_list, write_julian, write_float_matrix) from .proc_history import _read_proc_history, _write_proc_history -from ..utils import logger, verbose +from ..utils import logger, verbose, object_hash from ..fixes import Counter from .. import __version__ from ..externals.six import b, BytesIO, string_types, text_type @@ -47,7 +49,76 @@ def _summarize_str(st): return st[:56][::-1].split(',', 1)[-1][::-1] + ', ...' -class Info(dict): +class _ChannelNameList(collections.Sequence): + """A read-only list, used to provide convenient access to channel names. + + This list is linked to an Info object. Any changes to the `chs` field of + this object are automatically reflected by this list. + + Parameters + ---------- + info : instance of Info + The Info structure containing the channel list. + """ + def __init__(self, info): + self._channels = info['chs'] + + def __getitem__(self, index): + """Retrieve the name of the channel with the given index""" + if isinstance(index, slice): + return [ch['ch_name'] for ch in self._channels[index]] + else: + return self._channels[index]['ch_name'] + + def __len__(self): + """Length of the list""" + return len(self._channels) + + def __eq__(self, other): + """Test for equality""" + if isinstance(other, _ChannelNameList): + return list(self) == list(other) + elif isinstance(other, list): + return list(self) == other + else: + raise ValueError('Cannot compare _ChannelNameList and %s' + % type(other)) + + def __ne__(self, other): + """Test for non-equality""" + return not self.__eq__(other) + + def __repr__(self): + """String representation""" + if len(self) < 10: + channels = ', '.join(self) + else: + channels = ', '.join(self[:5]) + ' ... ' + ', '.join(self[-5:]) + + return '' % (len(self), channels) + + def __add__(self, other): + """Return a list containing the channels names in both lists""" + return list(self) + list(other) + + # Raise descriptive error when the user tries to modify this list. + _read_only_error = ("This list of channel names is read-only. It is " + "automatically computed by the info object. " + "Instead of modifying this list, make your " + "modifications to the list of channels in the info " + "object (info['chs']).") + + def __setitem__(self, index, value): + raise RuntimeError(_ChannelNameList._read_only_error) + + def __iadd__(self, value): + raise RuntimeError(_ChannelNameList._read_only_error) + + def append(self, value): + raise RuntimeError(_ChannelNameList._read_only_error) + + +class Info(collections.MutableMapping): """Information about the recording. This data structure behaves like a dictionary. It contains all meta-data @@ -60,8 +131,11 @@ class Info(dict): bads : list of str List of bad (noisy/broken) channels, by name. These channels will by default be ignored by many processing steps. - ch_names : list of str + ch_names : list-like of str (read-only) The names of the channels. + This object behaves like a read-only Python list. Behind the scenes + it iterates over the channels dictionaries in `info['chs']`: + `info['ch_names'][x] == info['chs'][x]['ch_name']` chs : list of dict A list of channel information structures. See: :ref:`faq` for details. @@ -148,6 +222,23 @@ class Info(dict): processing logs inside of a raw file. See: :ref:`faq` for details. """ + # These fields are read only and are computed from other fields + _read_only_fields = { + 'nchan': lambda info: len(info['chs']), + 'ch_names': lambda info: _ChannelNameList(info), + } + + def __init__(self, *args, **kwargs): + self._store = dict() + + # When initializing from a dict, silently ignore read-only fields. + # This is to keep backwards compatibility. + if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict): + for key, value in args[0].items(): + if key not in Info._read_only_fields: + self._store[key] = value + else: + self.update(dict(*args, **kwargs)) def copy(self): """Copy the instance @@ -157,7 +248,83 @@ def copy(self): info : instance of Info The copied info. """ - return Info(super(Info, self).copy()) + return Info(self._store) + + def to_dict(self): + """Obtain a version of this info object in the form of a plain Python + dict + + While the Info object behaves like a dict, it is a subclass of + MutableMapping. Use this function to cast it to an actual dict. Be + aware that the read-only fields will become writable and will no longer + auto-update. + + Notes + ----- + - To cast the plain dict back to an instance of Info, use: + `mne.io.Info(info_dict)` + - Always use a propert instance of Info when calling MNE-Python + functions. + """ + info_dict = dict(self) + # Convert non-standard fields + info_dict['ch_names'] = list(info_dict['ch_names']) + return info_dict + + def __getitem__(self, key): + """Retrieve a value from the data store + + Parameters + ---------- + key : str + The key associated with the value to retrieve + + Returns + ------- + value : object + The value associated with the key + """ + try: + return Info._read_only_fields[key](self) + except KeyError: + return self._store[key] + + def __setitem__(self, key, value): + """Store a value in the data store + + Parameters + ---------- + key : str + The key associated with the value for later retrieval + value : object + The value associated with the key + """ + if key in Info._read_only_fields: + raise ValueError('The field ' + key + ' is read only.') + else: + self._store[key] = value + + def __delitem__(self, key): + """Remove a value from the store + + Parameters + ---------- + key : str + The key associated with the value to remove + """ + if key in Info._read_only_fields: + raise ValueError('The field ' + key + ' is read only.') + else: + del self._store[key] + + # Make the read-only fields show up in info.items() + def __iter__(self): + return itertools.chain(self._store.keys(), + Info._read_only_fields.keys()) + + # Read-only fields add to the length of the info object + def __len__(self): + return len(self._store) + len(Info._read_only_fields) def __repr__(self): """Summarize info instead of printing all""" @@ -215,17 +382,22 @@ def _check_consistency(self): if len(missing) > 0: raise RuntimeError('bad channel(s) %s marked do not exist in info' % (missing,)) - chs = [ch['ch_name'] for ch in self['chs']] - if len(self['ch_names']) != len(chs) or any( - ch_1 != ch_2 for ch_1, ch_2 in zip(self['ch_names'], chs)) or \ - self['nchan'] != len(chs): - raise RuntimeError('info channel name inconsistency detected, ' - 'please notify mne-python developers') # make sure we have the proper datatypes for key in ('sfreq', 'highpass', 'lowpass'): if self.get(key) is not None: self[key] = float(self[key]) + # make sure channel names are unique + unique_ids = np.unique(self['ch_names'], return_index=True)[1] + if len(unique_ids) != self['nchan']: + dups = set(self['ch_names'][x] + for x in np.setdiff1d(range(self['nchan']), unique_ids)) + raise RuntimeError('Channel names are not unique, found ' + 'duplicates for: %s' % dups) + + def __hash__(self): + return object_hash(self._store) + def read_fiducials(fname): """Read fiducials from a fiff file @@ -412,7 +584,7 @@ def _make_dig_points(nasion=None, lpa=None, rpa=None, hpi=None, if lpa.shape == (3,): dig.append({'r': lpa, 'ident': FIFF.FIFFV_POINT_LPA, 'kind': FIFF.FIFFV_POINT_CARDINAL, - 'coord_frame': FIFF.FIFFV_COORD_HEAD}) + 'coord_frame': FIFF.FIFFV_COORD_HEAD}) else: msg = ('LPA should have the shape (3,) instead of %s' % (lpa.shape,)) @@ -422,7 +594,7 @@ def _make_dig_points(nasion=None, lpa=None, rpa=None, hpi=None, if nasion.shape == (3,): dig.append({'r': nasion, 'ident': FIFF.FIFFV_POINT_NASION, 'kind': FIFF.FIFFV_POINT_CARDINAL, - 'coord_frame': FIFF.FIFFV_COORD_HEAD}) + 'coord_frame': FIFF.FIFFV_COORD_HEAD}) else: msg = ('Nasion should have the shape (3,) instead of %s' % (nasion.shape,)) @@ -432,7 +604,7 @@ def _make_dig_points(nasion=None, lpa=None, rpa=None, hpi=None, if rpa.shape == (3,): dig.append({'r': rpa, 'ident': FIFF.FIFFV_POINT_RPA, 'kind': FIFF.FIFFV_POINT_CARDINAL, - 'coord_frame': FIFF.FIFFV_COORD_HEAD}) + 'coord_frame': FIFF.FIFFV_COORD_HEAD}) else: msg = ('RPA should have the shape (3,) instead of %s' % (rpa.shape,)) @@ -875,7 +1047,6 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): else: info['meas_date'] = meas_date - info['nchan'] = nchan info['sfreq'] = sfreq info['highpass'] = highpass if highpass is not None else 0. info['lowpass'] = lowpass if lowpass is not None else info['sfreq'] / 2.0 @@ -884,7 +1055,6 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): # Add the channel information and make a list of channel names # for convenience info['chs'] = chs - info['ch_names'] = [ch['ch_name'] for ch in chs] # # Add the coordinate transformations @@ -1269,17 +1439,15 @@ def _merge_info(infos, verbose=None): for info in infos: info._check_consistency() info = Info() - ch_names = _merge_dict_values(infos, 'ch_names') - duplicates = set([ch for ch in ch_names if ch_names.count(ch) > 1]) + info['chs'] = [] + for this_info in infos: + info['chs'].extend(this_info['chs']) + duplicates = set([ch for ch in info['ch_names'] + if info['ch_names'].count(ch) > 1]) if len(duplicates) > 0: msg = ("The following channels are present in more than one input " "measurement info objects: %s" % list(duplicates)) raise ValueError(msg) - info['nchan'] = len(ch_names) - info['ch_names'] = ch_names - info['chs'] = [] - for this_info in infos: - info['chs'].extend(this_info['chs']) transforms = ['ctf_head_t', 'dev_head_t', 'dev_ctf_t'] for trans_name in transforms: @@ -1359,8 +1527,6 @@ def create_info(ch_names, sfreq, ch_types=None, montage=None): raise ValueError('ch_types and ch_names must be the same length') info = _empty_info(sfreq) info['meas_date'] = np.array([0, 0], np.int32) - info['ch_names'] = ch_names - info['nchan'] = nchan loc = np.concatenate((np.zeros(3), np.eye(3).ravel())).astype(np.float32) for ci, (name, kind) in enumerate(zip(ch_names, ch_types)): if not isinstance(name, string_types): @@ -1414,21 +1580,17 @@ def _empty_info(sfreq): 'lowpass', 'meas_date', 'meas_id', 'proj_id', 'proj_name', 'subject_info', ) - _list_keys = ( - 'bads', 'ch_names', 'chs', 'comps', 'events', 'hpi_meas', - 'hpi_results', 'projs', - ) + _list_keys = ('bads', 'chs', 'comps', 'events', 'hpi_meas', 'hpi_results', + 'projs') info = Info() for k in _none_keys: info[k] = None for k in _list_keys: info[k] = list() info['custom_ref_applied'] = False - info['nchan'] = 0 info['dev_head_t'] = Transform('meg', 'head', np.eye(4)) info['highpass'] = 0. info['sfreq'] = float(sfreq) info['lowpass'] = info['sfreq'] / 2. - assert set(info.keys()) == set(RAW_INFO_FIELDS) info._check_consistency() return info diff --git a/mne/io/nicolet/nicolet.py b/mne/io/nicolet/nicolet.py index 85954d41fd9..07fda2eeac0 100644 --- a/mne/io/nicolet/nicolet.py +++ b/mne/io/nicolet/nicolet.py @@ -103,10 +103,9 @@ def _get_nicolet_info(fname, ch_type, eog, ecg, emg, misc): date = datetime.datetime(int(date[0]), int(date[1]), int(date[2]), int(time[0]), int(time[1]), int(sec), int(msec)) info = _empty_info(header_info['sample_freq']) - info.update({'filename': fname, 'nchan': header_info['num_channels'], + info.update({'filename': fname, 'meas_date': calendar.timegm(date.utctimetuple()), - 'ch_names': ch_names, 'description': None, - 'buffer_size_sec': 10.}) + 'description': None, 'buffer_size_sec': 10.}) if ch_type == 'eeg': ch_coil = FIFF.FIFFV_COIL_EEG diff --git a/mne/io/pick.py b/mne/io/pick.py index 8370e531256..f002e64be71 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -97,6 +97,11 @@ def pick_channels(ch_names, include, exclude=[]): raise RuntimeError('ch_names is not a unique list, picking is unsafe') _check_excludes_includes(include) _check_excludes_includes(exclude) + if not isinstance(include, set): + include = set(include) + if not isinstance(exclude, set): + exclude = set(exclude) + sel = [] for k, name in enumerate(ch_names): if (len(include) == 0 or name in include) and name not in exclude: @@ -328,8 +333,6 @@ def pick_info(info, sel=[], copy=True): raise ValueError('No channels match the selection.') info['chs'] = [info['chs'][k] for k in sel] - info['ch_names'] = [info['ch_names'][k] for k in sel] - info['nchan'] = len(sel) info['bads'] = [ch for ch in info['bads'] if ch in info['ch_names']] comps = deepcopy(info['comps']) @@ -463,9 +466,7 @@ def pick_channels_forward(orig, include=[], exclude=[], verbose=None): fwd['sol']['row_names'] = ch_names # Pick the appropriate channel names from the info-dict using sel_info - fwd['info']['ch_names'] = [fwd['info']['ch_names'][k] for k in sel_info] fwd['info']['chs'] = [fwd['info']['chs'][k] for k in sel_info] - fwd['info']['nchan'] = nuse fwd['info']['bads'] = [b for b in fwd['info']['bads'] if b in ch_names] if fwd['sol_grad'] is not None: @@ -513,6 +514,7 @@ def pick_types_forward(orig, meg=True, eeg=False, ref_meg=True, seeg=False, if len(sel) == 0: raise ValueError('No valid channels found') include_ch_names = [info['ch_names'][k] for k in sel] + return pick_channels_forward(orig, include_ch_names) @@ -621,8 +623,8 @@ def _check_excludes_includes(chs, info=None, allow_bads=False): Channels to be excluded/excluded. If allow_bads, and chs=="bads", this will be the bad channels found in 'info'. """ - from .meas_info import Info - if not isinstance(chs, (list, tuple, np.ndarray)): + from .meas_info import Info, _ChannelNameList + if not isinstance(chs, (list, tuple, np.ndarray, _ChannelNameList)): if allow_bads is True: if not isinstance(info, Info): raise ValueError('Supply an info object if allow_bads is true') diff --git a/mne/io/reference.py b/mne/io/reference.py index 1585dd43b9e..5f7328078a4 100644 --- a/mne/io/reference.py +++ b/mne/io/reference.py @@ -198,8 +198,6 @@ def add_reference_channels(inst, ref_channels, copy=True): 'coord_frame': FIFF.FIFFV_COORD_HEAD, 'loc': np.zeros(12)} inst.info['chs'].append(chan_info) - inst.info['ch_names'].extend(ref_channels) - inst.info['nchan'] = len(inst.info['ch_names']) if isinstance(inst, _BaseRaw): inst._cals = np.hstack((inst._cals, [1] * len(ref_channels))) @@ -378,7 +376,6 @@ def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None, an_idx = inst.ch_names.index(an) inst.info['chs'][an_idx] = info inst.info['chs'][an_idx]['ch_name'] = name - inst.info['ch_names'][an_idx] = name logger.info('Bipolar channel added as "%s".' % name) # Drop cathode channels diff --git a/mne/io/tests/test_meas_info.py b/mne/io/tests/test_meas_info.py index 0764fa11a37..f0d38bc2cda 100644 --- a/mne/io/tests/test_meas_info.py +++ b/mne/io/tests/test_meas_info.py @@ -11,7 +11,8 @@ _loc_to_coil_trans, Raw, read_info, write_info) from mne.io.constants import FIFF from mne.io.meas_info import (Info, create_info, _write_dig_points, - _read_dig_points, _make_dig_points, _merge_info) + _read_dig_points, _make_dig_points, _merge_info, + RAW_INFO_FIELDS) from mne.utils import _TempDir, run_tests_if_main from mne.channels.montage import read_montage, read_dig_montage @@ -41,6 +42,8 @@ def test_make_info(): """ n_ch = 1 info = create_info(n_ch, 1000., 'eeg') + assert_equal(sorted(info.keys()), sorted(RAW_INFO_FIELDS)) + coil_types = set([ch['coil_type'] for ch in info['chs']]) assert_true(FIFF.FIFFV_COIL_EEG in coil_types) @@ -126,13 +129,61 @@ def test_info(): info[42] = 'foo' assert_true(info[42] == 'foo') - # test info attribute in API objects + # Test info attribute in API objects for obj in [raw, epochs, evoked]: assert_true(isinstance(obj.info, Info)) info_str = '%s' % obj.info - assert_equal(len(info_str.split('\n')), (len(obj.info.keys()) + 2)) + assert_equal(len(info_str.split('\n')), len(obj.info.keys()) + 2) assert_true(all(k in info_str for k in obj.info.keys())) + # Test read-only fields + info = raw.info.copy() + nchan = len(info['chs']) + ch_names = [ch['ch_name'] for ch in info['chs']] + assert_equal(info['nchan'], nchan) + assert_equal(list(info['ch_names']), ch_names) + + def _assignment_to_nchan(info): + info['nchan'] = 42 + assert_raises(ValueError, _assignment_to_nchan, info) + + def _assignment_to_ch_names(info): + info['ch_names'] = ['foo', 'bar'] + assert_raises(ValueError, _assignment_to_ch_names, info) + + def _del_nchan(info): + del info['nchan'] + assert_raises(ValueError, _del_nchan, info) + + def _del_ch_names(info): + del info['ch_names'] + assert_raises(ValueError, _del_ch_names, info) + + # Deleting of regular fields should work + info['foo'] = 'bar' + del info['foo'] + + # Passing read only fields to the constructor + assert_raises(ValueError, Info, nchan=42) + assert_raises(ValueError, Info, ch_names=['foo', 'bar']) + + # Test automatic updating of read-only fields + del info['chs'][-1] + assert_equal(info['nchan'], nchan - 1) + assert_equal(list(info['ch_names']), ch_names[:-1]) + + info['chs'][0]['ch_name'] = 'foo' + assert_equal(info['ch_names'][0], 'foo') + + # Test casting to and from a dict + info_dict = dict(info) + info2 = Info(info_dict) + assert_equal(info, info2) + + info_dict = info.to_dict() + info2 = Info(info_dict) + assert_equal(info, info2) + def test_read_write_info(): """Test IO of info @@ -209,6 +260,49 @@ def test_make_dig_points(): dig_points[:, :2]) +def test_channel_name_list(): + """Test the _ChannelNamesList object""" + # Indexing + info = create_info(ch_names=['a', 'b', 'c'], sfreq=1000., ch_types=None) + assert_equal(info['ch_names'][0], 'a') + assert_equal(info['ch_names'][1], 'b') + assert_equal(info['ch_names'][2], 'c') + + # Equality + assert_equal(info['ch_names'], info['ch_names']) + assert_equal(info['ch_names'], ['a', 'b', 'c']) + + # No channels in info + info = create_info(ch_names=[], sfreq=1000., ch_types=None) + assert_equal(info['ch_names'], []) + + # List should be read-only + info = create_info(ch_names=['a', 'b', 'c'], sfreq=1000., ch_types=None) + + def _test_assignment(): + info['ch_names'][0] = 'foo' + assert_raises(RuntimeError, _test_assignment) + + def _test_concatenation(): + info['ch_names'] += ['foo'] + assert_raises(RuntimeError, _test_concatenation) + + def _test_appending(): + info['ch_names'].append('foo') + assert_raises(RuntimeError, _test_appending) + + def _test_removal(): + del info['ch_names'][0] + assert_raises(AttributeError, _test_removal) + + # Concatenation + assert_equal(info['ch_names'] + ['d'], ['a', 'b', 'c', 'd']) + + # Representation + assert_equal(repr(info['ch_names']), + "") + + def test_merge_info(): """Test merging of multiple Info objects""" info_a = create_info(ch_names=['a', 'b', 'c'], sfreq=1000., ch_types=None) @@ -219,4 +313,49 @@ def test_merge_info(): assert_raises(ValueError, _merge_info, [info_a, info_a]) +def test_check_consistency(): + """Test consistency check of Info objects""" + info = create_info(ch_names=['a', 'b', 'c'], sfreq=1000.) + + # This should pass + info._check_consistency() + + # Info without any channels + info_empty = create_info(ch_names=[], sfreq=1000.) + info_empty._check_consistency() + + # Bad channels that are not in the info object + info2 = info.copy() + info2['bads'] = ['b', 'foo', 'bar'] + assert_raises(RuntimeError, info2._check_consistency) + + # Bad data types + info2 = info.copy() + info2['sfreq'] = 'foo' + assert_raises(ValueError, info2._check_consistency) + + info2 = info.copy() + info2['highpass'] = 'foo' + assert_raises(ValueError, info2._check_consistency) + + info2 = info.copy() + info2['lowpass'] = 'foo' + assert_raises(ValueError, info2._check_consistency) + + # Silent type conversion to float + info2 = info.copy() + info2['sfreq'] = 1 + info2['highpass'] = 2 + info2['lowpass'] = 2 + info2._check_consistency() + assert_true(isinstance(info2['sfreq'], float)) + assert_true(isinstance(info2['highpass'], float)) + assert_true(isinstance(info2['lowpass'], float)) + + # Duplicate channel names + info2 = info.copy() + info2['chs'][2]['ch_name'] = 'b' + assert_raises(RuntimeError, info2._check_consistency) + + run_tests_if_main() diff --git a/mne/io/tests/test_pick.py b/mne/io/tests/test_pick.py index 2e1f5129909..4132bf523db 100644 --- a/mne/io/tests/test_pick.py +++ b/mne/io/tests/test_pick.py @@ -235,13 +235,5 @@ def test_clean_info_bads(): info._check_consistency() info['bads'] += ['EEG 053'] assert_raises(RuntimeError, info._check_consistency) - info = pick_info(raw.info, picks_meg) - info._check_consistency() - info['ch_names'][0] += 'f' - assert_raises(RuntimeError, info._check_consistency) - info = pick_info(raw.info, picks_meg) - info._check_consistency() - info['nchan'] += 1 - assert_raises(RuntimeError, info._check_consistency) run_tests_if_main() diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py index 22747ceda78..8ff90b23f85 100644 --- a/mne/minimum_norm/tests/test_inverse.py +++ b/mne/minimum_norm/tests/test_inverse.py @@ -16,7 +16,7 @@ from mne import (read_cov, read_forward_solution, read_evokeds, pick_types, pick_types_forward, make_forward_solution, convert_forward_solution, Covariance) -from mne.io import Raw +from mne.io import Raw, Info from mne.minimum_norm.inverse import (apply_inverse, read_inverse_operator, apply_inverse_raw, apply_inverse_epochs, make_inverse_operator, @@ -81,8 +81,8 @@ def _compare(a, b): skip_types = ['whitener', 'proj', 'reginv', 'noisenorm', 'nchan', 'command_line', 'working_dir', 'mri_file', 'mri_id'] try: - if isinstance(a, dict): - assert_true(isinstance(b, dict)) + if isinstance(a, (dict, Info)): + assert_true(isinstance(b, (dict, Info))) for k, v in six.iteritems(a): if k not in b and k not in skip_types: raise ValueError('First one had one second one didn\'t:\n' @@ -225,7 +225,6 @@ def test_inverse_operator_channel_ordering(): randomiser = np.random.RandomState(42) randomiser.shuffle(new_order) evoked.data = evoked.data[new_order] - evoked.info['ch_names'] = [evoked.info['ch_names'][n] for n in new_order] evoked.info['chs'] = [evoked.info['chs'][n] for n in new_order] cov_ch_reorder = [c for c in evoked.info['ch_names'] diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index b416e42b2bf..5f729a4396e 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -713,7 +713,7 @@ def _export_info(self, info, container, add_channels): """Aux method """ # set channel names and info - ch_names = info['ch_names'] = [] + ch_names = [] ch_info = info['chs'] = [] for ii in range(self.n_components_): this_source = 'ICA %03d' % (ii + 1) @@ -732,10 +732,6 @@ def _export_info(self, info, container, add_channels): # re-append additionally picked ch_info ch_info += [k for k in container.info['chs'] if k['ch_name'] in add_channels] - # update number of channels - info['nchan'] = self.n_components_ - if add_channels is not None: - info['nchan'] += len(add_channels) info['bads'] = [ch_names[k] for k in self.exclude] info['projs'] = [] # make sure projections are removed. diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index 6c9cb4e7486..176ef559de0 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -617,8 +617,6 @@ def _copy_preload_add_channels(raw, add_channels): cal=1. / 10000., coil_type=FIFF.FWD_COIL_UNKNOWN) for ii in range(len(kinds))] raw.info['chs'].extend(chpi_chs) - raw.info['nchan'] += len(chpi_chs) - raw.info['ch_names'] += [c['ch_name'] for c in chpi_chs] raw.info._check_consistency() assert raw._data.shape == (raw.info['nchan'], len(raw.times)) # Return the pos picks diff --git a/mne/realtime/fieldtrip_client.py b/mne/realtime/fieldtrip_client.py index 6b35cb08020..bf4ee33fdc2 100644 --- a/mne/realtime/fieldtrip_client.py +++ b/mne/realtime/fieldtrip_client.py @@ -142,9 +142,6 @@ def _guess_measurement_info(self): info = _empty_info(self.ft_header.fSample) # create info # modify info attributes according to the FieldTrip Header object - info['nchan'] = self.ft_header.nChannels - info['ch_names'] = self.ft_header.labels - info['comps'] = list() info['projs'] = list() info['bads'] = list() @@ -152,7 +149,7 @@ def _guess_measurement_info(self): # channel dictionary list info['chs'] = [] - for idx, ch in enumerate(info['ch_names']): + for idx, ch in enumerate(self.ft_header.labels): this_info = dict() this_info['scanno'] = idx diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index a395ae27c1b..1adad8d740e 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -1752,8 +1752,6 @@ def make_epochs(picks, proj): [epochs_meg, epochs_eeg[:2]]) epochs_meg.info['chs'].pop(0) - epochs_meg.info['ch_names'].pop(0) - epochs_meg.info['nchan'] -= 1 assert_raises(RuntimeError, add_channels_epochs, [epochs_meg, epochs_eeg]) @@ -1768,9 +1766,8 @@ def make_epochs(picks, proj): [epochs_meg2, epochs_eeg]) epochs_meg2 = epochs_meg.copy() - epochs_meg2.info['ch_names'][1] = epochs_meg2.info['ch_names'][0] - epochs_meg2.info['chs'][1]['ch_name'] = epochs_meg2.info['ch_names'][1] - assert_raises(ValueError, add_channels_epochs, + epochs_meg2.info['chs'][1]['ch_name'] = epochs_meg2.info['ch_names'][0] + assert_raises(RuntimeError, add_channels_epochs, [epochs_meg2, epochs_eeg]) epochs_meg2 = epochs_meg.copy() diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index dabc7440896..cae91f1d294 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1065,8 +1065,9 @@ def save(self, fname, overwrite=False): def _prepare_write_tfr(tfr, condition): """Aux function""" return (condition, dict(times=tfr.times, freqs=tfr.freqs, - data=tfr.data, info=tfr.info, nave=tfr.nave, - comment=tfr.comment, method=tfr.method)) + data=tfr.data, info=tfr.info.to_dict(), + nave=tfr.nave, comment=tfr.comment, + method=tfr.method)) def write_tfrs(fname, tfr, overwrite=False): diff --git a/mne/utils.py b/mne/utils.py index 0bdf54df7fd..660381591f7 100644 --- a/mne/utils.py +++ b/mne/utils.py @@ -101,15 +101,12 @@ def object_hash(x, h=None): """ if h is None: h = hashlib.md5() - if isinstance(x, dict): + if hasattr(x, 'keys'): + # dict-like types keys = _sort_keys(x) for key in keys: object_hash(key, h) object_hash(x[key], h) - elif isinstance(x, (list, tuple)): - h.update(str(type(x)).encode('utf-8')) - for xx in x: - object_hash(xx, h) elif isinstance(x, bytes): # must come before "str" below h.update(x) @@ -121,6 +118,11 @@ def object_hash(x, h=None): h.update(str(x.shape).encode('utf-8')) h.update(str(x.dtype).encode('utf-8')) h.update(x.tostring()) + elif hasattr(x, '__len__'): + # all other list-like types + h.update(str(type(x)).encode('utf-8')) + for xx in x: + object_hash(xx, h) else: raise RuntimeError('unsupported type: %s (%s)' % (type(x), x)) return int(h.hexdigest(), 16) diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index d73771de4f4..bc720a2b7b8 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -38,9 +38,9 @@ def _prepare_topo_plot(inst, ch_type, layout): elif layout == 'auto': layout = None - info['ch_names'] = _clean_names(info['ch_names']) + clean_ch_names = _clean_names(info['ch_names']) for ii, this_ch in enumerate(info['chs']): - this_ch['ch_name'] = info['ch_names'][ii] + this_ch['ch_name'] = clean_ch_names[ii] # special case for merging grad channels if (ch_type == 'grad' and FIFF.FIFFV_COIL_VV_PLANAR_T1 in