-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_decoding_group_gen_across_time.py
90 lines (70 loc) · 2.86 KB
/
run_decoding_group_gen_across_time.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-
"""
Created on Thu Jul 27 09:24:19 2017
@author: claire
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
#from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
import mne
from mne.decoding import (SlidingEstimator, GeneralizingEstimator,
cross_val_multiscore, LinearModel, get_coef)
import os
from sklearn.preprocessing import LabelEncoder
cond= 'imag'
ana_path = '/home/claire/DATA/Data_Face_House_new_proc'
results_path = os.path.join(ana_path, 'Decoding_Gen_across_time')
exclude = [7]
all_epochs=list()
# We start by exploring the frequence content of our epochs.
for subject_id in range(1,26):
if subject_id in exclude:
continue
subject = 'S%02d' %subject_id
data_path = os.path.join(ana_path, subject , 'EEG', 'New_Preproc')
fname_in = os.path.join(data_path, '%s-causal-highpass-2Hz-epo.fif' %subject)
epochs=mne.read_epochs(fname_in)
epochs.interpolate_bads()
all_epochs.append(epochs)
epochs = epochs=mne.concatenate_epochs(all_epochs)
epochs=epochs[cond]
#epochs.crop(tmin=-0.2, tmax=1.5)
# fit and time decoder
le=LabelEncoder()
X = epochs.get_data() # MEG signals: n_epochs, n_channels, n_times
y = le.fit_transform(epochs.events[:, 2]) # target: Audio left or right
clf = make_pipeline(StandardScaler(), LogisticRegression())
# define the Temporal Generalization object
time_gen = GeneralizingEstimator(clf, n_jobs=1, scoring='roc_auc')
scores = cross_val_multiscore(time_gen, X, y, cv=5, n_jobs=6)
# Mean scores across cross-validation splits
scores = np.mean(scores, axis=0)
# get chance level
class_balance = np.mean(y == y[0])
class_balance = max(class_balance, 1. - class_balance)
# Plot the diagonal (it's exactly the same as the time-by-time decoding above)
fig, ax = plt.subplots()
ax.plot(epochs.times, np.diag(scores), label='score')
ax.axhline(class_balance, color='k', linestyle='--', label='chance')
ax.set_xlabel('Times')
ax.set_ylabel('AUC')
ax.legend()
ax.axvline(.0, color='k', linestyle='-')
ax.set_title('Decoding EEG sensors over time - %s all subjects' %cond)
plt.show()
plt.savefig(os.path.join(results_path , ' %s_gen_across_time_all_subj.pdf' %cond), bbox_to_inches='tight')
# Plot the full matrix
fig, ax = plt.subplots(1, 1)
im = ax.imshow(scores, interpolation='lanczos', origin='lower', cmap='RdBu_r',
extent=epochs.times[[0, -1, 0, -1]], vmin=0., vmax=1.)
ax.set_xlabel('Testing Time (s)')
ax.set_ylabel('Training Time (s)')
ax.set_title('Temporal Generalization - %s subject' %cond)
ax.axvline(0, color='k')
ax.axhline(0, color='k')
plt.colorbar(im, ax=ax)
plt.show()
plt.savefig(os.path.join(results_path , ' %s_gen_across_time_matrix_all_subj.pdf' %cond), bbox_to_inches='tight')