# Import Dependencies

In [None]:
import time
import sys
import os

print(sys.version)

import matplotlib.pyplot as p
from matplotlib.lines import Line2D
import numpy as np
import xarray as xr
import gzip
import pickle

import pyvista as pv
pv.set_jupyter_backend('server')
#pv.set_jupyter_backend('static')

from IPython.display import Image

import cedalion
import cedalion.sigproc.quality as quality
import cedalion.dataclasses as cdc
import cedalion.datasets
import cedalion.geometry.registration # import icp_with_full_transform, find_spread_points
import cedalion.geometry.segmentation
from cedalion.geometry.photogrammetry.processors import ColoredStickerProcessor, geo3d_from_scan
from cedalion.geometry.landmarks import order_ref_points_6
import cedalion.imagereco.forward_model as fw
import cedalion.imagereco.tissue_properties
import cedalion.io
import cedalion.plots # import OptodeSelector
import cedalion.xrutils as xrutils
from cedalion.imagereco.solver import pseudo_inverse_stacked
from cedalion import units

# for dev purposes
%load_ext autoreload
%autoreload 2

%matplotlib widget


prefix =  'C:/Users/avonl/OneDrive/Work/Research/projects/2024 - AvLtapCedalion/data/' # local
# prefix =  '/home/avonluh/data/' # cluster

## Data loading directories and files
# head segmentations
SEG_DATADIR = prefix+'mri/segmented/'
landmarks = 'landmarks.mrk.json'
# Finger Tapping fNIRS data
FT_DATADIR = prefix+'fnirs/data/'
dataset ='2024-02-05_013.snirf'
# Photogrammetric Scan
PG_DATADIR = prefix+'photogrammetry/scan.obj'
#PG_DATADIR = 'C://Users//avonl//My Drive (ibs.bifold@gmail.com)//photogrammetry//shift4.obj'
# Fluence Profile (precalculated)
FP_DATADIR = prefix+'fluence/'
TEMP_DATADIR = prefix+'tempdata/'
datafull = 'ftfull.snirf'

# Load Finger Tapping Data

In [None]:
# FIXME temporarily define ADU unit in this dataset to avoid an error
#cedalion.units.define("ADU = 1")
#record = cedalion.io.read_snirf(FT_DATADIR+dataset)
#rec = record[0]
rec = cedalion.datasets.get_fingertappingDOT()
# rename events
rec.stim.cd.rename_events(
        {"1": "Rest", 
         "2": "FTapping/Left", 
         "3": "FTapping/Right",
         "4": "BallSqueezing/Left",
         "5": "BallSqueezing/Right"}
    )

# Load Headmodel
Getting segmented MRI scans from AvL

In [None]:
masks={'csf': 'csf.nii', 'gm': 'gm.nii', 'scalp': 'scalp.nii', 'skull': 'skull.nii', 'wm': 'wm.nii'}

head = fw.TwoSurfaceHeadModel.from_segmentation(
    segmentation_dir=SEG_DATADIR,
    mask_files = masks,
    landmarks_ras_file=landmarks
)

Plot loaded head

In [None]:
plt = pv.Plotter(notebook=True, off_screen=False)
plt.clear()  # Clear any previous plots

cedalion.plots.plot_surface(plt, head.brain, color="#d3a6a1")
cedalion.plots.plot_surface(plt, head.scalp, opacity=.1)

plt.show()


# Register Probe to Headmodel using Photogrammetric scan

#### First plot the default probe 

In [None]:
# plot montage
cedalion.plots.plot_montage3D(rec["amp"], rec.geo3d)

## Load and Display Photogrammetric Scan 

In [None]:
# load scan
scan = cedalion.io.read_einstar_obj(PG_DATADIR)

# find label stickers on the scan
PGprocessor = cedalion.geometry.photogrammetry.processors.ColoredStickerProcessor(
    colors={
        "O" : ((0.11, 0.21, 0.7, 1)), # (hue_min, hue_max, value_min, value_max)
        #"L" : ((0.25, 0.37, 0.35, 0.6))
    }
)
sticker_centers, normals, details = PGprocessor.process(scan, details=True)


### Optodes can be manually removed or added, if not all were found

In [None]:
visualizer = cedalion.plots.OptodeSelector(scan, sticker_centers, normals)
visualizer.plot()
visualizer.enable_picking()
cedalion.plots.plot_surface(visualizer.plotter, scan, opacity=1.0)
visualizer.plotter.show()

After selecting all optodes, update sticker_centers and normals:

In [None]:
sticker_centers = visualizer.points
normals = visualizer.normals if visualizer.normals is not None else normals

## Optode projection and landmark selection

In [None]:
optode_length = 22.6 * cedalion.units.mm

scalp_coords = sticker_centers.copy()
mask_optodes = sticker_centers.group == 'O'
scalp_coords[mask_optodes] = sticker_centers[mask_optodes] - optode_length*normals[mask_optodes]
#display(scalp_coords)

plt = pv.Plotter()
cedalion.plots.plot_surface(plt, scan, opacity=0.3)
cedalion.plots.plot_labeled_points(plt, sticker_centers, color="r")
cedalion.plots.plot_labeled_points(plt, scalp_coords, color="g")
cedalion.plots.plot_vector_field(plt, sticker_centers, normals)
plt.show()

### Landmark selection

In [None]:
plt = pv.Plotter()
get_landmarks = cedalion.plots.plot_surface(plt, scan, opacity=1.0, pick_landmarks = True)
plt.show(interactive = True)

### Get landmarks from the plot
Call *get_landmarks* from the previous cell and write into xarray
* 1st value - coordinates of picked landmarks
* 2nd - labels of corresponding landmarks

In [None]:
import tkinter as tk
from tkinter import messagebox

landmark_coordinates, landmark_labels = get_landmarks()

# write into Xarray
landmarks = xr.DataArray(
    np.vstack(landmark_coordinates),
    dims=["label", "digitized"],
    coords={
        "label": ("label", landmark_labels),
        "type": ("label", [cdc.PointType.LANDMARK]*5),
        "group": ("label", ["L"]*5),
    },
).pint.quantify("mm")


display(landmarks)

### Load montage info from snirf file and find transformation between  montage to landmarks

In [None]:
montage_elements

In [None]:
montage_elements = rec.geo3d
montage_elements = montage_elements.rename({"digitized" : "aligned"})
#montage_elements
# find transformation between landmarks and montage elements
trafo = cedalion.geometry.registration.register_trans_rot(landmarks, montage_elements)

# apply the transformation to the montage elements
filtered_montage_elements = montage_elements.where((montage_elements.type == cdc.PointType.SOURCE) | (montage_elements.type == cdc.PointType.DETECTOR), drop=True)
filtered_montage_elements_t = filtered_montage_elements.points.apply_transform(trafo)


### Coregistration: find and assign optode labels of scalp coordinates

In [None]:
scalp_coords = sticker_centers.copy()

# iterative closest point registration
idx = cedalion.geometry.registration.icp_with_full_transform(scalp_coords, 
                                                                            filtered_montage_elements_t, max_iterations = 100) 
# extract labels for detected optodes
label_dict = {}
for i, label in enumerate(filtered_montage_elements.coords['label'].values):
    label_dict[i] = label
labels = [label_dict[index] for index in idx]

# write labels to scalp_coords
scalp_coords = scalp_coords.assign_coords(label=labels)


## visualize results
# Green points represent optode centers
# Next to them there shall be labels assumed by ICP algorithm (*show_labels = True*)
plt = pv.Plotter()
cedalion.plots.plot3d(None, scan.mesh, filtered_montage_elements_t, None, plotter = plt) 
cedalion.plots.plot_labeled_points(plt, scalp_coords, color="green", show_labels = True)
plt.show(interactive = True)


### Update geo3D coordinates for the headmodel 

Transform coordinates to headmodel coordinates and snap to surface. Save them.

In [None]:
# update geo3D with accurate optode scalp positions and landmarks
geo3Dscan = geo3d_from_scan(scalp_coords, landmarks)
# save geo3Dscan to disk via snirf by using a copy of rec
rectmp = rec
rectmp.geo3d = geo3Dscan

# save data
SAVEDATA = False
if SAVEDATA:
    cedalion.io.write_snirf(TEMP_DATADIR + datafull, rectmp)

# Preprocess Finger Tapping Data

In [None]:
## Prune with SNR threshold
snr_thresh = 10 # dB
snr, rec.masks["snr_mask"] = quality.snr(rec["amp"], snr_thresh)
# prune channels using the masks and the operator "all", which will keep only channels that pass all three metrics
rec["amp_pruned"], drop_list = quality.prune_ch(rec["amp"], [rec.masks["snr_mask"]], "all")

# Convert to OD
rec["od"] = cedalion.nirs.int2od(rec["amp"])

## find and remove movement artifacts
# define parameters for motion artifact detection. We follow the method from Homer2/3: "hmrR_MotionArtifactByChannel" and "hmrR_MotionArtifact".
t_motion = 0.5*units.s  # time window for motion artifact detection
t_mask = 1.0*units.s    # time window for masking motion artifacts (+- t_mask s before/after detected motion artifact)
stdev_thresh = 4.0      # threshold for standard deviation of the signal used to detect motion artifacts. Default is 50. We set it very low to find something in our good data for demonstration purposes.
amp_thresh = 5.0        # threshold for amplitude of the signal used to detect motion artifacts. Default is 5.
# to identify motion artifacts with these parameters we call the following function
rec.masks["ma_mask"] = quality.id_motion(rec["od"], t_motion, t_mask, stdev_thresh, amp_thresh)
rec.masks["ma_mask"], ma_info = quality.id_motion_refine(rec.masks["ma_mask"], 'by_channel')
# THIS IS WHERE WE WOULD APPLY SPLINE SG
# XXXXXXXXXXXXXXX
# XXXXXXXXXXXXXXX


# Convert to  HbO/HbR
dpf = xr.DataArray(
        [6, 6],
        dims="wavelength",
        coords={"wavelength": rec["amp"].wavelength},
    )
rec["conc"] = cedalion.nirs.beer_lambert(rec["amp_pruned"], rec.geo3d, dpf)


# lowpass od concentration data for image recon later
rec["od_freqfilt"] = rec["od"].cd.freq_filter(
        fmin=0.01, fmax=0.5, butter_order=4
    )
# lowpass filter concentration data for block averages below
rec["conc_freqfilt"] = rec["conc"].cd.freq_filter(
        fmin=0.01, fmax=0.5, butter_order=4
    )


# Plot a channel for quality control

In [None]:
p.figure()
p.plot(rec["conc"].time, rec["conc"].sel(channel="S1D2", chromo="HbO"), "r-", label="HbO")
p.plot(rec["conc"].time, rec["conc"].sel(channel="S1D2", chromo="HbR"), "b-", label="HbR")

# Adding vertical stimulus lines
clr = {'Rest': 'g', 'FTapping/Left': 'y', 'FTapping/Right': 'm', 'BallSqueezing/Left': 'c', 'BallSqueezing/Right': 'k'}
for idx, onset in enumerate(rec.stim['onset']):
    # Use the index to get the corresponding element in the 'duration' column
    type = rec.stim.at[idx, 'trial_type']
    p.axvline(x=onset, color=clr[type], linestyle='--', label=type)

p.xlabel("time / s")
p.ylabel("delta Conc / µM")


# Defining custom legend handles and labels
custom_handles = [
    p.Line2D([0], [0], color='g', lw=2, label='Rest'),
    p.Line2D([0], [0], color='y', lw=2, label='FTapping/Left'),
    p.Line2D([0], [0], color='m', lw=2, label='FTapping/Right'),
    p.Line2D([0], [0], color='c', lw=2, label='BallSqueezing/Left'),
    p.Line2D([0], [0], color='k', lw=2, label='BallSqueezing/Right'),
]

# Adding the custom legend
p.legend(handles=custom_handles, loc='upper right')



# Plot Block Averages

In [None]:
## keep only subset of long channels
sd_threshs = [2, 4.5]*units.cm # defines the lower and upper bounds for the source-detector separation that we would like to keep
ch_dist, rec.masks["sd_mask"] = quality.sd_dist(rec["conc_freqfilt"], rec.geo3d, sd_threshs)
rec["conc_freqfilt_LD"], masked_elements = xrutils.apply_mask(rec["conc_freqfilt"], rec.masks["sd_mask"], "drop", "channel")


# segment data into epochs
rec["cfepochs"] = rec["conc_freqfilt_LD"].cd.to_epochs(
        rec.stim,  # stimulus dataframe
        ["FTapping/Left", "FTapping/Right"],#, "BallSqueezing/Left", "BallSqueezing/Right"], # select events. do not use "Rest"
        before=5*units.s,  # seconds before stimulus
        after=20*units.s,  # seconds after stimulus
)

# calculate baseline
baseline_conc = rec["cfepochs"].sel(reltime=(rec["cfepochs"].reltime < 0)).mean("reltime")
# subtract baseline
rec["conc_epochs_blcorrected_LD"] = rec["cfepochs"] - baseline_conc

# group trials by trial_type. For each group individually average the epoch dimension
rec["blockaverage_conc"] = rec["conc_epochs_blcorrected_LD"] .groupby("trial_type").mean("epoch")

# Define line styles and trial types
line_styles = ["-", "--", "-.", ":"]
trial_types = ["FTapping/Left", "FTapping/Right", "BallSqueezing/Left", "BallSqueezing/Right"]

# Create custom legend handles
legend_handles = [Line2D([0], [0], color='k', lw=2, ls=ls, label=tt) for ls, tt in zip(line_styles, trial_types)]

# plot single channels
noPlts2 = int(np.ceil(np.sqrt(len(rec["blockaverage_conc"].channel))))
f, ax = p.subplots(noPlts2, noPlts2, figsize=(12, 10))
ax = ax.flatten()
for i_ch, ch in enumerate(rec["blockaverage_conc"].channel):
    for ls, trial_type in zip(line_styles, rec["blockaverage_conc"].trial_type):
        ax[i_ch].plot(
            rec["blockaverage_conc"].reltime,
            rec["blockaverage_conc"].sel(chromo="HbO", trial_type=trial_type, channel=ch),
            "r",
            lw=2,
            ls=ls,
        )
        ax[i_ch].plot(
            rec["blockaverage_conc"].reltime,
            rec["blockaverage_conc"].sel(chromo="HbR", trial_type=trial_type, channel=ch),
            "b",
            lw=2,
            ls=ls,
        )
    ax[i_ch].grid(1)
    ax[i_ch].set_title(ch.values)
    ax[i_ch].set_ylim(-0.3, 0.6)

# Add the legend to the first subplot
ax[0].legend(handles=legend_handles, title="Trial Types", loc='lower right')
p.tight_layout()

p.show() 

# DOT Image Reconstruction

## Optode Registration
Align the photogrammetrically registered optode positions with the scalp surface

In [None]:
LOAD_SCANCOORDS = False
# if we did the photogrammetric coregistration in a previous session and saved the results, load them here from our temp snirf file
if LOAD_SCANCOORDS:
    recordtmp = cedalion.io.read_snirf(TEMP_DATADIR + datafull)
    rectmp = record[0]
    geo3Dscan = rectmp.geo3d

geo3dscan_snapped_ijk = head.align_and_snap_to_scalp(geo3Dscan)

plt = pv.Plotter(notebook=True, off_screen=False)
cedalion.plots.plot_surface(plt, head.brain, color="#d3a6a1")
cedalion.plots.plot_surface(plt, head.scalp, opacity=.1)
cedalion.plots.plot_labeled_points(plt, geo3dscan_snapped_ijk, show_labels = True)
plt.show(interactive = True)

## Simulate light propagation with MCX or NIRFASTER

In [None]:
fwm = cedalion.imagereco.forward_model.ForwardModel(head, geo3dscan_snapped_ijk, rec._measurement_lists["amp"])

USE_CACHED = True
RUN_PACKAGE = 'NIRFASTer' # or 'MCX'

if USE_CACHED:
    fname = FP_DATADIR+'AvL_fluence.pickle.gz'
    with gzip.GzipFile(fname) as fin:        
        fluence_all, fluence_at_optodes = pickle.load(fin)
else:
    if RUN_PACKAGE == 'MCX':
        fluence_all, fluence_at_optodes = fwm.compute_fluence_mcx()
    elif RUN_PACKAGE == 'NIRFASTer':
        fluence_all, fluence_at_optodes = fwm.compute_fluence_nirfaster()
    # save computed fluence data (xarrays) to disk to avoid having to re-run each time
    file = gzip.GzipFile(FP_DATADIR+'AvL_fluence.pickle.gz', 'wb')
    file.write(pickle.dumps([fluence_all, fluence_at_optodes]))
    file.close()

## Plot fluence

To illustrate the tissue probed by light travelling from a source to the detector two fluence profiles need to be multiplied. 

In [None]:
time.sleep(1)

plt = pv.Plotter()

f = fluence_all.loc["S1", 760].values * fluence_all.loc["D8",760].values
f[f<=0] = f[f>0].min()
f = np.log10(f)
vf = pv.wrap(f)

plt.add_volume(
    vf,
    log_scale=False, 
    cmap='plasma_r',
    clim=(-10,0),
)
cedalion.plots.plot_surface(plt, head.brain, color="w")
cedalion.plots.plot_labeled_points(plt, geo3dscan_snapped_ijk, show_labels = True)

cog = head.brain.vertices.mean("label").values
plt.camera.position = cog + [-300,30, 150]
plt.camera.focal_point = cog 
plt.camera.up = [0,0,1] 

plt.show()

## Inverse Problem

### Calculate the sensitivity matrices

The sensitivity matrix describes the effect of an absorption change at a given surface vertex in the OD recording in a given channel and at given wavelength. The coordinate `is_brain` holds a mask to distinguish brain and scalp voxels.

The sensitivity `Adot` has shape (nchannel, nvertex, nwavelenghts). To solve the inverse problem we need a matrix that relates OD in channel space to absorption in image space. Hence, the sensitivity must include the extinction coefficients to translate between OD and concentrations. Furthermore, channels at different wavelengths must be  stacked as well vertice and chromophores into new dimensions (flat_channel, flat_vertex):

$$ \left( \begin{matrix} OD_{c_1, \lambda_1} \\ \vdots \\ OD_{c_N,\lambda_1} \\ OD_{c_1,\lambda_2} \\ \vdots \\ OD_{c_N,\lambda_2} \end{matrix}\right) = A \cdot
\left( \begin{matrix} \Delta c_{v_1, HbO} \\ \vdots \\ \Delta c_{v_N, HbO} \\ \Delta c_{v_1, HbR} \\ \vdots \\ \Delta c_{v_N, HbR} \end{matrix}\right) $$

In [None]:
# compute sensitivity matrix
Adot = fwm.compute_sensitivity(fluence_all, fluence_at_optodes)
Adot_stacked = fwm.compute_stacked_sensitivity(Adot)
# Invert the matrix
B = pseudo_inverse_stacked(Adot_stacked, alpha = 0.01, alpha_spatial = 0.001)


### Calculate average concentration changes on the cortex

In [None]:
od_epochs = rec["od_freqfilt"].cd.to_epochs(
        rec.stim, # stimulus dataframe
        ["FTapping/Left", "FTapping/Right", "BallSqueezing/Left", "BallSqueezing/Right"],  # select events
        before=5*units.s, # seconds before stimulus
        after=20*units.s  # seconds after stimulus
)

# calculate baseline
od_baseline = od_epochs.sel(reltime=(od_epochs.reltime < 0)).mean("reltime")
# subtract baseline
od_epochs_blcorrected = od_epochs - od_baseline

# group trials by trial_type. For each group individually average the epoch dimension
od_blockaverage = od_epochs_blcorrected.groupby("trial_type").mean("epoch")



# calculate the concentration on cortex by multiypling with the inverted sensitivity matrix
dC_brain, dC_scalp = fw.apply_inv_sensitivity(od_blockaverage, B)

## Plot results of image recon

Using functionality from pyvista and VTK plot the concentration changes on the brain surface

In [None]:
from cedalion.plots import image_recon_multi_view 

filename_multiview = 'image_recon_multiview'

# prepare data
X_ts = xr.concat([dC_brain.sel(trial_type="FTapping/Right"), dC_scalp.sel(trial_type="FTapping/Right")], dim="vertex")
X_ts = X_ts.rename({"reltime": "time"})
X_ts = X_ts.transpose("vertex", "chromo", "time")
X_ts = X_ts.assign_coords(is_brain=('vertex', Adot.is_brain.values))

scl = np.percentile(np.abs(X_ts.sel(chromo='HbO').values.reshape(-1)),99)
clim = (-scl,scl)

# selects the nearest time sample at t=10s in X_ts
#X_ts = X_ts.sel(time=5*units.s, method="nearest")

image_recon_multi_view(
    X_ts,  # time series data; can be 2D (static) or 3D (dynamic)
    head,
    cmap='seismic',
    clim=clim,
    view_type='hbo_brain',
    title_str='HbO',
    filename=filename_multiview,
    SAVE=True,
    time_range=(-5,30,0.5)*units.s,
    fps=6,
    geo3d_plot = None, #  geo3d_plot
    wdw_size = (1024, 768)
)

In [None]:
display(Image(data=open("image_recon_multiview.png",'rb').read(), format='png'))

In [None]:
from cedalion.plots import image_recon_multi_view 

filename_multiview = 'image_recon_multiview_scalp'


image_recon_multi_view(
    X_ts,  # time series data; can be 2D (static) or 3D (dynamic)
    head,
    cmap='seismic',
    clim=clim,
    view_type='hbo_scalp',
    title_str='HbO',
    filename=filename_multiview,
    SAVE=True,
    time_range=(-5,30,0.5)*units.s,
    fps=6,
    geo3d_plot = None, #  geo3d_plot
    wdw_size = (1024, 768)
)

In [None]:
display(Image(data=open("image_recon_multiview_scalp.gif",'rb').read(), format='png'))

## WIP! Training a LDA classifier with Scikit-Learn
For this example we use a very simple non-sophisticated approach. This section is currently WIP and probably flawed

In [None]:
from sklearn.preprocessing import LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.metrics import accuracy_score

In [None]:
# start with the frequency-filtered, epoched and baseline-corrected concentration data
# discard the samples before the stimulus onset
epochs = rec["conc_epochs_blcorrected_LD"].sel(reltime=rec["conc_epochs_blcorrected_LD"].reltime >=0)
# strip units. sklearn would strip them anyway and issue a warning about it.
epochs = epochs.pint.dequantify()

# need to manually tell xarray to create an index for trial_type
epochs = epochs.set_xindex("trial_type")
# pool finger tapping and ball squeezing trial types into only "left" vs "right"

#calculate the mean amplitude in the interval 8-15s after stimulus onset as a feature
epochs_meanfeature = epochs.sel(reltime=slice(8,15)).mean("reltime")
# stack the features of all channels and chromophores into a single feature vector 
X = epochs_meanfeature.stack(features=["chromo", "channel"])
#display(X)

In [None]:

# reduce the set of trials to include only the desired trial types, here Finger Tapping
ttype_des = ["FTapping/Left", "FTapping/Right"] # "BallSqueezing/Left", "BallSqueezing/Right"
X_des = X.where(X.trial_type.isin(ttype_des), drop=True)
display(X_des.sel(channel="S4D6", chromo ="HbO"))

### Scatter plot of an example feature: HbO at Channel S4D6 for both trials

In [None]:
import matplotlib.pyplot as plt

# Select the data for HbO and HbR
HbO_data_l = X_des.sel(channel="S4D6", chromo="HbO", trial_type="FTapping/Left")
HbO_data_r = X_des.sel(channel="S4D6", chromo="HbO", trial_type="FTapping/Right")

# Extract the values to be plotted
HbO_data_lvs = HbO_data_l.values
HbO_data_rvs = HbO_data_r.values

# Create a scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(HbO_data_lvs, HbO_data_rvs, c='blue', alpha=0.5, label="HbO L vs HbO R")
plt.plot([-0.4, 0.8], [-0.4, 0.8], 'r--', label='y = x')

# Add labels and title
plt.xlabel("HbO Values L")
plt.ylabel("HbO Values R")
plt.title("Scatter Plot of HbO L vs R (Channel: S4D6)")
plt.legend()

# Show the plot
plt.show()

In [None]:
# encode the trial types and train classifier
y = xr.apply_ufunc(LabelEncoder().fit_transform, X.trial_type)

# Initialize the classifier
classifier = LinearDiscriminantAnalysis(n_components=1)

# Set up 10-fold cross-validation
kf = StratifiedKFold(n_splits=10)

# Perform cross-validation
cross_val_scores = cross_val_score(classifier, X, y, cv=kf)

# Print the accuracy for each fold
print("Cross-validation accuracy scores for each fold:")
for i, score in enumerate(cross_val_scores, start=1):
    print(f"Fold {i}: {score:.4f}")

# Print the mean accuracy across all folds
print(f"\nMean accuracy across all folds: {cross_val_scores.mean():.4f}")

In [None]:
f, ax = p.subplots(1, 2, figsize=(12, 3))

for trial_type, c in zip(["FTapping/Left", "FTapping/Right"], ["r", "g"]):
    kw = dict(alpha=0.5, fc=c, label=trial_type)
    ax[0].hist(classifier.decision_function(X_train.sel(trial_type=trial_type)), **kw)
    ax[1].hist(classifier.decision_function(X_test.sel(trial_type=trial_type)), **kw)

ax[0].set_xlabel("LDA score")
ax[1].set_xlabel("LDA score")
ax[0].set_title("train")
ax[1].set_title("test")
ax[0].legend(ncol=1, loc="upper left")
ax[1].legend(ncol=1, loc="upper left");