# AMPD - Automatic Multiscale Peak Detection

This notebook provides an end-to-end pipeline for processing and analyzing fNIRS data collected during a finger-tapping task. The primary goal is to identify peaks in the time series data using an **Optimized AMPD** algorithm.

The **AMPD** algorithm is a multiscale peak detection technique that is especially effective for periodic and quasi-periodic signals, such as heart beats, even in the presence of noise. By analyzing the signal at multiple scales, the algorithm can reliably detect local maxima while minimizing false positives. This method is based on the work by **[Scholkmann et al. 2012](https://doi.org/10.3390/a5040588)**



In [None]:
# This cells setups the environment when executed in Google Colab.
try:
    import google.colab
    !curl -s https://raw.githubusercontent.com/ibs-lab/cedalion/colab_setup/scripts/colab_setup.py -o colab_setup.py
    # Select branch with --branch "branch name" (default is "dev")
    %run colab_setup.py
except ImportError:
    pass

In [None]:
import cedalion.nirs
from cedalion import units
from cedalion.sigproc import quality
from cedalion.sigproc.frequency import freq_filter
import cedalion.xrutils as xrutils
from cedalion.datasets import get_fingertapping_snirf_path
import time
import numpy as np
import xarray as xr
from cedalion.sigproc.physio import ampd
import matplotlib.pyplot as plt

xr.set_options(display_max_rows=3, display_values_threshold=50)
np.set_printoptions(precision=4)

### Loading raw CW-NIRS data from a SNIRF file
This notebook uses a finger-tapping dataset in BIDS layout provided by [Rob Luke](https://github.com/rob-luke/BIDS-NIRS-Tapping). It can can be downloaded via `cedalion.datasets`.

### Load amplitude data from the snirf file and extract the first 60 seconds for further processing

In [None]:
path_to_snirf_file = get_fingertapping_snirf_path()

recordings = cedalion.io.read_snirf(path_to_snirf_file)
rec = recordings[0]  # there is only one NirsElement in this snirf file...
amp = rec["amp"]  # ... which holds amplitude data

# restrict to first 60 seconds and fill in missing units
amp = amp.sel(time=amp.time < 60)
times = amp.time.values * 1000
# print(amp.time.values[-1] / 60, len(times))


### Following are utility methods for normalizing, filtering and plotting the signal

In [None]:

# collection of utility functions

def normalize(sig):
    min_val = np.min(sig)
    max_val = np.max(sig)
    return (sig - min_val) / (max_val - min_val)

def filter_signal(amplitudes):
    return freq_filter(amplitudes, 0.5 * units.Hz, 3 * units.Hz, 2)

def plot_peaks(signal, s_times, s_peaks, label, title='peaks'):
    fig, ax = plt.subplots(1, 1, figsize=(24, 8))
    ax.plot(s_times, signal, label=label)

    for ind, peak in enumerate(s_peaks):
        if peak > 0:
            ax.axvline(x=peak, color='black', linestyle='--', linewidth=1)

    plt.title(title)



### This is the amplitude data structure

In [None]:

amp
# filter the signal to remove noise
# amp = filter_signal(amp)

### Now run the *optimized AMPD* on the amplitude data 

In [None]:

# use the optimized AMPD to find the peaks
peaks = ampd(amp)

### Optimized AMPD

The **Optimized AMPD** uses vectorization and data segmentation for improved performance and resource management.


#### Methodology:
1. **Detrending**: Each channel and wavelength signal is first detrended to remove baseline shifts, making peaks easier to detect.
2. **Local Scalogram Matrix (LSM)**: The detrended signal is processed in overlapping chunks. A matrix (LSM) is created to identify regions where local maxima occur across different scales.
3. **Multi-Scale Analysis**: The algorithm analyzes how these maxima behave across scales, accumulating values into a vector \( G \), which helps identify the scale at which peaks are most pronounced.
4. **Peak Identification**: Peaks are identified at locations where the local maxima are consistent across scales, as indicated by low standard deviation values in the LSM.

#### Parameters:
- **`amplitudes`**: An `xarray.DataArray` that contains amplitude data from fnirs signals
- optional **`chunk_size`**: Controls the size of each overlapping segment processed. A larger size can capture more context but may increase computation time.
- optional **`step_size`**: Determines the step increment for the overlapping chunks, controlling how much the segments overlap.

#### Output:
The output, `peaks_xr`, is an `xarray.DataArray` that mirrors the shape and structure of the input data (`amplitudes`). Each detected peak in `peaks_xr` is marked with a `1`, while non-peak values are marked as `0`. This format allows easy access to peak locations across different channels and wavelengths.


In [None]:
peaks

### Now plot the signals with the found peaks from the AMPD for Channel S1D1

In [None]:
# select a channel for displaying the results
channel = "S1D1"
channel_data = amp.sel(channel=channel)

# retrieve the peaks for that channel. peaks contains lists for the channel and both wavelengths 
# where peaks are represented by 1 and non-peaks are 0
peak_indices = peaks.sel(channel=channel)

# extract the timestamps of the identified peaks for one wavelength
peak_times = times * peak_indices.values[1]
peak_times = [pt for pt in peak_times if pt > 0]

# for plotting prepare the signal for the same wavelength
signal = channel_data.values[1]

# plot the signal and the peaks calculated by the optimized AMPD
plot_peaks(signal, times, peak_times, channel, f"peaks: {len(peak_times)}")
