# Storing estimated HRFs in snirf files

This notebook estimates the HRF in a finger-tapping experiment by blockaveraging and then stores the result in a snirf file.

In [None]:
import cedalion
import cedalion.nirs
import cedalion.datasets
import cedalion.io

from snirf import Snirf
import numpy as np
import xarray as xr

import matplotlib.pyplot as p
import os.path
import tempfile

xr.set_options(display_max_rows=3, display_values_threshold=50, display_expand_data=False)
np.set_printoptions(precision=4)

## Load a finger-tapping dataset 

For this demo we load an example finger-tapping recording through `cedalion.datasets.get_fingertapping`. The file contains a single NIRS element with one block of raw amplitude data. 

In [None]:
elements = cedalion.datasets.get_fingertapping()

In [None]:
amp = elements[0].data[0]
stim = elements[0].stim # pandas Dataframe
geo3d = elements[0].geo3d
    
# Rename events
stim.cd.rename_events( {
    "1.0" : "control",
    "2.0" : "Tapping/Left",
    "3.0" : "Tapping/Right"
})

## Calculate concentrations

In [None]:
dpf = xr.DataArray([6, 6], dims="wavelength", coords={"wavelength" : amp.wavelength})
od = - np.log( amp / amp.mean("time") )
conc = cedalion.nirs.beer_lambert(amp, geo3d, dpf)

## Frequency filtering and splitting into epochs

In [None]:
conc_freqfilt = conc.cd.freq_filter(fmin=0.02, fmax=0.5, butter_order=4)

    
cf_epochs = conc_freqfilt.cd.to_epochs(
    stim, # stimulus dataframe
    ["Tapping/Left", "Tapping/Right"],  # select events
    before=5, # seconds before stimulus
    after=20  # seconds after stimulus
)

## Blockaveraging to estimate the HRFs

In [None]:
# calculate baseline
baseline = cf_epochs.sel(reltime=(cf_epochs.reltime < 0)).mean("reltime")
# subtract baseline
epochs_blcorrected = cf_epochs - baseline

# group trials by trial_type. For each group individually average the epoch dimension
blockaverage = epochs_blcorrected.groupby("trial_type").mean("epoch")

display(blockaverage)

## Store HRFs 

In [None]:
tmpdir = tempfile.TemporaryDirectory()
snirf_fname = os.path.join(tmpdir.name, "test.snirf")

In [None]:
cedalion.io.snirf.write_snirf(
    snirf_fname, 
    "hrf", 
    blockaverage, 
    geo3d, 
    stim
)

## Inspect snirf file

In [None]:
s = Snirf(snirf_fname)

#### Stim DataFrame

In [None]:
cedalion.io.snirf.stim_to_dataframe(s.nirs[0].stim)

#### MeasurementList DataFrame

In [None]:
df_ml = cedalion.io.snirf.measurement_list_to_dataframe(s.nirs[0].data[0].measurementList, drop_none=True)
df_ml

In [None]:
s.close()

## Read HRFs from snirf file

Note: read_snirf names the time dimension `time` whereas in `blockaverage` it was called `reltime`. Need to agree on a convention.

In [None]:
hrf_elements = cedalion.io.read_snirf(snirf_fname)

In [None]:
read_blockaverage = hrf_elements[0].data[0]
display(read_blockaverage)

#### Assert that the written and read HRFs are identical.

In [None]:
assert (read_blockaverage.rename({"time" : "reltime"}) == blockaverage).all()

## Plot the HRFs

In [None]:
ba = read_blockaverage

f,ax = p.subplots(5,6, figsize=(12,10))
ax = ax.flatten()
for i_ch, ch in enumerate(ba.channel.values):
    for ls, trial_type in zip(["-", "--"], ba.trial_type):    
        ax[i_ch].plot(ba.time, ba.sel(chromo="HbO", trial_type=trial_type, channel=ch), "r", lw=2, ls=ls)
        ax[i_ch].plot(ba.time, ba.sel(chromo="HbR", trial_type=trial_type, channel=ch), "b", lw=2, ls=ls)
    ax[i_ch].grid(1)
    ax[i_ch].set_title(ch)
    ax[i_ch].set_ylim(-.3, .6)

p.tight_layout()

## Tidy up

In [None]:
del tmpdir