## Motion Artefact Detection and Correction (WIP / in validation)
The same xarray-based masks can be used for indicating motion-artefacts. The example below shows how to checks channels for motion artefacts using standard thresholds from Homer2/3. The output is a mask that can be handed to motion correction algorithms

In [None]:
import cedalion
import cedalion.nirs
import cedalion.sigproc.quality as quality
import cedalion.sigproc.motion_correct as motion_correct
import cedalion.xrutils as xrutils
import cedalion.datasets as datasets
import xarray as xr
import matplotlib.pyplot as p
import cedalion.plots as plots
from cedalion import units
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

In [None]:
# get example finger tapping dataset

rec = datasets.get_fingertapping()
rec["od"] = cedalion.nirs.int2od(rec["amp"])

# Plot some data for visual validation
f,ax = p.subplots(1,1, figsize=(12,4))
ax.plot( rec["amp"].time, rec["amp"].sel(channel="S3D3", wavelength="850"), "r-", label="850nm")
ax.plot( rec["amp"].time, rec["amp"].sel(channel="S3D3", wavelength="760"), "r-", label="760nm")
p.legend()
ax.set_xlabel("time / s")
ax.set_ylabel("Signal intensity / a.u.")

display(rec["amp"])

### Detecting Motion Artifacts and generating the MA mask

In [None]:
# we use Optical Density data for motion artifact detection
fNIRSdata = rec["od"]

# define parameters for motion artifact detection. We follow the method from Homer2/3: "hmrR_MotionArtifactByChannel" and "hmrR_MotionArtifact".
t_motion = 0.5*units.s  # time window for motion artifact detection
t_mask = 1.0*units.s    # time window for masking motion artifacts (+- t_mask s before/after detected motion artifact)
stdev_thresh = 4.0      # threshold for standard deviation of the signal used to detect motion artifacts. Default is 50. We set it very low to find something in our good data for demonstration purposes.
amp_thresh = 5.0        # threshold for amplitude of the signal used to detect motion artifacts. Default is 5.

# to identify motion artifacts with these parameters we call the following function
ma_mask = quality.id_motion(fNIRSdata, t_motion, t_mask, stdev_thresh, amp_thresh)
# it hands us a boolean mask (xarray) of the input dimension, where False indicates a motion artifact at a given time point.

# show the masks data
ma_mask

The output mask is quite detailed and still contains all original dimensions (e.g. single wavelengths) and allows us to combine it with a mask from another motion artifact detection method. This is the same approach as for the channel quality metrics above.

Let us now plot the result for an example channel. Note, that for both wavelengths a different number of artifacts was identified, which can sometimes happen:

In [None]:
p.figure()
p.plot(ma_mask.sel(time=slice(0,250)).time, ma_mask.sel(channel="S3D3", wavelength="760", time=slice(0,250)), "b-")
p.plot(ma_mask.sel(time=slice(0,250)).time, ma_mask.sel(channel="S3D3", wavelength="850", time=slice(0,250)), "r-")
p.xlabel("time / s")
p.ylabel("Motion artifact mask")
p.show() 

Our example dataset is very clean. So we artificially detected motion artifacts with a very low threshold. Plotting the mask and the data together (we have to rescale a bit to make both fit): 

In [None]:
p.figure()
p.plot(fNIRSdata.sel(time=slice(0,250)).time, fNIRSdata.sel(channel="S3D3", wavelength="760", time=slice(0,250)), "r-")
p.plot(ma_mask.sel(time=slice(0,250)).time, ma_mask.sel(channel="S3D3", wavelength="760", time=slice(0,250))/10, "k-")
p.xlabel("time / s")
p.ylabel("fNIRS Signal / Motion artifact mask")
p.show() 

### Refining the MA Mask
At the latest when we want to correct motion artifacts, we usually do not need the level of granularity that the mask provides. For instance, we usually want to treat a detected motion artifact in either of both wavelengths or chromophores of one channel as a single artifact that gets flagged for both. We might also want to flag motion artifacts globally, i.e. mask time points for all channels even if only some of them show an artifact. This can easily be done by using the "id_motion_refine" function. The function also returns useful information about motion artifacts in each channel in "ma_info"

In [None]:
# refine the motion artifact mask. This function collapses the mask along dimensions that are chosen by the "operator" argument.
# Here we use "by_channel", which will yield a mask for each channel by collapsing the masks along either the wavelength or concentration dimension.
ma_mask_refined, ma_info = quality.id_motion_refine(ma_mask, 'by_channel')

# show the refined mask
ma_mask_refined

Now the mask does not have the "wavelength" or "concentration" dimension anymore, and the masks of these dimensions are combined:

In [None]:
# plot the figure

p.figure()
p.plot(fNIRSdata.sel(time=slice(0,250)).time, fNIRSdata.sel(channel="S3D3", wavelength="760", time=slice(0,250)), "r-")
p.plot(ma_mask_refined.sel(time=slice(0,250)).time, ma_mask_refined.sel(channel="S3D3", time=slice(0,250))/10, "k-")
p.xlabel("time / s")
p.ylabel("fNIRS Signal / Refined Motion artifact mask")
p.show() 

# show the information about the motion artifacts: we get a pandas dataframe telling us 
# 1) for which channels artifacts were detected, 
# 2) what is the fraction of time points that were marked as artifacts and
# 3) how many artifacts where detected 
ma_info

Now we look at the "all" operator, which will collapse the mask across all dimensions except time, leading to a single motion artifact mask

In [None]:
# "all", yields a mask that flags an artifact at any given time if flagged for any channetransl, wavelength, chromophore, etc.
ma_mask_refined, ma_info = quality.id_motion_refine(ma_mask, 'all')

# show the refined mask
ma_mask_refined

In [None]:
# plot the figure

p.figure()
p.plot(fNIRSdata.sel(time=slice(0,250)).time, fNIRSdata.sel(channel="S3D3", wavelength="760", time=slice(0,250)), "r-")
p.plot(ma_mask_refined.sel(time=slice(0,250)).time, ma_mask_refined.sel(time=slice(0,250))/10, "k-")
p.xlabel("time / s")
p.ylabel("fNIRS Signal / Refined Motion artifact mask")
p.show() 

# show the information about the motion artifacts: we get a pandas dataframe telling us 
# 1) that the mask is for all channels
# 2) fraction of time points that were marked as artifacts for this mask across all channels
# 3) how many artifacts where detected in total
ma_info

### Motion Correction
#### SplineSG method: 
1. identifies baselineshifts in the data and uses spline interpolation to correct these shifts
2. uses a Savitzky-Golay filter to remove spikes


In [None]:
frame_size = 10 * units.s
rec['od_splineSG'] = motion_correct.motion_correct_splineSG(rec['od'], frame_size=frame_size, p=1)

f, ax = p.subplots(1,1, figsize=(12,4))
ax.plot(rec['od'].time, rec['od'].sel(channel="S3D3", wavelength="850"), "r-", label="850nm raw")
ax.plot(rec['od'].time, rec['od_splineSG'].sel(channel="S3D3", wavelength="850"), "g-", label="850nm cleaned")
ax.set_xlim(0,500)
ax.legend()

#### PCA recurse method:

- WIP