-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
241 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
""" | ||
============================================== | ||
Compute effect-matched-spatial filtering (EMS) | ||
============================================== | ||
This example computes the EMS to reconstruct the time course of | ||
the experimental effect as described in: | ||
Aaron Schurger, Sebastien Marti, and Stanislas Dehaene, "Reducing multi-sensor | ||
data to a single time course that reveals experimental effects", | ||
BMC Neuroscience 2013, 14:122 | ||
""" | ||
|
||
# Author: Denis Engemann <denis.engemann@gmail.com> | ||
# | ||
# License: BSD (3-clause) | ||
|
||
|
||
print(__doc__) | ||
|
||
import os.path as op | ||
import numpy as np | ||
|
||
import mne | ||
from mne import fiff | ||
from mne.datasets import sample | ||
from mne.epochs import combine_event_ids | ||
data_path = sample.data_path() | ||
|
||
# Set parameters | ||
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif' | ||
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' | ||
event_ids = {'AudL': 1, 'VisL': 2} | ||
tmin = -0.2 | ||
tmax = 0.5 | ||
|
||
# Setup for reading the raw data | ||
raw = fiff.Raw(raw_fname, preload=True) | ||
raw.filter(1, 45) | ||
events = mne.read_events(event_fname) | ||
|
||
# Set up pick list: EEG + STI 014 - bad channels (modify to your needs) | ||
include = [] # or stim channels ['STI 014'] | ||
raw.info['bads'] += ['EEG 053'] # bads + 1 more | ||
|
||
# pick EEG channels | ||
picks = fiff.pick_types(raw.info, meg='grad', eeg=False, stim=False, eog=True, | ||
include=include, exclude='bads') | ||
# Read epochs | ||
|
||
reject = dict(grad=4000e-13, eog=150e-6) | ||
# reject = dict(mag=4e-12, eog=150e-6) | ||
|
||
epochs = mne.Epochs(raw, events, event_ids, tmin, tmax, picks=picks, | ||
baseline=None, reject=reject) | ||
|
||
# Let's equalize the trial counts in each condition | ||
epochs.equalize_event_counts(['AudL', 'VisL'], copy=False) | ||
# Now let's combine some conditions | ||
|
||
picks2 = fiff.pick_types(epochs.info, meg='grad', exclude='bads') | ||
|
||
data = epochs.get_data()[:, picks2, :] | ||
|
||
# the matlab routine expects n_sensors, n_times, n_epochs | ||
|
||
data2 = np.transpose(data, [1, 2, 0]) | ||
|
||
# # create bool indices | ||
conditions = [epochs.events[:, 2] == 1, epochs.events[:, 2] == 2] | ||
|
||
# # matlab io functions don't deal with bool values | ||
# # so we need tom make a detour via int | ||
conditions = [c.astype(int) for c in conditions] | ||
|
||
|
||
############################################################################### | ||
# Now it's time for some hacking ... | ||
|
||
from scipy import io | ||
|
||
io.savemat('epochs_data.mat', {'data': data2, | ||
'conditions': conditions}) | ||
|
||
var_name1, var_name2 = 'surrogates', 'spatial_filter' | ||
my_pwd = op.abspath(op.curdir) # expand path | ||
|
||
# this requires | ||
# https://gist.github.com/dengemann/640d202f84befff1545d | ||
# in the local directory | ||
|
||
my_matlab_code = """ | ||
disp('reading data ...'); | ||
epochs = load('epochs_data.mat'); | ||
conditions = boolean(epochs.conditions'); | ||
disp('computing trial surrogates'); | ||
[{0}, {1}] = ems_ncond(epochs.data, conditions); | ||
disp('saving results ...'); | ||
save('{pwd}/{0}.mat', '{0}'); | ||
save('{pwd}/{1}.mat', '{1}'); | ||
quit; | ||
""".format(var_name1, var_name2, pwd=my_pwd).strip('\n').replace('\n', '') | ||
|
||
run_matlab = ['matlab', '-nojvm', '-nodesktop', '-nodisplay', '-r'] | ||
|
||
run_matlab.append(my_matlab_code) | ||
|
||
from subprocess import Popen, PIPE | ||
|
||
process = Popen(run_matlab, stdin=PIPE, stdout=None, shell=False) | ||
|
||
process.communicate() # call and quit matlab | ||
|
||
surrogates = io.loadmat(var_name1 + '.mat')[var_name1] | ||
spatial_filter = io.loadmat(var_name2 + '.mat')[var_name2] | ||
|
||
from mne.decoding import compute_ems | ||
|
||
iter_comparisons = [ | ||
(surrogates, spatial_filter), | ||
compute_ems(data, conditions) | ||
] | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
for ii, (tsurrogate, sfilter) in enumerate(iter_comparisons): | ||
|
||
lang = 'python' if ii > 0 else 'matlab' | ||
|
||
order = epochs.events[:, 2].argsort() | ||
times = epochs.times * 1e3 | ||
|
||
plt.figure() | ||
plt.title('single surrogate trial - %s' % lang) | ||
plt.imshow(surrogates[order], origin='lower', aspect='auto', | ||
extent=[times[0], times[-1], 1, len(epochs)]) | ||
plt.xlabel('Time (ms)') | ||
plt.ylabel('Trials (reordered by condition)') | ||
plt.savefig('fig-%s-1.png' % lang) | ||
|
||
plt.figure() | ||
plt.title('Average EMS signal - %s' % lang) | ||
for key, value in epochs.event_id.items(): | ||
ems_ave = surrogates[epochs.events[:, 2] == value] | ||
ems_ave /= 4e-11 | ||
plt.plot(times, ems_ave.mean(0), label=key) | ||
plt.xlabel('Time (ms)') | ||
plt.ylabel('fT/cm') | ||
plt.legend(loc='best') | ||
plt.savefig('fig-%s-2.png' % lang) | ||
|
||
# visualize spatial filter | ||
evoked = epochs.average() | ||
evoked.data = spatial_filter | ||
evoked.plot_topomap(ch_type='grad', title=lang) | ||
plt.savefig('fig-%s-3.png' % lang) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import numpy as np | ||
# import scipy | ||
# import sklearn | ||
from scipy.linalg import norm | ||
|
||
|
||
def compute_ems(data, conditions, objective_function=None): | ||
"""Compute event-matched spatial filter | ||
This version operates on the entire timecourse. No time window needs to | ||
be specified. The result is a spatial filter at each time point and a | ||
corresponding timecourse. Intuitively, the result gives the similarity | ||
between the filter at each time point and the data vector (sensors) at | ||
that timepoint. | ||
References | ||
---------- | ||
[1] Aaron Schurger, Sebastien Marti, and Stanislas Dehaene, "Reducing | ||
multi-sensor data to a single time course that reveals experimental | ||
effects", BMC Neuroscience 2013, 14:122 | ||
Parameters | ||
---------- | ||
data : numpy.ndarray (n_epochs, n_channels, n_times) | ||
The data matrix | ||
conditions : list-like | ||
a list or an array of indices or bool arrays. | ||
objective_function : callable | ||
The objective function to maximize. Must comply with the following | ||
API: | ||
def objective_function(data, conditions, **kwargs): | ||
... | ||
return numpy.ndarray (n_channels, n_times) | ||
If None, the difference function as described in [1] | ||
Returns | ||
------- | ||
surrogate_trials : numpy.ndarray (trials, n_trials, n_time_points) | ||
The trial surrogates. | ||
mean_spatial_filter : instance of numpy.ndarray (n_channels, n_times) | ||
The set of spatial filters. | ||
""" | ||
|
||
n_epochs, n_channels, n_times = data.shape | ||
spatial_filter = np.zeros((n_channels, n_times)) | ||
surrogate_trials = np.zeros((n_epochs, n_times)) | ||
|
||
if objective_function is None: | ||
objective_function = _ems_diff | ||
|
||
from sklearn.cross_validation import LeaveOneOut | ||
|
||
loo = LeaveOneOut(n_epochs) | ||
for train_indices, epoch_idx in loo: | ||
print('.. processing epoch %i' % epoch_idx) | ||
d = objective_function(data, conditions, train_indices) | ||
for time_idx in np.arange(n_times): | ||
d[:, time_idx] /= norm(d[:, time_idx]) | ||
|
||
# update spatial filter | ||
spatial_filter += d | ||
# take norm over channels | ||
surrogate_trials[epoch_idx] = np.sum(np.squeeze(data[epoch_idx]) | ||
* spatial_filter, axis=0) | ||
|
||
# compute surrogates | ||
|
||
spatial_filter /= n_epochs | ||
|
||
return surrogate_trials, spatial_filter | ||
|
||
|
||
def _ems_diff(data, conditions, train): | ||
"""defaut diff objective function | ||
""" | ||
|
||
sum1, sum2 = [data[conditions[i]].sum(axis=0) for i in [0, 1]] | ||
n1, n2 = conditions[0].sum(), conditions[1].sum() | ||
m1 = (sum1 - data[train].sum(axis=0)) / (n1 - len(train)) | ||
m2 = (sum2 - data[train].sum(axis=0)) / (n2 - len(train)) | ||
|
||
return m1 - m2 |