In [1]:
import os
import sys
import subprocess
import argparse
from bids import BIDSLayout
from util.io.iter_BIDSPaths import *
from util.io.bids import DataSink

import gc
import sys
import mne
import numpy as np
import pandas as pd

# import matplotlib.pyplot as plt
# from typing import Tuple, Iterator
# from mne_bids import BIDSPath, read_raw_bids, print_dir_tree
# from mne.time_frequency import tfr_morlet
# from bids import BIDSLayout

# from sklearn.pipeline import make_pipeline
# from sklearn import preprocessing
# from sklearn.preprocessing import StandardScaler
# from sklearn.linear_model import LogisticRegression
# from mne.decoding import SlidingEstimator, cross_val_multiscore

In [6]:
BIDS_ROOT = '../data/bids'
FIGS_ROOT = '../figs'
STIM_FREQS = np.array([130, 200, 280])

cond = '11'

In [3]:
print("---------- Load data ----------")
fpath = '/project2/hcn1/pitch_tracking_attention/data/bids/derivatives/preprocessing/sub-12/sub-12_task-pitch_run-1_desc-clean_epo.fif.gz'
epochs = mne.read_epochs(fpath)
print(epochs.event_id)

---------- Load data ----------
Reading /project2/hcn1/pitch_tracking_attention/data/bids/derivatives/preprocessing/sub-12/sub-12_task-pitch_run-1_desc-clean_epo.fif.gz ...
    Found the data of interest:
        t =    -200.00 ...     350.00 ms
        0 CTF compensation matrices available
Reading /project2/hcn1/pitch_tracking_attention/data/bids/derivatives/preprocessing/sub-12/sub-12_task-pitch_run-1_desc-clean_epo.fif-1.gz ...
    Found the data of interest:
        t =    -200.00 ...     350.00 ms
        0 CTF compensation matrices available
0 bad epochs dropped
0 bad epochs dropped
Not setting metadata
3350 matching events found
No baseline correction applied
0 projection items activated
{'11': 10001, '12': 10002, '13': 10003, '21': 10004, '22': 10005, '23': 10006, '31': 10007, '32': 10008, '33': 10009}


In [7]:
print("---------- Subset epochs ----------")
CONDS = {'1': ['11', '12', '13'], # subset the trials belonging to the given condition
         '2': ['21', '22', '23'],
         '3': ['31', '32', '33']}
condition_epochs = epochs[CONDS[cond[0]]]
events = condition_epochs.events
print(condition_epochs.event_id)
print(condition_epochs)

---------- Subset epochs ----------
{'11': 10001, '12': 10002, '13': 10003}
<EpochsFIF |  824 events (all good), -0.2 - 0.35 sec, baseline -0.2 – 0 sec, ~1.05 GB, data loaded,
 '11': 271
 '12': 288
 '13': 265>


In [9]:
labels = pd.Series(events[:, 2])
EVENT_DICTS = {'11': {10001 : 1, 10002 : 0, 10003 : 0, 10004: 0, 10005: 0, 10006: 0, 10007: 0, 10008: 0, 10009: 0},
               '12': {10001 : 0, 10002 : 1, 10003 : 0, 10004: 0, 10005: 0, 10006: 0, 10007: 0, 10008: 0, 10009: 0},
               '13': {10001 : 0, 10002 : 0, 10003 : 1, 10004: 0, 10005: 0, 10006: 0, 10007: 0, 10008: 0, 10009: 0},
               '21': {10001 : 0, 10002 : 0, 10003 : 0, 10004: 1, 10005: 0, 10006: 0, 10007: 0, 10008: 0, 10009: 0},
               '22': {10001 : 0, 10002 : 0, 10003 : 0, 10004: 0, 10005: 1, 10006: 0, 10007: 0, 10008: 0, 10009: 0},
               '23': {10001 : 0, 10002 : 0, 10003 : 0, 10004: 0, 10005: 0, 10006: 1, 10007: 0, 10008: 0, 10009: 0},
               '31': {10001 : 0, 10002 : 0, 10003 : 0, 10004: 0, 10005: 0, 10006: 0, 10007: 1, 10008: 0, 10009: 0},
               '32': {10001 : 0, 10002 : 0, 10003 : 0, 10004: 0, 10005: 0, 10006: 0, 10007: 0, 10008: 1, 10009: 0},
               '33': {10001 : 0, 10002 : 0, 10003 : 0, 10004: 0, 10005: 0, 10006: 0, 10007: 0, 10008: 0, 10009: 1},}
y = labels.replace(EVENT_DICTS[cond])
print(labels)
print(y)

0      10003
1      10002
2      10003
3      10002
4      10002
       ...  
819    10003
820    10001
821    10003
822    10003
823    10001
Length: 824, dtype: int32
0      0
1      0
2      0
3      0
4      0
      ..
819    0
820    1
821    0
822    0
823    1
Length: 824, dtype: int32


In [None]:
print("---------- Compute power ----------")
n_cycles = STIM_FREQS / 7 # different number of cycle per frequency
                           # higher constant, fewer windows, maybe?
power = tfr_morlet(epochs,
                   freqs = STIM_FREQS,
                   n_cycles = n_cycles,
                   use_fft = True,
                   return_itc = False,
                   decim = 3,
                   n_jobs = 1,
                   average = False)
power = np.log10(power)

del epochs
gc.collect()

# Get some information
n_epochs = np.shape(power)[0]
n_channels = np.shape(power)[1]
n_freqs = np.shape(power)[2]
n_windows = np.shape(power)[3]
print("n_windows: " + str(n_windows))

---------- Compute power ----------
Not setting metadata


In [None]:
print("---------- Prepare for decoder ----------")
# Reshape for classifier
X = power.reshape((n_epochs, n_freqs * n_channels, n_windows)) # Set order to preserve epoch order

# Create array of condition labels
labels = pd.Series(events[:, 2])
y = labels.replace({10001 : 130, 10002 : 200, 10003 : 280})
le = preprocessing.LabelEncoder()
y = le.fit_transform(y)

In [None]:
print("---------- Decode ----------")
clf = make_pipeline(
    StandardScaler(),
    LogisticRegression(solver = 'liblinear')
)

print("Creating sliding estimators")
time_decod = SlidingEstimator(clf)

print("Fit estimators")
scores = cross_val_multiscore(
    time_decod,
    X, # a trials x features x time array
    y, # an (n_trials,) array of integer condition labels
    cv = 5, # use stratified 5-fold cross-validation
    n_jobs = -1, # use all available CPU cores
)
scores = np.mean(scores, axis = 0) # average across cv splits

In [None]:
print("---------- Save decoder scores ----------")
print('Saving scores to: ' + scores_fpath)
np.save(scores_fpath, scores)

---------- Load data ----------
/project2/hcn1/pitch_tracking_attention/data/bids/derivatives/preprocessing/sub-12/sub-12_task-pitch_run-1_desc-clean_epo.fif.gz
Reading /project2/hcn1/pitch_tracking_attention/data/bids/derivatives/preprocessing/sub-12/sub-12_task-pitch_run-1_desc-clean_epo.fif.gz ...
    Found the data of interest:
        t =    -200.00 ...     350.00 ms
        0 CTF compensation matrices available
Reading /project2/hcn1/pitch_tracking_attention/data/bids/derivatives/preprocessing/sub-12/sub-12_task-pitch_run-1_desc-clean_epo.fif-1.gz ...
    Found the data of interest:
        t =    -200.00 ...     350.00 ms
        0 CTF compensation matrices available
0 bad epochs dropped
0 bad epochs dropped
Not setting metadata
3350 matching events found
No baseline correction applied
0 projection items activated
---------- Compute power ----------
Not setting metadata


In [None]:
print("---------- Plot ----------")
n_stimuli = 3
fig, ax = plt.subplots()
ax.plot(range(len(scores)), scores, label = 'score')
ax.axhline(1/n_stimuli, color = 'k', linestyle = '--', label = 'chance')
ax.set_xlabel('Times')
ax.set_ylabel('Accuracy')  # Area Under the Curve
ax.legend()
ax.set_title('Sensor space decoding')

# Save plot
fig_fpath = FIGS_ROOT + '/subj-' + sub + '_' + 'task-pitch_' + 'run-' + run + '_log_reg_no_crop' + '.png'
print('Saving figure to: ' + fig_fpath)
plt.savefig(fig_fpath)