# Data Loading and Preprocessing

In [None]:
# This cells setups the environment when executed in Google Colab.
try:
    import google.colab
    !curl -s https://raw.githubusercontent.com/ibs-lab/cedalion/dev/scripts/colab_setup.py -o colab_setup.py
    # Select branch with --branch "branch name" (default is "dev")
    %run colab_setup.py
except ImportError:
    pass

In [None]:
import cedalion
import cedalion.sigproc.quality as quality
import cedalion.sigproc.motion_correct as motion_correct
from cedalion.plots import segmented_cmap
from cedalion import units
import cedalion.xrutils as xrutils
import cedalion.datasets

from pathlib import Path
import numpy as np
import xarray as xr

import matplotlib.pyplot as p

## Load Data

Example datasets are accessible through functions in `cedalion.datasets`. These take care of downloading, caching and updating the data files. Often they also already load the data.

Here we load a single-subject DOT dataset with a motor task.

In [None]:
rec = cedalion.datasets.get_fingertappingDOT()

This recording object hold a single NIRS time series `'amp'`

In [None]:
rec.timeseries.keys()

It contains several auxiliary time series from additional sensors:

In [None]:
rec.aux_ts.keys()

## Inspecting the Datasets

### Raw Amplitude Time Series

In [None]:
rec["amp"]

### Stimulus event information

In [None]:
rec.stim

### Montage

In [None]:
rec.geo3d

In [None]:
cedalion.plots.plot_montage3D(rec["amp"], rec.geo3d)

### Channel Distances

In [None]:
distances = cedalion.nirs.channel_distances(rec["amp"], rec.geo3d)

p.figure(figsize=(8,4))
p.hist(distances, 40)
p.xlabel("channel distance / mm")
p.ylabel("channel count");

### Plot raw amplitude for one channel

In [None]:
# example time trace
amp = rec["amp"]
ch = "S12D25"
f, ax = p.subplots(1,1, figsize=(12,4))
ax.set_prop_cycle("color", cedalion.plots.COLORBREWER_Q8)
ax.plot(amp.time, amp.sel(channel=ch, wavelength=760), label="amp. 760 nm")
ax.plot(amp.time, amp.sel(channel=ch, wavelength=850), label="amp. 850 nm")
cedalion.plots.plot_stim_markers(ax, rec.stim, y=1)
ax.set_xlabel("time / s")
ax.set_ylabel("amplitude / V")
ax.set_xlim(0,150)
ax.legend()
ax.set_title(ch);

## Quality Metrics : SCI & PSP

- using functions from cedalion.sigproc.quality we calculate two metrics:
  - scalp coupling index (SCI) 
  - peak spectral power (PSP)

- note the different time axis: both metrics a calculated in sliding windows 
- both functions return a metric and boolean arrays (masks) if the metric is above threshold

In [None]:
sci_threshold = 0.75
window_length = 10*units.s
sci, sci_mask = quality.sci(rec["amp"], window_length, sci_threshold)

psp_threshold = 0.03
psp, psp_mask = quality.psp(rec["amp"], window_length, psp_threshold)


display(sci.rename("sci"))
display(sci_mask.rename("sci_mask"))


In [None]:
# define three colomaps: redish below a threshold, blueish above
sci_norm, sci_cmap = segmented_cmap(
    "sci_cmap",
    0,
    1.0,
    [(0.0, "#000000"), (sci_threshold, "#DC3220"), (sci_threshold, "#5D3A9B"), (1.0, "#0C7BDC")],
    bad="magenta", over="magenta", under="magenta"
)
psp_norm, psp_cmap = segmented_cmap(
    "psp_cmap",
    0,
    1.0,
    [(0.0, "#000000"), (psp_threshold, "#DC3220"), (psp_threshold, "#5D3A9B"), (1.0, "#0C7BDC")],
    bad="magenta", over="magenta", under="magenta"
)

mask_norm, mask_cmap = segmented_cmap(
    "mask_cmap",
    0,
    1.0,
    [(0.0, "#DC3220"), (0.5, "#DC3220"), (0.5, "#0C7BDC"), (1.0, "#0C7BDC")],
)

def plot_sci(sci):
    # plot the heatmap
    f,ax = p.subplots(1,1,figsize=(17,10))

    m = ax.pcolormesh(sci.time, np.arange(len(sci.channel)), sci, shading="nearest", cmap=sci_cmap, norm=sci_norm)
    cb = p.colorbar(m, ax=ax)
    cb.set_label("SCI")
    ax.set_xlabel("time / s")
    p.tight_layout()
    ax.yaxis.set_ticks(np.arange(len(sci.channel)))
    ax.yaxis.set_ticklabels(sci.channel.values, fontsize=7)

def plot_psp(psp):
    f,ax = p.subplots(1,1,figsize=(17,10))

    m = ax.pcolormesh(psp.time, np.arange(len(psp.channel)), psp, shading="nearest", cmap=psp_cmap, norm=psp_norm)
    cb = p.colorbar(m, ax=ax)
    cb.set_label("PSP")
    ax.set_xlabel("time / s")
    p.tight_layout()
    ax.yaxis.set_ticks(np.arange(len(psp.channel)))
    ax.yaxis.set_ticklabels(psp.channel.values, fontsize=7)

def plot_quality_mask(mask, cb_label : str, bool_labels = ["TAINTED", "CLEAN"]):
    # plot the binary heatmap
    f,ax = p.subplots(1,1,figsize=(17,10))

    m = ax.pcolormesh(mask.time, np.arange(len(mask.channel)), mask, shading="nearest", cmap=mask_cmap, norm=mask_norm)
    cb = p.colorbar(m, ax=ax)
    p.tight_layout()
    ax.yaxis.set_ticks(np.arange(len(mask.channel)))
    ax.yaxis.set_ticklabels(mask.channel.values, fontsize=7);
    cb.set_label(cb_label)
    ax.set_xlabel("time / s");
    cb.set_ticks([.25,.75])
    cb.set_ticklabels(bool_labels)
    ax.set_xlabel("time / s");

In [None]:
plot_sci(sci)
plot_quality_mask(sci > sci_threshold, f"SCI > {sci_threshold}")
plot_psp(psp)
plot_quality_mask(psp > psp_threshold, f"PSP > {psp_threshold}")

## Combining Signal Quality Masks

We want both SCI and PSP to be above their respective thresholds for a window to be considered clean. We can use the boolean and operation to combine both and then look at the percentage of time both metrics are above the thresholds.

In [None]:
combined_mask = sci_mask & psp_mask

display(combined_mask)
plot_quality_mask(combined_mask, "combined_mask")

- calculate percentage of clean time per channel

In [None]:
perc_time_clean = combined_mask.sum(dim="time") / len(sci.time)

display(perc_time_clean)

f, ax = p.subplots(1,1,figsize=(6.5,6.5))

cedalion.plots.scalp_plot(
    rec["amp"],
    rec.geo3d,
    perc_time_clean,
    ax,
    cmap="RdYlGn",
    vmin=0.80,
    vmax=1,
    title=None,
    cb_label="Percentage of clean time",
    channel_lw=2,
    optode_labels=True
)
f.tight_layout()

## Correct Motion Artefacts
- use `cedalion.nirs.int2d` to get optical densities
- apply Temporal Derivative Distribution Repair (TDDR) first to correct jumps 
- then apply Wavelet motion artifact correction

In [None]:
rec["od"] = cedalion.nirs.int2od(rec["amp"])
rec["od_tddr"] = motion_correct.tddr(rec["od"])
rec["od_wavelet"] = motion_correct.wavelet(rec["od_tddr"])
rec["amp_corrected"] = cedalion.nirs.od2int(rec["od_wavelet"], rec["amp"].mean("time"))


In [None]:
# recalculate sci & psp on cleaned data
sci_corr, sci_corr_mask = quality.sci(rec["amp_corrected"], window_length, sci_threshold)
psp_corr, psp_corr_mask = quality.psp(rec["amp_corrected"], window_length, psp_threshold)
combined_corr_mask = sci_corr_mask & psp_corr_mask

In [None]:
plot_quality_mask(combined_mask, f"combined mask")
plot_quality_mask(combined_corr_mask, f"combined corrected mask")

## Compare masks before and after motion artifact correction

In [None]:
changed_windows = (combined_mask == quality.TAINTED) & (combined_corr_mask == quality.CLEAN)

plot_quality_mask(changed_windows, "mask of time windows cleaned by motion correction", bool_labels=["unchanged", "improved"])

changed_windows = (combined_mask == quality.CLEAN) & (combined_corr_mask == quality.TAINTED)

plot_quality_mask(changed_windows, "mask of time windows corrupted by motion correction", bool_labels=["unchanged", "worsened"])

recalculate percentage of clean time

In [None]:
perc_time_clean_corr = combined_corr_mask.sum(dim="time") / len(sci.time)

f, ax = p.subplots(1,1,figsize=(6.5,6.5))

cedalion.plots.scalp_plot(
    rec["amp"],
    rec.geo3d,
    perc_time_clean_corr,
    ax,
    cmap="RdYlGn",
    vmin=0.80,
    vmax=1,
    title=None,
    cb_label="Percentage of clean time",
    channel_lw=2,
    optode_labels=True
)
f.tight_layout()

## Global Variance of the Temporal Derivative (GVTD) for identifying global bad time segments

In [None]:
gvtd, gvtd_mask = quality.gvtd(rec["amp"])
gvtd_corr, gvtd_prr_mask = quality.gvtd(rec["amp_corrected"])

In [None]:
# select the 10 segments with highest gvtd
top10_bad_segments = sorted(
    [seg for seg in quality.mask_to_segments(combined_mask.all("channel"))],
    key=lambda t: gvtd.sel(time=slice(t[0], t[1])).max(),
    reverse=True,
)[:10]


Calculate GVTD for the original and corrected time series

In [None]:
f,ax = p.subplots(4,1,figsize=(16,6), sharex=True)
ax[0].plot(gvtd.time, gvtd)
ax[1].plot(combined_mask.time, combined_mask.all("channel"))
ax[2].plot(gvtd_corr.time, gvtd_corr)
ax[3].plot(combined_corr_mask.time, combined_corr_mask.all("channel"))
ax[0].set_ylim(0, 0.02)
ax[2].set_ylim(0, 0.02)
ax[0].set_ylabel("GVTD")
ax[2].set_ylabel("GVTD")
ax[1].set_ylabel("combined_mask")
ax[3].set_ylabel("combined_corr_mask")
ax[3].set_xlabel("time / s")

for i in range(4):
    cedalion.plots.plot_segments(ax[i], top10_bad_segments)


## Highlight motion correction in selected segments

In [None]:
example_channels = ["S4D10", "S13D26"]

f, ax = p.subplots(5,4, figsize=(16,16), sharex=False)
ax = ax.T.flatten()
padding = 15
i = 0
for ch in example_channels:
    for (start, end) in top10_bad_segments:
        ax[i].set_prop_cycle(color=["#e41a1c", "#ff7f00", "#377eb8", "#984ea3"])
        for wl in rec["od"].wavelength.values:
            sel = rec["od"].sel(time=slice(start-padding, end+padding), channel=ch, wavelength=wl)
            ax[i].plot(sel.time, sel, label=f"{wl:.0f} nm orig")
            sel = rec["od_wavelet"].sel(time=slice(start-padding, end+padding), channel=ch, wavelength=wl)
            ax[i].plot(sel.time, sel, label=f"{wl:.0f} nm corr")
            ax[i].set_title(ch)
        ax[i].legend(ncol=2, loc="upper center")
        ylim = ax[i].get_ylim()
        ax[i].set_ylim(ylim[0], ylim[1]+0.25*(ylim[1]-ylim[0])) # make space for legend

        i += 1

p.tight_layout()


## Final channel selection

In [None]:
perc_time_clean_corr[perc_time_clean_corr < 0.95]

In [None]:
signal_quality_selection_masks = [perc_time_clean >= .95]

rec["amp_pruned"], pruned_channels = quality.prune_ch(
    rec["amp"], signal_quality_selection_masks, "all"
)
display(rec["amp_pruned"])
display(pruned_channels)