# Basic single trial fNIRS finger tapping classification  

This notebook sketches the analysis of a finger tapping dataset with multiple subjects. A simple Linear Discriminant Analysis (LDA) classifier is trained to distinguish left and right fingertapping.

**PLEASE NOTE:** For simplicity's sake we are skipping many preprocessing steps (e.g. pruning, artifact removal, physiology removal). These are subject of other example notebooks. For a rigorous analysis you will want to include such steps. The purpose of this notebook is only to demonstrate easy interfacing of the scikit learn package. 

In [None]:
# This cells setups the environment when executed in Google Colab.
try:
    import google.colab
    !curl -s https://raw.githubusercontent.com/ibs-lab/cedalion/colab_setup/scripts/colab_setup.py -o colab_setup.py
    # Select branch with --branch "branch name" (default is "dev")
    %run colab_setup.py
except ImportError:
    pass

In [None]:
import cedalion
import cedalion.nirs
from cedalion.datasets import get_multisubject_fingertapping_snirf_paths
import cedalion.sigproc.quality as quality
import cedalion.plots as plots
import numpy as np
import xarray as xr
import matplotlib.pyplot as p

from sklearn.preprocessing import LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split, cross_val_score, cross_val_predict, StratifiedKFold
from sklearn.metrics import accuracy_score,roc_curve, roc_auc_score, auc

from cedalion import units

xr.set_options(display_max_rows=3, display_values_threshold=50)
np.set_printoptions(precision=4)

### Loading raw CW-NIRS data from a SNIRF file

This notebook uses a finger-tapping dataset in BIDS layout provided by [Rob Luke](https://github.com/rob-luke/BIDS-NIRS-Tapping). It can can be downloaded via `cedalion.datasets`.

Cedalion's `read_snirf` method returns a list of `Recording` objects. These are containers for timeseries and adjunct data objects.

In [None]:
fnames = get_multisubject_fingertapping_snirf_paths()
subjects = [f"sub-{i:02d}" for i in [1, 2, 3]]

# store data of different subjects in a dictionary
data = {}
for subject, fname in zip(subjects, fnames):
    records = cedalion.io.read_snirf(fname)
    rec = records[0]
    display(rec)

    # Cedalion registers an accessor (attribute .cd ) on pandas DataFrames.
    # Use this to rename trial_types inplace.
    rec.stim.cd.rename_events(
        {"1.0": "control", "2.0": "Tapping/Left", "3.0": "Tapping/Right"}
    )

    dpf = xr.DataArray(
        [6, 6],
        dims="wavelength",
        coords={"wavelength": rec["amp"].wavelength},
    )

    rec["od"] = -np.log(rec["amp"] / rec["amp"].mean("time")),
    rec["conc"] = cedalion.nirs.beer_lambert(rec["amp"], rec.geo3d, dpf)

    data[subject] = rec

Illustrate the dataset of one subject

In [None]:
display(data["sub-01"])

### Frequency filtering and splitting into epochs

In [None]:
for subject, rec in data.items():
    # cedalion registers the accessor .cd on DataArrays
    # to provide common functionality like frequency filters...
    rec["conc_freqfilt"] = rec["conc"].cd.freq_filter(
        fmin=0.01, fmax=0.5, butter_order=4
    )

    # ... or epoch splitting
    rec["cfepochs"] = rec["conc_freqfilt"].cd.to_epochs(
        rec.stim,  # stimulus dataframe
        ["Tapping/Left", "Tapping/Right"],  # select events
        before=5 * units.s,  # seconds before stimulus
        after=20 * units.s,  # seconds after stimulus
    )

### Plot frequency filtered data
Illustrate for a single subject and channel the effect of the bandpass filter.

In [None]:
rec = data["sub-01"]
channel = "S5D7"

f, ax = p.subplots(2, 1, figsize=(12, 4), sharex=True)
ax[0].plot(rec["conc"].time, rec["conc"].sel(channel=channel, chromo="HbO"), "r-", label="HbO")
ax[0].plot(rec["conc"].time, rec["conc"].sel(channel=channel, chromo="HbR"), "b-", label="HbR")
ax[1].plot(
    rec["conc_freqfilt"].time,
    rec["conc_freqfilt"].sel(channel=channel, chromo="HbO"),
    "r-",
    label="HbO",
)
ax[1].plot(
    rec["conc_freqfilt"].time,
    rec["conc_freqfilt"].sel(channel=channel, chromo="HbR"),
    "b-",
    label="HbR",
)
ax[0].set_xlim(1000, 1100)
ax[1].set_xlabel("time / s")
ax[0].set_ylabel("$\Delta c$ / $\mu M$")
ax[1].set_ylabel("$\Delta c$ / $\mu M$")
ax[0].legend(loc="upper left")
ax[1].legend(loc="upper left");

### Baseline removal

In [None]:
for subject, rec in data.items():
    # calculate baseline
    baseline_conc = rec["cfepochs"].sel(reltime=(rec["cfepochs"].reltime < 0)).mean("reltime")
    # subtract baseline
    rec["cfbl_epochs"] = rec["cfepochs"] - baseline_conc

In [None]:
display(data["sub-01"]["cfbl_epochs"])

### Block Averages of trials for one participant per condition

In [None]:
# we use subject 1 as an example here
subject = "sub-01"

# group trials by trial_type. For each group individually average the epoch dimension
blockaverage = data[subject]["cfbl_epochs"].groupby("trial_type").mean("epoch")

Plotting averaged epochs

In [None]:
f, ax = p.subplots(5, 6, figsize=(12, 10))
ax = ax.flatten()
for i_ch, ch in enumerate(blockaverage.channel):
    for ls, trial_type in zip(["-", "--"], blockaverage.trial_type):
        ax[i_ch].plot(
            blockaverage.reltime,
            blockaverage.sel(chromo="HbO", trial_type=trial_type, channel=ch),
            "r",
            lw=2,
            ls=ls,
        )
        ax[i_ch].plot(
            blockaverage.reltime,
            blockaverage.sel(chromo="HbR", trial_type=trial_type, channel=ch),
            "b",
            lw=2,
            ls=ls,
        )
    ax[i_ch].grid(1)
    ax[i_ch].set_title(ch.values)
    ax[i_ch].set_ylim(-0.3, 0.6)

# add legend
ax[0].legend(["HbO Tapping/Left", "HbR Tapping/Left",  "HbO Tapping/Right", "HbR Tapping/Right"])
p.tight_layout()

## Training a LDA classifier with Scikit-Learn
### Feature Extraction
We use very simple min, max and avg features.

In [None]:
for subject, rec in data.items():

    # avg signal between 0 and 10 seconds after stimulus onset
    fmean = rec["cfbl_epochs"].sel(reltime=slice(0, 10)).mean("reltime")
    # min signal between 0 and 15 seconds after stimulus onset
    fmin = rec["cfbl_epochs"].sel(reltime=slice(0, 15)).min("reltime")
    # max signal between 0 and 15 seconds after stimulus onset
    fmax = rec["cfbl_epochs"].sel(reltime=slice(0, 15)).max("reltime")
      
    # concatenate features and stack them into a single dimension
    X = xr.concat([fmean, fmin, fmax], dim="reltime")
    X = X.stack(features=["chromo", "channel", "reltime"])

    # strip units. sklearn would strip them anyway and issue a warning about it.
    X = X.pint.dequantify()

    # need to manually tell xarray to create an index for trial_type
    X = X.set_xindex("trial_type")

    # save in recording container
    rec.aux_obj["X"] = X

In [None]:
display(data["sub-01"].aux_obj["X"])

In [None]:
# Encode labels for use in scikit-learn
for subject, rec in data.items():
    rec.aux_obj["y"] = xr.apply_ufunc(LabelEncoder().fit_transform, rec.aux_obj["X"].trial_type)

display(data["sub-01"].aux_obj["y"])  

### Train LDA classifier for each subject using 5-fold cross-validation

In [None]:
# initialize dictionaries for key metrics for each subject to plot
scores = {}
fpr = {}
tpr = {}
roc_auc = {}

for subject, rec in data.items():

    X = rec.aux_obj["X"]
    y = rec.aux_obj["y"]
    classifier = LinearDiscriminantAnalysis(n_components=1)
    
    # Define the cross-validation strategy (e.g., stratified k-fold with 5 folds)
    cv = StratifiedKFold(n_splits=5)
    
    # Perform cross-validation and get accuracy scores
    scores[subject] = cross_val_score(classifier, X, y, cv=cv, scoring='accuracy')
    # Get predicted probabilities using cross-validation
    pred_prob = cross_val_predict(classifier, X, y, cv=cv, method='predict_proba')[:, 1]
   
    # Calculate ROC curve and AUC
    fpr[subject], tpr[subject], thresholds = roc_curve(y, pred_prob)
    roc_auc[subject] = auc(fpr[subject], tpr[subject])
    

    # Print the mean accuracy across folds
    print(f"Cross-validated accuracy for subject {subject}: {scores[subject].mean():.2f}")

# barplot of accuracies
f, ax = p.subplots()
ax.bar(data.keys(), [scores.mean() for scores in scores.values()])
ax.set_ylabel("Accuracy")
ax.set_xlabel("Subject")


### Plot ROC curves for subjects

In [None]:
# Initialize the ROC plot
p.figure(figsize=(10, 8))
# Train classifier and plot ROC curve for each subject
for subject, rec in data.items():
    # Plotting the ROC curve
    p.plot(fpr[subject], tpr[subject], lw=2, label=f'Subject {subject} (AUC = {roc_auc[subject]:.2f})')
# Plot the diagonal line for random guessing
p.plot([0, 1], [0, 1], color='gray', linestyle='--')
    # Adding labels and title
p.xlabel('False Positive Rate')
p.ylabel('True Positive Rate')
p.title('ROC Curves for All Subjects')
p.legend(loc='lower right')
p.grid(True)
p.show()