In [None]:
%matplotlib inline
import mne
import matplotlib.pyplot as plt

fname = "./oddball-epo.fif"
epochs = mne.read_epochs(fname)

mne.set_log_level(True)

event_ids = {"standard/stimulus": 200, "target/stimulus": 100}

## MVPA/decoding

Can we predict trial type from EEG activity?

In [None]:
from sklearn.svm import LinearSVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_val_score

In [None]:
epochs.pick_types(eeg=True)
X = epochs.get_data()  # features
y = epochs.events[:, -1] == event_ids["target/stimulus"]  # targets
X.shape, y.shape

In [None]:
y[:9]

X has the wrong shape - `samples`, `channels`, `times`, should be `samples`, `features`.

We can use `mne.decoding.Vectorizer` to correctly shape the data. It fits right into a scikit-learn pipeline.

In [None]:
from mne.decoding import Vectorizer
clf = make_pipeline(Vectorizer(), StandardScaler(),
                    LinearSVC(class_weight="balanced")
                   )

The resulting object behaves exactly as any other scikit-learn classifier:

In [None]:
clf.fit(X, y)

In [None]:
clf.predict(X[:9])

Of course, usually we care about cross-validated scores.

In [None]:
cross_val_score(clf, X, y)  # accuracy

So there is some decodable information in the brain data. Can we investigate this in a bit more detail?
For example: at which time points in the trial is there information about trial category?

We need two more tools for this: one to train and score at each time point, and one to handle the cross-validated scoring for the former.

In [None]:
from mne.decoding import SlidingEstimator, cross_val_multiscore
sl = SlidingEstimator(clf)

In [None]:
scores_time_decoding = cross_val_multiscore(sl, X, y)

In [None]:
scores_time_decoding.shape

In [None]:
fig, ax = plt.subplots()
ax.plot(epochs.times, scores_time_decoding.T)
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.plot(epochs.times, scores_time_decoding.mean(0))
plt.show()

But is the same thing happening at each time point? We can investigate that with generalization across time decoding.

In [None]:
from mne.decoding import GeneralizingEstimator
gen = GeneralizingEstimator(clf)
scores_gat = cross_val_multiscore(gen, X, y)

In [None]:
scores_gat.shape

In [None]:
import numpy as np
data = scores_gat.mean(0)
vmax = np.abs(data).max()
tmin, tmax = epochs.times[[0, -1]]

fig, ax = plt.subplots()
im = ax.imshow(
    data,
    origin="lower", cmap="RdBu_r",
    extent=(tmin, tmax, tmin, tmax),
    vmax=vmax, vmin=1-vmax);

plt.colorbar(im)

We can easily attempt more complex decoding pipelines.

In [None]:
from mne.decoding import UnsupervisedSpatialFilter
from sklearn.decomposition import PCA

In [None]:
pca = UnsupervisedSpatialFilter(PCA(.85))

In [None]:
pca_clf = make_pipeline(pca, Vectorizer(), StandardScaler(), LinearSVC())

In [None]:
cross_val_score(pca_clf, X, y)

What do the learned patterns actually look like?

In [None]:
svc = LinearSVC(class_weight="balanced")
topos = np.array([svc.fit(time_point.T, y).coef_ * time_point.std(1)
                  for time_point in X.T])[:, 0, :]
topo_ev = mne.EvokedArray(topos.T, info=epochs.info, tmin=-.2, nave=len(y))
topo_ev.plot_joint(times=[.22, .3, .375, .45]);