<font size="+12"><center>
    MVPA analysis: Recursive Feature Elimination
</font></center>

Author:
Egor Ananyev

# Preparation

## Loading packages

In [1]:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt
import mne
import pandas as pd
import os

In [2]:
from sklearn import svm
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import ShuffleSplit
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from sklearn.feature_selection import RFE

In [3]:
mne.set_log_level('warning')  # less verbose output

## Setting parameters

In [4]:
interactive = False  # Whether to render interactive figures such as channel plots
# ...should be set to True on first pass

debug = False

cur_subj = 1
numof_runs = 7  # TEMP; is 7 for subj 01

## Loading evoked data

In [5]:
# data_path = os.path.expanduser("C:\\Users\\egora\\Downloads\\meg\\")
data_path = os.path.expanduser('E:\\meg\\')

cur_subj_str = str(cur_subj).zfill(2)

for cur_run in range(numof_runs):
    cur_run_str = str(cur_run+1).zfill(2)
    fname = os.path.join(data_path,
                             'derivatives\\meg_derivatives\\sub-' + cur_subj_str + '\\ses-meg\\meg\\sub-' + \
                             cur_subj_str + '_ses-meg_experimental_run-' + cur_run_str + '_proc-sss_300_epo.fif')
    print(fname)
    epochs_run = mne.read_epochs(fname)
    # example append: https://www.programcreek.com/python/example/92634/mne.Epochs
    epochs_run = epochs_run.pick_types(meg=True)
    epochs_run_df = epochs_run.to_data_frame()
    # Normalizing the signal by the baseline:
    epochs_run_std = epochs_run_df.sort_index(level=['condition', 'epoch', 'time'], ascending=[1, 1, 1])
    epochs_run_std = epochs_run_std.loc[pd.IndexSlice[:, :, -100:-1], :].groupby(['condition', 'epoch']).std()
    epochs_run_norm = epochs_run_df / epochs_run_std
    if cur_run == 0:
        epochs_df = epochs_run_norm
    else: 
        epochs_df = epochs_df.append(epochs_run_norm)

E:\meg\derivatives\meg_derivatives\sub-01\ses-meg\meg\sub-01_ses-meg_experimental_run-01_proc-sss_300_epo.fif
E:\meg\derivatives\meg_derivatives\sub-01\ses-meg\meg\sub-01_ses-meg_experimental_run-02_proc-sss_300_epo.fif
E:\meg\derivatives\meg_derivatives\sub-01\ses-meg\meg\sub-01_ses-meg_experimental_run-03_proc-sss_300_epo.fif
E:\meg\derivatives\meg_derivatives\sub-01\ses-meg\meg\sub-01_ses-meg_experimental_run-04_proc-sss_300_epo.fif
E:\meg\derivatives\meg_derivatives\sub-01\ses-meg\meg\sub-01_ses-meg_experimental_run-05_proc-sss_300_epo.fif
E:\meg\derivatives\meg_derivatives\sub-01\ses-meg\meg\sub-01_ses-meg_experimental_run-06_proc-sss_300_epo.fif
E:\meg\derivatives\meg_derivatives\sub-01\ses-meg\meg\sub-01_ses-meg_experimental_run-07_proc-sss_300_epo.fif


## Output path

The path to store MVPA-related data and visualization.

In [6]:
mvpa_path = os.path.join(data_path, 'derivatives\\meg_derivatives\\sub-' + cur_subj_str + '\\ses-meg\\meg-mvpa\\')
mvpa_fname = os.path.join(mvpa_path, 'sub-' + cur_subj_str)
print(mvpa_fname)

E:\meg\derivatives\meg_derivatives\sub-01\ses-meg\meg-mvpa\sub-01


In [12]:
all_times = np.round(epochs_run.times * 1000).astype(int)
# print(all_times)
print(np.shape(all_times))
print(all_times[56])

(211,)
87


# Recurrent Feature Elimination

In [13]:
t = 87
X_right = epochs_df[epochs_df.index.get_level_values('condition').str.contains('right')].loc[pd.IndexSlice[:, :, t], :]
X_left = epochs_df[epochs_df.index.get_level_values('condition').str.contains('left')].loc[pd.IndexSlice[:, :, t], :]
X_ = np.concatenate((X_right, X_left))
y_ = np.concatenate((np.repeat(1, len(X_right)), np.repeat(0, len(X_left))))

In [14]:
# def run_svm(X_, y_, n_splits=10, track=True):
n_splits=10
track = True
acc_list = []
ss = ShuffleSplit(n_splits=n_splits, test_size=0.2)
for train_index, test_index in ss.split(X_):
    X_train = X_[train_index]
    y_train = y_[train_index]
    X_test = X_[test_index]
    y_test = y_[test_index]
    clf = svm.SVC(kernel='linear', cache_size=2000).fit(X_train, y_train)  #, gamma=0.001
    y_pred = clf.predict(X_test)
    this_acc = accuracy_score(y_test, y_pred)
#     print(this_acc)
    acc_list.append(this_acc)
if track:
    print('{0:.2f}'.format(np.mean(acc_list)), end=' ')

0.54 

In [17]:
# def run_svm(X_, y_, n_splits=10, track=True):
n_splits=1
track = True
# acc_list = []
ss = ShuffleSplit(n_splits=n_splits, test_size=0.2)
for train_index, test_index in ss.split(X_):
    X_train = X_[train_index]
    y_train = y_[train_index]
    X_test = X_[test_index]
    y_test = y_[test_index]
    clf = svm.SVC(kernel='linear', cache_size=2000)  #.fit(X_train, y_train)  #, gamma=0.001
    rfe = RFE(estimator=clf, n_features_to_select=1, step=1)
    rfe.fit(X_train, y_train)
    ranking = rfe.ranking_  #.reshape(digits.images[0].shape)
#     print(accuracy_score(y_test, y_pred))
#     y_pred = clf.predict(X_test)
#     this_acc = accuracy_score(y_test, y_pred)
#     acc_list.append(this_acc)
# if track:
#     print('{0:.2f}'.format(np.mean(acc_list)), end=' ')
# return np.mean(acc_list)

In [18]:
print(ranking)

[119 257 228 150 117 293 235 259 185 118 188 166  28 224 147 260  68 158
 141 273  29 193 225 132  32  76 112  31  67 200 253  88 145  75 286  30
 153 192 144 277 137 149  94  50  98 218 229 285  39 271  41  36  34  42
 199 221  38 268 220 169  54 219  53  58 276 109 146  26 148  89 278 143
 242  99  25 249 100  56 251 306 240 177  24 263 208 163 135 198 248 250
  33 201 136  73  72 234 139  74 187 255 246  95  40 213  52 284  96  55
  59 296 120 140 264 262 275 116  27  60 113 159 151 197 212 300 115 160
 203 279 205 108 130 304 206  71 129 173  57 131 196  70 114 106 254 128
  47 124 127 195 288 292 182 178 289 180 245 179 244 189 183 176 266 267
 283  21 152   3 238  19 138 110 269 298  20 207 103 239 214 105 155 142
 126 122   7   4   2 299 215 297 230  45 237 181 265 236  69   6  93  16
 252 156 301   8  49 302   9  87 233 291 186 202 305 184 190  81 281  77
   5  14  66  13 270 295 104 294 272 232  86  48 290 209 258 247  91 287
 191 226  78  11 282  62  37  90 168 241  92  17  2

In [16]:
print(ranking)

[203 266 130 254 204 278 121 112 236 200 274  57 247 260 174  38   5 240
 108 242  36 111 243 173   6 196 241  37   4 128 134 123 299 225 192  35
 284 221 110 107 252 220  92 263 122  47 159 179 296  96  33 289  32  34
 256 283  31 286 304 248   7 158  16  60  27  23 120  46  64 213 267 253
  53 305  45 170 109  72 287  71  74  54 166 129  10 219  88 275 165 148
   9  65  89  67 195  90 185  52 191 190 164  66   8  25  15 184  68 136
  59  26 257 205 216 290 250 258  24 106 209  14 114 300 115 303 175 189
  73  97  91  76 292 207  93 288 285 306 135  13 118  86 282  50 227 294
  49 153 293 268  87 229 298 224 302 171 172  75  70 183 187  94 201 160
 234 261 230  79  22 215 199 181  29 271 119  28  99 262 163 297 140 264
 218 139  40  84  19 162 280 180 251 270 141  20 157 142  41 198  83  42
  44 104 222 126  17 103 232   1 167 231  18 255 212 210 146  95 259  11
 145  55 161 272  21 228 269  39  58 246 127 245 244 178  62 239 211 147
 168  82  12 277 301 214 156  56 238 197 149  43   

In [11]:
print(ranking)

[224 252 276 286 211 214 155 204 271 289 290  14  19 148 210 196 193 181
  63  90  17 250 231  34  26 146  68  25  23 191 161 285 261 194 134 189
 199 266 109 198 113 303 283 111 123 192 302 179  27 108  69   4 133  62
 295 260 141 168 132  24 145 291  92 241 177 225 112  29 235  32 223 267
 171 122  22 269 275 129 175 195 265  47  21 173  67 251 274  93 259 239
   3  58 209  56 227 207 233  97  76 305 264  59  75  95 169  96  99 170
 279 176 184 268  28  33   9  66 300  53 243  80  82 248 238   8  65 299
  57  77 116 304 263  78 217  11 216 126  94  51 128   5 242 188 234  79
 136 156 124 249   6 301 296 221 117 183 245 107 187 178 104 294 293 114
 154 205 229  12  13  83  49 186 185  48 158 215 306 254  36 105  91 255
  30 256  20   1  15  35 232 153  84  52 103 218   2 284 131  10 152  44
 167  71  50 149 110 172  46 226 240  31 164 236  88  40  38  18 287 262
 237  16 130 277  39 127 273 219 280 206  55 200  42 272 298 213  41  37
  43 180 278 230 246 253  70 150 182 270 151  45   

In [23]:
print(ranking)

[162 228 200  66 259 253 207 260 148 281 147  48   8 267 269 288 159 301
  49 118  33 105 261  50   6  65  32 154  44 272 156 170 284   7  67 250
 280 220 286  59 146  57  58 145 108 242 117  98  21  23  31  22  35  30
 173 224  88 252  45 265  42  89  37 293 285  38 122  61  56  43 303 299
 273 236 143 251 211  80 229 195 109  47  75 264  81 203 130 297  86 196
  26  28 139  27 121 209 131 268 240 175  87  29 155  40  36 134  84 237
 127 201  85  62  79 295  18 100 245  39  72 194 217 213 193  17  93 212
 179 278 233 306 294 234  77  82 158  78  41 241 188  83 248 290 216 166
 161 129 178 182 255 133 227 140 226 136 305 243 125 190 191 225 174 135
 266  53 151 113  20 215  51 296 262  52 160 263 292 142  63 189 115 230
  19 114 171   1  90  64 132   3  60 112  46  97   2  55  99  14   4  15
 202 231 289 270 144 239  34 107 168  13 214 276 302  12 199  95 164  94
 254  54   5 271 197 300 126 169 163 106 123 149  76 192 101 246   9 102
  96 120 104  69 183 204 181 298 279 275 206  16  1

In [42]:
len(epochs_run.info['chs'])

306

In [68]:
for cur_ind in range(len(epochs_run.info['chs'])):
    if cur_ind == 0:
        pos = epochs_run.info['chs'][cur_ind]['loc'][0:2]
    else: 
        pos = np.vstack((pos, epochs_run.info['chs'][cur_ind]['loc'][0:2]))

In [69]:
pos

array([[-1.06600001e-01,  4.63999994e-02],
       [-1.06600001e-01,  4.63999994e-02],
       [-1.06600001e-01,  4.63999994e-02],
       [-1.01999998e-01,  6.31000027e-02],
       [-1.01999998e-01,  6.31000027e-02],
       [-1.01999998e-01,  6.31000027e-02],
       [-1.08499996e-01,  3.02000009e-02],
       [-1.08499996e-01,  3.02000009e-02],
       [-1.08499996e-01,  3.02000009e-02],
       [-1.09899998e-01,  1.31000001e-02],
       [-1.09899998e-01,  1.31000001e-02],
       [-1.09899998e-01,  1.31000001e-02],
       [-1.07400000e-01,  3.29000019e-02],
       [-1.07400000e-01,  3.29000019e-02],
       [-1.07400000e-01,  3.29000019e-02],
       [-9.88999978e-02,  4.03000005e-02],
       [-9.88999978e-02,  4.03000005e-02],
       [-9.88999978e-02,  4.03000005e-02],
       [-1.01099998e-01,  4.39999998e-03],
       [-1.01099998e-01,  4.39999998e-03],
       [-1.01099998e-01,  4.39999998e-03],
       [-1.08300000e-01, -1.09999999e-03],
       [-1.08300000e-01, -1.09999999e-03],
       [-1.

In [70]:
mne.viz.plot_topomap(ranking, pos=pos)

(<matplotlib.image.AxesImage at 0x1e74af447c0>,
 <matplotlib.contour.QuadContourSet at 0x1e74af44a90>)