Skip to content

Commit

Permalink
Merge 5cc5172 into 7e5b90e
Browse files Browse the repository at this point in the history
  • Loading branch information
dengemann committed Nov 3, 2014
2 parents 7e5b90e + 5cc5172 commit dce6fff
Show file tree
Hide file tree
Showing 9 changed files with 1,235 additions and 72 deletions.
118 changes: 54 additions & 64 deletions examples/decoding/plot_decoding_time_generalization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
========================================================
Decoding sensor space data with over-time generalization
========================================================
==========================================================
Decoding sensor space data with generalization across time
==========================================================
This example runs the analysis computed in:
Expand All @@ -14,78 +14,68 @@
"""
print(__doc__)

# Authors: Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# Authors: Jean-Remi King <jeanremi.king@gmail.com>
# Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# Denis Engemann <denis.engemann@gmail.com>
#
# License: BSD (3-clause)

import numpy as np
import matplotlib.pyplot as plt

import mne
from mne.datasets import spm_face
from mne.decoding import time_generalization
from mne.datasets import sample
from mne.decoding import GeneralizationAcrossTime

data_path = spm_face.data_path()
# --------------------------------------------------------------
# Preprocess data
# --------------------------------------------------------------

###############################################################################
data_path = sample.data_path()
# Load and filter data, set up epochs
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
events_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
raw = mne.io.Raw(raw_fname, preload=True)
picks = mne.pick_types(raw.info, meg=True, exclude='bads') # Pick MEG channels
raw.filter(1, 30, method='iir') # Band pass filtering signals
events = mne.read_events(events_fname)
event_id = {'AudL': 1, 'AudR': 2, 'VisL': 3, 'VisR': 4}
decim = 3 # decimate to make the example faster to run
epochs = mne.Epochs(raw, events, event_id, -0.050, 0.400, proj=True,
picks=picks, baseline=None, preload=True,
reject=dict(mag=5e-12), decim=decim)

raw_fname = data_path + '/MEG/spm/SPM_CTF_MEG_example_faces%d_3D_raw.fif'

raw = mne.io.Raw(raw_fname % 1, preload=True) # Take first run
raw.append(mne.io.Raw(raw_fname % 2, preload=True)) # Take second run too
# ----------------------------------------------------------------------------
# Generalization across time (GAT)
# ----------------------------------------------------------------------------
# The function implements the method used in:
# King, Gramfort, Schurger, Naccache & Dehaene, "Two distinct dynamic modes
# subtend the detection of unexpected sounds", PLOS ONE, 2013

picks = mne.pick_types(raw.info, meg=True, exclude='bads')
raw.filter(1, 45, method='iir')
# Define events of interest
y_vis_audio = (epochs.events[:, 2] <= 2).astype(np.int)

events = mne.find_events(raw, stim_channel='UPPT001')
event_id = {"faces": 1, "scrambled": 2}
tmin, tmax = -0.1, 0.5
gat = GeneralizationAcrossTime()
gat.fit(epochs, y=y_vis_audio)
gat.score(epochs, y=y_vis_audio)
gat.plot_diagonal() # plot decoding across time (correspond to GAT diagonal)
gat.plot() # plot full GAT matrix

# Set up pick list
picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=True, eog=True,
ref_meg=False, exclude='bads')

# Read epochs
decim = 4 # decimate to make the example faster to run
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
picks=picks, baseline=None, preload=True,
reject=dict(mag=1.5e-12), decim=decim)

epochs_list = [epochs[k] for k in event_id]
mne.epochs.equalize_epoch_counts(epochs_list)

###############################################################################
# Run decoding

# Compute Area Under the Curver (AUC) Receiver Operator Curve (ROC) score
# of time generalization. A perfect decoding would lead to AUCs of 1.
# Chance level is at 0.5.
# The default classifier is a linear SVM (C=1) after feature scaling.
scores = time_generalization(epochs_list, clf=None, cv=5, scoring="roc_auc",
shuffle=True, n_jobs=2)

###############################################################################
# Now visualize
times = 1e3 * epochs.times # convert times to ms

plt.figure()
plt.imshow(scores, interpolation='nearest', origin='lower',
extent=[times[0], times[-1], times[0], times[-1]],
vmin=0.1, vmax=0.9, cmap='RdBu_r')
plt.xlabel('Times Test (ms)')
plt.ylabel('Times Train (ms)')
plt.title('Time generalization (%s vs. %s)' % tuple(event_id.keys()))
plt.axvline(0, color='k')
plt.axhline(0, color='k')
plt.colorbar()

plt.figure()
plt.plot(times, np.diag(scores), label="Classif. score")
plt.axhline(0.5, color='k', linestyle='--', label="Chance level")
plt.axvline(0, color='r', label='stim onset')
plt.legend()
plt.xlabel('Time (ms)')
plt.ylabel('ROC classification score')
plt.title('Decoding (%s vs. %s)' % tuple(event_id.keys()))
plt.show()
# ----------------------------------------------------------------------------
# Generalization across time and across conditions
# ----------------------------------------------------------------------------
# As proposed in King & Dehaene (2014) 'Characterizing the dynamics of mental
# representations: the temporal generalization method', Trends In Cognitive
# Sciences, 18(4), 203-210.

gat = GeneralizationAcrossTime(predict_mode='independent')

# Train on visual versus audio: left stimuli only.
gat.fit(epochs[('AudL', 'VisL')])

# Test on visual versus audio: right stimuli only.
# In this case, because the test data is independent, we test the
# classifier of each folds and average their respective prediction:

gat.score(epochs[('AudR', 'VisR')])
gat.plot()
3 changes: 2 additions & 1 deletion mne/decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .mixin import TransformerMixin
from .csp import CSP
from .ems import compute_ems
from .time_gen import time_generalization
from .time_gen import time_generalization ## to be deprecated
from .time_gen import GeneralizationAcrossTime
98 changes: 97 additions & 1 deletion mne/decoding/tests/test_time_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@
import warnings
import os.path as op

from nose.tools import assert_true
from nose.tools import assert_equal, assert_true, assert_raises
import numpy as np

from mne import io, Epochs, read_events, pick_types
from mne.utils import requires_sklearn
from mne.decoding import time_generalization
from mne.decoding import GeneralizationAcrossTime


data_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(data_dir, 'test_raw.fif')
event_name = op.join(data_dir, 'test-eve.fif')

tmin, tmax = -0.2, 0.5
event_id = dict(aud_l=1, vis_l=3)
event_id_gen = dict(aud_l=2, vis_l=4)


@requires_sklearn
@requires_sklearn
def test_time_generalization():
"""Test time generalization decoding
Expand All @@ -40,3 +45,94 @@ def test_time_generalization():
assert_true(scores.shape == (n_times, n_times))
assert_true(scores.max() <= 1.)
assert_true(scores.min() >= 0.)


@requires_sklearn
def test_generalization_across_time():
"""Test time generalization decoding
"""
from sklearn.svm import SVC

raw = io.Raw(raw_fname, preload=False)
events = read_events(event_name)
picks = pick_types(raw.info, meg='mag', stim=False, ecg=False,
eog=False, exclude='bads')
picks = picks[1:13:3]
decim = 30

# Test on time generalization within one condition
with warnings.catch_warnings(record=True) as w:
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=True, decim=decim)

# Test default running
gat = GeneralizationAcrossTime()
gat.fit(epochs)
gat.predict(epochs)
gat.score(epochs)
gat.fit(epochs, y=epochs.events[:, 2])
gat.score(epochs, y=epochs.events[:, 2])

# Test basics
# --- number of trials
assert_true(gat.y_train_.shape[0] ==
gat.y_true_.shape[0] ==
gat.y_pred_.shape[2] == 14)
# --- number of folds
assert_true(np.shape(gat.estimators_)[1] == gat.cv)
# --- length training size
assert_true(len(gat.train_times['slices']) == 15 ==
np.shape(gat.estimators_)[0])
# --- length testing sizes
assert_true(len(gat.test_times_['slices']) == 15 ==
np.shape(gat.scores_)[0])
assert_true(len(gat.test_times_['slices'][0]) == 15
== np.shape(gat.scores_)[1])

# Test longer time window
gat = GeneralizationAcrossTime(train_times={'length': .100})
gat2 = gat.fit(epochs)
assert_true(gat is gat2) # return self
scores = gat.score(epochs)
assert_true(isinstance(scores, list)) # type check
assert_equal(len(scores[0]), len(scores)) # shape check

assert_equal(len(gat.test_times_['slices'][0][0]), 2)
# Decim training steps
gat = GeneralizationAcrossTime(train_times={'step': .100})
gat.fit(epochs)
gat.score(epochs)
assert_equal(len(gat.scores_), 8)

# Test start stop training
gat = GeneralizationAcrossTime(train_times={'start': 0.090,
'stop': 0.250})
# predict without fit
assert_raises(RuntimeError, gat.predict, epochs)
gat.fit(epochs)
gat.score(epochs)
assert_equal(len(gat.scores_), 4)
assert_equal(gat.train_times['times_'][0], epochs.times[6])
assert_equal(gat.train_times['times_'][-1], epochs.times[9])

# Test diagonal decoding
gat = GeneralizationAcrossTime()
gat.fit(epochs)
scores = gat.score(epochs, test_times='diagonal')
assert_true(scores is gat.scores_)
assert_equal(np.shape(gat.scores_), (15, 1))

# Test generalization across conditions
gat = GeneralizationAcrossTime(predict_mode='independent')
gat.fit(epochs[0:6])
gat.predict(epochs[7:])
assert_raises(ValueError, gat.predict, epochs, test_times='hahahaha')
gat.score(epochs[7:])

svc = SVC(C=1, kernel='linear', probability=True)
gat = GeneralizationAcrossTime(clf=svc, predict_type='proba')
gat.fit(epochs)
scores = gat.score(epochs)
scores = sum(scores, []) # flatten
assert_true(0.0 <= min(scores) <= 1.0)
assert_true(0.0 <= max(scores) <= 1.0)
Loading

0 comments on commit dce6fff

Please sign in to comment.