# Storing estimated HRFs in snirf files

This notebook estimates the HRF in a finger-tapping experiment by blockaveraging and then stores the result in a snirf file.

In [None]:
from pathlib import Path
import tempfile

import matplotlib.pyplot as p
import numpy as np
import xarray as xr
from snirf import Snirf

import cedalion
import cedalion.datasets
import cedalion.io
import cedalion.nirs

xr.set_options(display_max_rows=3, display_values_threshold=50, display_expand_data=False)
np.set_printoptions(precision=4)

## Load a finger-tapping dataset 

For this demo we load an example finger-tapping recording through `cedalion.datasets.get_fingertapping`. The file contains a single NIRS element with one block of raw amplitude data. 

In [None]:
rec = cedalion.datasets.get_fingertapping()

In [None]:
# Rename events
rec.stim.cd.rename_events( {
    "1.0" : "control",
    "2.0" : "Tapping/Left",
    "3.0" : "Tapping/Right"
})

## Calculate concentrations

In [None]:
dpf = xr.DataArray([6, 6], dims="wavelength", coords={"wavelength" : rec["amp"].wavelength})
rec["od"] = - np.log( rec["amp"] / rec["amp"].mean("time") )
rec["conc"] = cedalion.nirs.beer_lambert(rec["amp"], rec.geo3d, dpf)

## Frequency filtering and splitting into epochs

In [None]:
rec["conc_freqfilt"] = rec["conc"].cd.freq_filter(fmin=0.02, fmax=0.5, butter_order=4)


cf_epochs = rec["conc_freqfilt"].cd.to_epochs(
    rec.stim,  # stimulus dataframe
    ["Tapping/Left", "Tapping/Right"],  # select events
    before=5,  # seconds before stimulus
    after=20,  # seconds after stimulus
)

## Blockaveraging to estimate the HRFs

In [None]:
# calculate baseline
baseline = cf_epochs.sel(reltime=(cf_epochs.reltime < 0)).mean("reltime")
# subtract baseline
epochs_blcorrected = cf_epochs - baseline

# group trials by trial_type. For each group individually average the epoch dimension
rec["hrf_blockaverage"] = epochs_blcorrected.groupby("trial_type").mean("epoch")

display(rec["hrf_blockaverage"])

## Store HRFs 

In [None]:
tmpdir = tempfile.TemporaryDirectory()
snirf_fname = str(Path(tmpdir.name) / "test.snirf")

In [None]:
cedalion.io.snirf.write_snirf(snirf_fname, rec)

## Inspect snirf file

In [None]:
s = Snirf(snirf_fname)

In [None]:
display(s.nirs)
display(s.nirs[0].data)

#### Stim DataFrame

In [None]:
display(s.nirs[0].stim)

In [None]:
cedalion.io.snirf.stim_to_dataframe(s.nirs[0].stim)

#### MeasurementList DataFrame

In [None]:
for i_data in range(len(s.nirs[0].data)):
    df_ml = cedalion.io.snirf.measurement_list_to_dataframe(
        s.nirs[0].data[i_data].measurementList, drop_none=True
    )
    display(df_ml.head(3))
    display(df_ml.tail(3))

In [None]:
s.close()

## Read HRFs from snirf file

Note: read_snirf names the time dimension `time` whereas in `blockaverage` it was called `reltime`. Need to agree on a convention.

In [None]:
hrf_recs = cedalion.io.read_snirf(snirf_fname)
display(hrf_recs)

In [None]:
# the name of the stored blockaverages derives from the datatype (HRF) and the type of data (concentration)
read_blockaverage = hrf_recs[0]["hrf_conc"]
display(read_blockaverage)

#### Assert that the written and read HRFs are identical.

In [None]:
assert (read_blockaverage.rename({"time" : "reltime"}) == rec["hrf_blockaverage"]).all()

## Plot the HRFs

In [None]:
ba = read_blockaverage

f,ax = p.subplots(5,6, figsize=(12,10))
ax = ax.flatten()
for i_ch, ch in enumerate(ba.channel.values):
    for ls, trial_type in zip(["-", "--"], ba.trial_type):    
        ax[i_ch].plot(ba.time, ba.sel(chromo="HbO", trial_type=trial_type, channel=ch), "r", lw=2, ls=ls)
        ax[i_ch].plot(ba.time, ba.sel(chromo="HbR", trial_type=trial_type, channel=ch), "b", lw=2, ls=ls)
    ax[i_ch].grid(1)
    ax[i_ch].set_title(ch)
    ax[i_ch].set_ylim(-.3, .6)

p.tight_layout()

## Tidy up

In [None]:
del tmpdir