# Example notebook 2

This notebook sketches the analysis of a finger tapping dataset with multiple subjects. A simple Linear Discriminant Analysis (LDA) classifier is trained to distinguish left and right fingertapping.

In [None]:
import cedalion
import cedalion.nirs
import cedalion.xrutils as xrutils
from cedalion.datasets import get_multisubject_fingertapping_snirf_paths
import numpy as np
import xarray as xr
import pint
import matplotlib.pyplot as p
import scipy.signal
import os.path

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split, cross_validate
from sklearn.metrics import accuracy_score

xr.set_options(display_max_rows=3, display_values_threshold=50)
np.set_printoptions(precision=4)

### Loading raw CW-NIRS data from a SNIRF file

This notebook uses a finger-tapping dataset in BIDS layout provided by [Rob Luke](https://github.com/rob-luke/BIDS-NIRS-Tapping). It can can be downloaded via `cedalion.datasets`.

xarray provides another container: `xr.DataSet`. These are collections of `xr.DataArray` that share coordinate axes. These can be used to group different arrays with shared coordinate axes together.

In [None]:
fnames = get_multisubject_fingertapping_snirf_paths()
subjects  = [f"sub-{i:02d}" for i in [1,2,3,4,5]]

# store data of different subjects in a dictionary
data = {} 
for subject,fname in zip(subjects, fnames):
    
    elements = cedalion.io.read_snirf(fname)

    amp = elements[0].data[0]
    stim = elements[0].stim # pandas Dataframe
    geo3d = elements[0].geo3d
    
    # cedalion registers an accessor (attribute .cd ) on pandas DataFrames
    stim.cd.rename_events( {
        "1.0" : "control",
        "2.0" : "Tapping/Left",
        "3.0" : "Tapping/Right"
    })
    
    dpf = xr.DataArray([6, 6], dims="wavelength", coords={"wavelength" : amp.wavelength})
    
    data[subject] = xr.Dataset(
        data_vars = {
            "amp" : amp,
            "od"  : - np.log( amp / amp.mean("time") ),
            "geo" : geo3d,
            "conc": cedalion.nirs.beer_lambert(amp, geo3d, dpf)
        },
        attrs={"stim" : stim}, # store stimulus data in attrs
        coords={"subject" : subject} # add the subject label as a coordinate
    )

Illustrate the dataset of one subject

In [None]:
display(data["sub-01"])

### Frequency filtering and splitting into epochs

In [None]:
for subject, ds in data.items():
    # cedalion registers the accessor .cd on DataArrays
    # to provide common functionality like frequency filters...
    ds["conc_freqfilt"] = ds["conc"].cd.freq_filter(fmin=0.02, fmax=0.5, butter_order=4)

    # ... or epoch splitting
    ds["cfepochs"] = ds["conc_freqfilt"].cd.to_epochs(
        ds.attrs["stim"], # stimulus dataframe
        ["Tapping/Left", "Tapping/Right"],  # select events
        before=5, # seconds before stimulus
        after=20  # seconds after stimulus
    )

### Plot frequency filtered data
Illustrate for a single subject and channel the effect of the bandpass filter.

In [None]:
ds = data["sub-01"]
channel = "S5D7"

f,ax= p.subplots(2,1, figsize=(12,4), sharex=True)
ax[0].plot(ds.time, ds.conc.sel(channel=channel, chromo="HbO"), "r-", label="HbO")
ax[0].plot(ds.time, ds.conc.sel(channel=channel, chromo="HbR"), "b-", label="HbR")
ax[1].plot(ds.time, ds.conc_freqfilt.sel(channel=channel, chromo="HbO"), "r-", label="HbO")
ax[1].plot(ds.time, ds.conc_freqfilt.sel(channel=channel, chromo="HbR"), "b-", label="HbR")
ax[0].set_xlim(1000,1200)
ax[1].set_xlabel("time / s")
ax[0].set_ylabel("$\Delta c$ / $\mu M$")
ax[1].set_ylabel("$\Delta c$ / $\mu M$")
ax[0].legend(loc="upper left"); ax[1].legend(loc="upper left");

In [None]:
display(data["sub-01"]["cfepochs"])

In [None]:
all_epochs = xr.concat([ds["cfepochs"] for ds in data.values()], dim="epoch")
all_epochs

### Block Averages

In [None]:
# calculate baseline
baseline = all_epochs.sel(reltime=(all_epochs.reltime < 0)).mean("reltime")
# subtract baseline
all_epochs_blcorrected = all_epochs - baseline

# group trials by trial_type. For each group individually average the epoch dimension
blockaverage = all_epochs_blcorrected.groupby("trial_type").mean("epoch")

Plotting averaged epochs

In [None]:
f,ax = p.subplots(5,6, figsize=(12,10))
ax = ax.flatten()
for i_ch, ch in enumerate(blockaverage.channel):
    for ls, trial_type in zip(["-", "--"], blockaverage.trial_type):    
        ax[i_ch].plot(blockaverage.reltime, blockaverage.sel(chromo="HbO", trial_type=trial_type, channel=ch), "r", lw=2, ls=ls)
        ax[i_ch].plot(blockaverage.reltime, blockaverage.sel(chromo="HbR", trial_type=trial_type, channel=ch), "b", lw=2, ls=ls)
    ax[i_ch].grid(1)
    ax[i_ch].set_title(ch.values)
    ax[i_ch].set_ylim(-.3, .6)
    
p.tight_layout()

### Training a LDA classifier with Scikit-Learn

In [None]:
# start with the frequency-filtered, epoched and baseline-corrected concentration data
# discard the samples before the stimulus onset
epochs = all_epochs_blcorrected.sel(reltime=all_epochs_blcorrected.reltime >=0)
# strip units. sklearn would strip them anyway and issue a warning about it.
epochs = epochs.pint.dequantify()

# need to manually tell xarray to create an index for trial_type
epochs = epochs.set_xindex("trial_type")

In [None]:
display(epochs)

In [None]:
X = epochs.stack(features=["chromo", "channel", "reltime"])
display(X)

In [None]:
y = xr.apply_ufunc(LabelEncoder().fit_transform, X.trial_type)
display(y)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3, stratify=y)
classifier = LinearDiscriminantAnalysis(n_components=1).fit(X_train, y_train)
y_pred = classifier.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")

In [None]:
f,ax = p.subplots(1,2, figsize=(12,3))
for trial_type, c in zip(["Tapping/Left", "Tapping/Right"], ["r", "g"]):
    kw  = dict(alpha=.5, fc=c, label=trial_type)
    ax[0].hist(classifier.decision_function(X_train.sel(trial_type=trial_type)),**kw)
    ax[1].hist(classifier.decision_function(X_test.sel(trial_type=trial_type)), **kw)

ax[0].set_xlabel("LDA score"); ax[1].set_xlabel("LDA score"); ax[0].set_title("train"); ax[1].set_title("test"); ax[0].legend(ncol=1,loc="upper left"); ax[1].legend(ncol=1, loc="upper left");