In [1]:
import os
import os.path as op
import mne
import numpy as np
import re
import itertools
import pandas as pd

from typing import Tuple, Iterator
from mne_bids import BIDSPath, read_raw_bids, print_dir_tree
from mne.time_frequency import tfr_morlet
from bids import BIDSLayout

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from mne.decoding import SlidingEstimator, cross_val_multiscore

In [2]:
BIDS_ROOT = '../data/bids'
STIM_FREQS = np.array([50, 100, 150, 200, 250])
FS = 2000

## Functions

In [3]:
KeyType = Tuple[str, str, str, str]

def get_fpaths(bids_root) -> Iterator[KeyType]:
    # Get filepaths
    layout = BIDSLayout(bids_root, derivatives = True)
    fpaths = layout.get(scope = 'preprocessing',
                        extension = 'fif.gz',
                        return_type = 'filename')
    fpaths.pop(0)
    
    # Get corresponding subject number
    filter_subs = re.compile('sub-(\d)_')
    subs = list(map(filter_subs.findall, fpaths))
    subs = list(itertools.chain(*subs))
    
    # Get corresponding run number
    filter_runs = re.compile('run-(\d)')
    runs = list(map(filter_runs.findall, fpaths))
    runs = list(itertools.chain(*runs))
    
    for i in range(len(fpaths)):
        key = (fpaths[i], subs[i], 'tasks', runs[i])
        yield key

def load_fif(fname):
    epochs = mne.read_epochs(fname)
    return epochs

def get_power_at_stim_freqs(epochs, stim_freqs):
    n_cycles = stim_freqs / 7.  # different number of cycle per frequency
    power = tfr_morlet(epochs, 
                       freqs = stim_freqs, 
                       n_cycles = n_cycles, 
                       use_fft = True,
                       return_itc = False, 
                       decim = 3, 
                       n_jobs = 1,
                       average = False)

    return power

## Compute power

In [4]:
# powers = []
# events = []
for (fpath, sub, task, run) in get_fpaths(BIDS_ROOT):
    epochs = load_fif(fpath)
    epochs = epochs.crop(tmin = 0)
#     events.append(mne.read_events(fpath)) #CHANGEBACK
    events = mne.read_events(fpath)
    power = get_power_at_stim_freqs(epochs, STIM_FREQS)
    log_power = np.log10(power)
#     powers.append(log_power) #CHANGEBACK
#     power = log_power 
    break

Reading /Users/nusbaumlab/src/pitch_tracking/analysis/../data/bids/derivatives/preprocessing/sub-2/sub-2_task-pitch_run-1_desc-clean_epo.fif.gz ...




    Found the data of interest:
        t =    -200.00 ...     250.00 ms
        0 CTF compensation matrices available
Not setting metadata
Not setting metadata
4553 matching events found
No baseline correction applied
0 projection items activated


  events = mne.read_events(fpath)


Not setting metadata


## Shape data for decoder

In [9]:
# Get some information
power = log_power 
n_epochs = np.shape(power)[0]
n_channels = np.shape(power)[1]
n_freqs = np.shape(power)[2]
n_windows = np.shape(power)[3]

# Reshape for classifier
power = power.reshape((n_epochs, n_freqs * n_channels, n_windows)) # Set order to preserve epoch order

In [11]:
# Create array of condition labels
labels = events[:, 2]
labels = pd.Series(labels)
labels = labels.replace({10001 : 50, 10002 : 100, 10003 : 150, 10004 : 200, 10005 : 250})

(0       250
 1       250
 2        50
 3        50
 4        50
        ... 
 4548     50
 4549    150
 4550    100
 4551     50
 4552     50
 Length: 4553, dtype: int64,)

## Decode

In [None]:
# Create dataframe for classifier (trials x features x time), one for each subject
power = power.reshape((n_epochs, n_freqs * n_channels, n_windows)) # Set order to preserve epoch order

n_stimuli = 5
metric = 'accuracy'

clf = make_pipeline(
    StandardScaler(),
    LogisticRegression(solver = 'liblinear')  # liblinear is faster than lbfgs
)
time_decod = SlidingEstimator(clf, scoring = metric, n_jobs = -1)
scores = cross_val_multiscore(
    time_decod,
    power, # a trials x features x time array
    labels, # an (n_trials,) array of integer condition labels
    cv = 5, scoring = metric, # use stratified 5-fold cross-validation
    n_jobs = -1 # use all available CPU cores
)
scores = np.mean(scores, axis = 0) # average across cv splits

# plot
fig, ax = plt.subplots()
ax.plot(epochs.times, scores, label = 'score')
ax.axhline(1/n_stimuli, color = 'k', linestyle = '--', label = 'chance')
ax.set_xlabel('Times')
ax.set_ylabel(metric)  # Area Under the Curve
ax.legend()
ax.set_title('Sensor space decoding')

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
  if LooseVersion(__version__) < LooseVersion('4.36'):
  if LooseVersion(__version__) < LooseVersion('4.36'):
100%|██████████| Fitting SlidingEstimator : 167/167 [01:20<00:00,    2.07it/s]
100%|██████████| Transforming SlidingEstimator : 167/167 [00:00<00:00,  309.08it/s]
[Parallel(n_jobs=8)]: Done   2 out of   5 | elapsed:  1.7min remaining:  2.5min
