Skip to content

Commit

Permalink
MRG: refactor stuff (mne-tools#5084)
Browse files Browse the repository at this point in the history
* refactor stuff
  • Loading branch information
jona-sassenhagen authored and agramfort committed May 9, 2018
1 parent a277bdb commit a11c465
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 152 deletions.
22 changes: 10 additions & 12 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ..externals.six import string_types
from ..utils import verbose, logger, warn, copy_function_doc_to_method_doc
from ..utils import _check_preload
from ..utils import _check_preload, _validate_type
from ..io.compensator import get_current_comp
from ..io.constants import FIFF
from ..io.meas_info import anonymize_info, Info
Expand Down Expand Up @@ -69,9 +69,7 @@ def _contains_ch_type(info, ch_type):
has_ch_type : bool
Whether the channel type is present or not.
"""
if not isinstance(ch_type, string_types):
raise ValueError('`ch_type` is of class {actual_class}. It must be '
'`str`'.format(actual_class=type(ch_type)))
_validate_type(ch_type, string_types, "ch_type")

meg_extras = ['mag', 'grad', 'planar1', 'planar2']
fnirs_extras = ['hbo', 'hbr']
Expand Down Expand Up @@ -127,10 +125,11 @@ def equalize_channels(candidates, verbose=None):
from ..evoked import Evoked
from ..time_frequency import AverageTFR

if not all(isinstance(c, (BaseRaw, BaseEpochs, Evoked, AverageTFR))
for c in candidates):
raise ValueError('candidates must be Raw, Epochs, Evoked, or '
'AverageTFR')
for candidate in candidates:
_validate_type(candidate,
(BaseRaw, BaseEpochs, Evoked, AverageTFR),
"Instances to be modified",
"Raw, Epochs, Evoked or TFR")

chan_max_idx = np.argmax([c.info['nchan'] for c in candidates])
chan_template = candidates[chan_max_idx].ch_names
Expand Down Expand Up @@ -847,8 +846,8 @@ def add_channels(self, add_list, force_update_info=False):
raise AssertionError('Input must be a list or tuple of objs')

# Object-specific checks
if not all([inst.preload for inst in add_list] + [self.preload]):
raise AssertionError('All data must be preloaded')
for inst in add_list + [self]:
_check_preload(inst, "adding channels")
if isinstance(self, BaseRaw):
con_axis = 0
comp_class = BaseRaw
Expand Down Expand Up @@ -921,8 +920,7 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
"""
from .interpolation import _interpolate_bads_eeg, _interpolate_bads_meg

if getattr(self, 'preload', None) is False:
raise ValueError('Data must be preloaded.')
_check_preload(self, "interpolation")

if len(self.info['bads']) == 0:
warn('No bad channels to interpolate. Doing nothing...')
Expand Down
4 changes: 2 additions & 2 deletions mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_interpolation():

# check that interpolation fails when preload is False
epochs_eeg.preload = False
assert_raises(ValueError, epochs_eeg.interpolate_bads)
assert_raises(RuntimeError, epochs_eeg.interpolate_bads)
epochs_eeg.preload = True

# check that interpolation changes the data in raw
Expand All @@ -98,7 +98,7 @@ def test_interpolation():
assert hasattr(inst, 'preload')
inst.preload = False
inst.info['bads'] = [inst.ch_names[1]]
assert_raises(ValueError, inst.interpolate_bads)
assert_raises(RuntimeError, inst.interpolate_bads)

# check that interpolation works with few channels
raw_few = raw.copy().crop(0, 0.1).load_data()
Expand Down
2 changes: 1 addition & 1 deletion mne/io/fiff/tests/test_raw_fiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ def test_add_channels():
raw_badsf.info['sfreq'] = 3.1415927
raw_eeg.crop(.5)

assert_raises(AssertionError, raw_meg.add_channels, [raw_nopre])
assert_raises(RuntimeError, raw_meg.add_channels, [raw_nopre])
assert_raises(RuntimeError, raw_meg.add_channels, [raw_badsf])
assert_raises(AssertionError, raw_meg.add_channels, [raw_eeg])
assert_raises(ValueError, raw_meg.add_channels, [raw_meg])
Expand Down
15 changes: 6 additions & 9 deletions mne/io/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .proc_history import _read_proc_history, _write_proc_history
from ..transforms import _to_const
from ..transforms import invert_transform
from ..utils import logger, verbose, warn, object_diff
from ..utils import logger, verbose, warn, object_diff, _validate_type
from .. import __version__
from ..externals.six import b, BytesIO, string_types, text_type
from .compensator import get_current_comp
Expand Down Expand Up @@ -1741,8 +1741,8 @@ def create_info(ch_names, sfreq, ch_types=None, montage=None, verbose=None):
"""
if isinstance(ch_names, int):
ch_names = list(np.arange(ch_names).astype(str))
if not isinstance(ch_names, (list, tuple)):
raise TypeError('ch_names must be a list, tuple, or int')
_validate_type(ch_names, (list, tuple), "ch_names",
("list, tuple, or int"))
sfreq = float(sfreq)
if sfreq <= 0:
raise ValueError('sfreq must be positive')
Expand All @@ -1756,10 +1756,8 @@ def create_info(ch_names, sfreq, ch_types=None, montage=None, verbose=None):
'(%s != %s)' % (len(ch_types), nchan))
info = _empty_info(sfreq)
for ci, (name, kind) in enumerate(zip(ch_names, ch_types)):
if not isinstance(name, string_types):
raise TypeError('each entry in ch_names must be a string')
if not isinstance(kind, string_types):
raise TypeError('each entry in ch_types must be a string')
_validate_type(name, string_types, "each entry in ch_names")
_validate_type(kind, string_types, "each entry in ch_types")
if kind not in _kind_dict:
raise KeyError('kind must be one of %s, not %s'
% (list(_kind_dict.keys()), kind))
Expand Down Expand Up @@ -1876,8 +1874,7 @@ def anonymize_info(info):
-----
Operates in place.
"""
if not isinstance(info, Info):
raise ValueError('self must be an Info instance.')
_validate_type(info, Info, "self", "Info")
if info.get('subject_info') is not None:
del info['subject_info']
info['meas_date'] = None
Expand Down
2 changes: 1 addition & 1 deletion mne/io/tests/test_meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def test_check_consistency():

def test_anonymize():
"""Test that sensitive information can be anonymized."""
pytest.raises(ValueError, anonymize_info, 'foo')
pytest.raises(TypeError, anonymize_info, 'foo')

# Fake some subject data
raw = read_raw_fif(raw_fname)
Expand Down
5 changes: 3 additions & 2 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,7 +1837,7 @@ def test_contains():
assert_true(not any(o in epochs for o in others))

assert_raises(ValueError, epochs.__contains__, 'foo')
assert_raises(ValueError, epochs.__contains__, 1)
assert_raises(TypeError, epochs.__contains__, 1)


def test_drop_channels_mixin():
Expand Down Expand Up @@ -2190,7 +2190,8 @@ def test_add_channels():
epoch_badsf.info['sfreq'] = 3.1415927
epoch_eeg = epoch_eeg.crop(-.1, .1)

assert_raises(AssertionError, epoch_meg.add_channels, [epoch_nopre])
epoch_meg.load_data()
assert_raises(RuntimeError, epoch_meg.add_channels, [epoch_nopre])
assert_raises(RuntimeError, epoch_meg.add_channels, [epoch_badsf])
assert_raises(AssertionError, epoch_meg.add_channels, [epoch_eeg])
assert_raises(ValueError, epoch_meg.add_channels, [epoch_meg])
Expand Down
14 changes: 14 additions & 0 deletions mne/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2740,3 +2740,17 @@ def open_docs(kind=None, version=None):
raise ValueError('version must be one of %s, got %s'
% (version, versions))
webbrowser.open_new_tab('https://martinos.org/mne/%s/%s' % (version, kind))


def _validate_type(item, types, item_name, type_name=None):
"""Validate that `item` is an instance of `types`."""
if type_name is None:
if types == string_types:
type_name = "str"
else:
iter_types = ([types] if not isinstance(types, (list, tuple))
else types)
type_name = ', '.join(cls.__name__ for cls in iter_types)
if not isinstance(item, types):
raise TypeError(item_name, ' must be an instance of ', type_name,
', got %s instead.' % (type(item),))
25 changes: 10 additions & 15 deletions mne/viz/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_connection_line, COLORS, _setup_ax_spines,
_setup_plot_projector, _prepare_joint_axes,
_set_title_multiple_electrodes, _check_time_unit)
from ..utils import logger, _clean_names, warn, _pl, verbose
from ..utils import logger, _clean_names, warn, _pl, verbose, _validate_type

from .topo import _plot_evoked_topo
from .topomap import (_prepare_topo_plot, plot_topomap, _check_outlines,
Expand Down Expand Up @@ -1497,18 +1497,15 @@ def _format_evokeds_colors(evokeds, cmap, colors):
raise ValueError('If evokeds is a dict and a cmap is passed, '
'you must specify the colors.')
for cond in evokeds.keys():
if not isinstance(cond, string_types):
raise TypeError('Conditions must be str, not %s' % (type(cond),))
_validate_type(cond, string_types, "Conditions")
# Now make sure all values are list of Evoked objects
evokeds = {condition: [v] if isinstance(v, Evoked) else v
for condition, v in evokeds.items()}

# Check that all elements are of type evoked
for this_evoked in evokeds.values():
for ev in this_evoked:
if not isinstance(ev, Evoked):
raise ValueError("Not all elements are Evoked "
"object. Got %s" % type(this_evoked))
_validate_type(ev, Evoked, "All evokeds entries ", "Evoked")

# Check that all evoked objects have the same time axis and channels
all_evoked = sum(evokeds.values(), [])
Expand Down Expand Up @@ -1798,9 +1795,7 @@ def plot_compare_evokeds(evokeds, picks=None, gfp=False, colors=None,

if vlines == "auto" and (tmin < 0 and tmax > 0):
vlines = [0.]
if not isinstance(vlines, (list, tuple)):
raise TypeError(
"vlines must be a list or tuple, not %s" % type(vlines))
_validate_type(vlines, (list, tuple), "vlines", "list or tuple")

if isinstance(picks, Integral):
picks = [picks]
Expand All @@ -1809,9 +1804,10 @@ def plot_compare_evokeds(evokeds, picks=None, gfp=False, colors=None,
gfp = True
picks = _pick_data_channels(info)

if not isinstance(picks, (list, np.ndarray)):
raise TypeError("picks should be a list or np.array of integers. "
"Got %s." % type(picks))
_validate_type(picks, (list, np.ndarray), "picks",
"list or np.array of integers")
for entry in picks:
_validate_type(entry, Integral, "entries of picks", "integers")

if len(picks) == 0:
raise ValueError("No valid channels were found to plot the GFP. " +
Expand Down Expand Up @@ -2052,9 +2048,8 @@ def plot_compare_evokeds(evokeds, picks=None, gfp=False, colors=None,
head_pos = {'center': (0, 0), 'scale': (0.5, 0.5)}
pos, outlines = _check_outlines(pos, np.array([1, 1]), head_pos)

if not isinstance(show_sensors, (np.int, bool, str)):
raise TypeError("show_sensors must be numeric, str or bool, "
"not " + str(type(show_sensors)))
_validate_type(show_sensors, (np.int, bool, str),
"show_sensors", "numeric, str or bool")
show_sensors = _check_loc_legal(show_sensors, "show_sensors")
_plot_legend(pos, ["k"] * len(picks), ax, list(), outlines,
show_sensors, size=25)
Expand Down
Loading

0 comments on commit a11c465

Please sign in to comment.