In [4]:
import os
import numpy as np
import pandas as pd
import mne
from mne.decoding import SlidingEstimator, cross_val_multiscore
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegressionCV
from sklearn.metrics import confusion_matrix, roc_auc_score, ConfusionMatrixDisplay, accuracy_score, balanced_accuracy_score
from config import *
from mne.beamformer import make_lcmv, apply_lcmv_epochs
from collections import defaultdict
from scipy.stats import ttest_1samp, spearmanr
import matplotlib.pyplot as plt
import gc
from jax import jit, grad, vmap, device_put, random
import jax.numpy as jnp
from jax.lib import xla_bridge
from functools import partial
import time
from base import ensure_dir

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


In [5]:
# params
trial_types = ["all", "pattern", "random"]
trial_type = 'pattern'
data_path = DATA_DIR
lock = "stim"
subjects = SUBJS
sessions = ['practice', 'b1', 'b2', 'b3', 'b4']
subjects_dir = FREESURFER_DIR
res_path = RESULTS_DIR
folds = 5
chance = 0.5
threshold = 0.05
# scoring = "accuracy"
scoring = "roc_auc"
parc='aparc'
hemi = 'both'
params = "pred_decoding"
verbose = False
jobs = -1
decim = False

plt.style.use('dark_background')

# figures dir
figures = RESULTS_DIR / 'figures' / lock / params / 'source' / trial_type
ensure_dir(figures)
# get times
epoch_fname = DATA_DIR / lock / 'sub01_0_s-epo.fif'
epochs = mne.read_epochs(epoch_fname, verbose=verbose)
times = epochs.times
if decim:
    times = times[::3]
del epochs
gc.collect()

  dirpos = int(tag.data)
  version=int(np.frombuffer(fid.read(4), dtype=">i4")),
  secs=int(np.frombuffer(fid.read(4), dtype=">i4")),
  usecs=int(np.frombuffer(fid.read(4), dtype=">i4")))
  logger.debug('    ' * indent + 'start { %d' % block)
  version=int(np.frombuffer(fid.read(4), dtype=">i4")),
  secs=int(np.frombuffer(fid.read(4), dtype=">i4")),
  usecs=int(np.frombuffer(fid.read(4), dtype=">i4")))
  logger.debug('    ' * indent + 'start { %d' % block)
  logger.debug('    ' * indent + 'start { %d' % block)
  logger.debug('    ' * (indent + 1) + 'block = %d nent = %d nchild = %d'
  logger.debug('    ' * indent + 'end } %d' % block)
  logger.debug('    ' * (indent + 1) + 'block = %d nent = %d nchild = %d'
  logger.debug('    ' * indent + 'end } %d' % block)
  logger.debug('    ' * indent + 'start { %d' % block)
  logger.debug('    ' * indent + 'start { %d' % block)
  logger.debug('    ' * indent + 'start { %d' % block)
  logger.debug('    ' * (indent + 1) + 'block = %d nent = %d nchi

547

In [6]:
subject = subjects[0]
# get labels
labels = mne.read_labels_from_annot(subject=subject, parc=parc, hemi=hemi, subjects_dir=subjects_dir, verbose=verbose)
# label = labels[ilabel]
        
session_id, session = 0, sessions[0]
# read stim epoch
epoch_fname = data_path / lock / f"{subject}_{session_id}_s-epo.fif"
epoch = mne.read_epochs(epoch_fname, preload=True, verbose=True)
# read behav
behav_fname = data_path / "behav" / f"{subject}_{session_id}.pkl"
behav = pd.read_pickle(behav_fname).reset_index()    
# get session behav and epoch
if session_id == 0:
    session = 'prac'
else:
    session = 'sess-%s' % (str(session_id).zfill(2))

if lock == 'button': 
    epoch_bsl_fname = data_path / "bsl" / f"{subject}_{session_id}_bl-epo.fif"
    epoch_bsl = mne.read_epochs(epoch_bsl_fname, verbose=verbose)
# read forward solution    
fwd_fname = res_path / "fwd" / lock / f"{subject}-fwd-{session_id}.fif"
fwd = mne.read_forward_solution(fwd_fname, verbose=verbose)
# compute data covariance matrix on evoked data
data_cov = mne.compute_covariance(epoch, tmin=0, tmax=.6, method="empirical", rank="info", verbose=verbose)
# compute noise covariance
if lock == 'button':
    noise_cov = mne.compute_covariance(epoch_bsl, method="empirical", rank="info", verbose=verbose)
else:
    noise_cov = mne.compute_covariance(epoch, tmin=-.2, tmax=0, method="empirical", rank="info", verbose=verbose)
info = epoch.info
# conpute rank
rank = mne.compute_rank(noise_cov, info=info, rank=None, tol_kind='relative', verbose=verbose)
# compute source estimate
filters = make_lcmv(info, fwd, data_cov=data_cov, noise_cov=noise_cov,
                pick_ori=None, rank=rank, reduce_rank=True, verbose=verbose)
stcs = apply_lcmv_epochs(epoch, filters=filters, verbose=verbose)

del epoch, fwd, data_cov, noise_cov, rank, filters
gc.collect()


Reading /Users/coum/Library/CloudStorage/OneDrive-etu.univ-lyon1.fr/asrt/preprocessed/stim/sub01_0_s-epo.fif ...
    Found the data of interest:
        t =    -196.61 ...     599.65 ms
        0 CTF compensation matrices available
0 bad epochs dropped
Not setting metadata
115 matching events found
No baseline correction applied
0 projection items activated


  dirpos = int(tag.data)
  version=int(np.frombuffer(fid.read(4), dtype=">i4")),
  secs=int(np.frombuffer(fid.read(4), dtype=">i4")),
  usecs=int(np.frombuffer(fid.read(4), dtype=">i4")))
  logger.debug('    ' * indent + 'start { %d' % block)
  version=int(np.frombuffer(fid.read(4), dtype=">i4")),
  secs=int(np.frombuffer(fid.read(4), dtype=">i4")),
  usecs=int(np.frombuffer(fid.read(4), dtype=">i4")))
  logger.debug('    ' * indent + 'start { %d' % block)
  logger.debug('    ' * indent + 'start { %d' % block)
  logger.debug('    ' * (indent + 1) + 'block = %d nent = %d nchild = %d'
  logger.debug('    ' * indent + 'end } %d' % block)
  logger.debug('    ' * (indent + 1) + 'block = %d nent = %d nchild = %d'
  logger.debug('    ' * indent + 'end } %d' % block)
  logger.debug('    ' * indent + 'start { %d' % block)
  logger.debug('    ' * indent + 'start { %d' % block)
  logger.debug('    ' * indent + 'start { %d' % block)
  logger.debug('    ' * (indent + 1) + 'block = %d nent = %d nchi



0

In [7]:
ilabel, label = 0, labels[0]
print(f"{ilabel+1}/{len(labels)}", subject, session, label.name)

# get stcs in label
stcs_data = [stc.in_label(label).data for stc in stcs]
stcs_data = np.array(stcs_data)
assert len(stcs_data) == len(behav)

if trial_type == 'pattern':
    pattern = behav.trialtypes == 1
    X = stcs_data[pattern]
    y = behav.positions[pattern]
elif trial_type == 'random':
    random = behav.trialtypes == 2
    X = stcs_data[random]
    y = behav.positions[random]
else:
    X = stcs_data
    y = behav.positions
y = y.reset_index(drop=True).to_numpy()            
assert X.shape[0] == y.shape[0]

if decim:                 
    X = X[:, :, ::3]

print("X shape:", X.shape, "(trials, vertices, time points)")

1/68 sub01 prac bankssts-lh
X shape: (51, 47, 163) (trials, vertices, time points)


In [None]:
X_train, X_test, y_train, y_test = train_test_split(X , y, test_size=0.1, random_state=42)
X_train = X_train.swapaxes(1, 2)
X_train = X_train.reshape(-1, X_train.shape[-1])

X_test = X_test.swapaxes(1, 2)
X_test_original = X_test.copy()
X_test = X_test.reshape(-1, X_test.shape[-1])

y_train = y_train.repeat(X.shape[-1]) - 1
y_test = y_test.repeat(X.shape[-1]) - 1
y_test_original = y_test.copy()

print("X_train shape: ", X_train.shape)
print("y_train shape: ", y_train.shape)

print("X_test shape: ", X_test.shape)
print("y_test shape: ", y_test.shape)

In [None]:
from jax import jit,grad,vmap,device_put,random
import jax.numpy as jnp
from functools import partial

class JaxReg:
    """
    Logistic regression classifier with GPU acceleration support through Google's JAX. The point of this class is fitting speed: I want this
    to fit a model for very large datasets (k49 in particular) as quickly as possible!

    - jit compilation utilized in sigma and loss methods (strongest in sigma due to matrix mult.). We need to 'partial' the
      jit function because it is used within a class.

    - jax.numpy (jnp) operations are JAX implementations of numpy functions.

    - jax.grad used as the gradient function. Returns gradient with respect to first parameter.

    - jax.vmap is used to 'vectorize' the jax.grad function. Used to compute gradient of batch elements at once, in parallel.
    """

    def __init__(self, learning_rate=.001, num_epochs=50, size_batch=20):
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.size_batch = size_batch

    def fit(self, data, y):
        self.K = max(y) + 1
        ones = jnp.ones((data.shape[0], 1))
        X = jnp.concatenate((ones, data), axis=1)
        W = jnp.zeros((jnp.shape(X)[1], max(y) + 1))

        self.coeff = self.mb_gd(W, X, y)

    # New mini-batch gradient descent function (because jitted functions require arrays which do not change shape)
    def mb_gd(self, W, X, y):
        num_epochs = self.num_epochs
        size_batch = self.size_batch
        eta = self.learning_rate
        N = X.shape[0]

        # Define the gradient function using jit, vmap, and the jax's own gradient function, grad.
        # vmap is especially useful for mini-batch GD since we compute all gradients of the batch at once, in parallel.
        # Special paramaters in_axes,out_axes define the axis of the input paramters (W, X, y) and output (gradients of batches)
        # upon which to vectorize. grads_b = loss_grad(W, X_batch, y_batch) has shape (batch_size, p+1, k) for p variables and k classes.

        loss_grad = jit(vmap(grad(self.loss), in_axes=(None, 0, 0), out_axes=0))

        for e in range(num_epochs):
            shuffle_index = random.permutation(random.PRNGKey(e), N)
            for m in range(0, N, size_batch):
                i = shuffle_index[m:m + size_batch]

                grads_b = loss_grad(W, X[i, :], y[i])  # 3D jax array of size (batch_size, p+1, k): gradients for each batch element

                W -= eta * jnp.mean(grads_b, axis=0)  # Update W with average over each batch
        return W

    def predict(self, data):
        ones = jnp.ones((data.shape[0], 1))
        X = jnp.concatenate((ones, data), axis=1)  # Augment to account for intercept
        W = self.coeff
        y_pred = jnp.argmax(self.sigma(X, W),
                            axis=1)  # Predicted class is largest probability returned by softmax array
        return y_pred

    def score(self, data, y_true):
        ones = jnp.ones((data.shape[0], 1))
        X = jnp.concatenate((ones, data), axis=1)
        y_pred = self.predict(data)
        acc = jnp.mean(y_pred == y_true)
        return acc

    # jitting 'sigma' is the biggest speed-up compared to the original implementation
    @partial(jit, static_argnums=0)
    def sigma(self, X, W):
        if X.ndim == 1:
            X = jnp.reshape(X, (-1, X.shape[0]))  # jax.grad seems to necessitate a reshape: X -> (1,p+1)
        s = jnp.exp(jnp.matmul(X, W))
        total = jnp.sum(s, axis=1).reshape(-1, 1)
        return s / total

    @partial(jit, static_argnums=0)
    def loss(self, W, X, y):
        f_value = self.sigma(X, W)
        loss_vector = jnp.zeros(X.shape[0])
        for k in range(self.K):
            loss_vector += jnp.log(f_value + 1e-10)[:, k] * (y == k)
        return -jnp.mean(loss_vector)

In [None]:
from jax.lib import xla_bridge

# Find fitting times for JaxReg models using 20 epochs

print(xla_bridge.get_backend().platform) # Confirm GPU in use

# Commit data to device - note these are now JAX arrays. Type: jaxlib.xla_extension.DeviceArray
X_train_dp = device_put(X_train)
y_train_dp = device_put(y_train)

lg_sgd_jax = JaxReg(learning_rate=1e-6, num_epochs = 20, size_batch = X_train_dp.shape[0])
lg_sgd_jax.fit(X_train_dp, y_train_dp)

In [None]:
print(lg_sgd_jax.score(X_test, y_test))