In [1]:
import numpy as np
import cedalion
import cedalion.nirs
import cedalion.imagereco.forward_model as fw
import cedalion.datasets
import os
import cedalion.xrutils as xrutils
import cedalion.plots
import xarray as xr
import cedalion.geometry.landmarks as cd_landmarks
import matplotlib.pyplot as plt
import cedalion.sim.synthetic_hrf as synHRF_ced
from cedalion import units
import cedalion.dataclasses as cdc
import pyvista as pv
#pv.set_jupyter_backend('server') # this enables interactive plots

xr.set_options(display_expand_data=False);

## Loading and preprocessing the dataset

This notebook uses a finger-tapping dataset in BIDS layout provided by [Rob Luke](https://github.com/rob-luke/BIDS-NIRS-Tapping). It can can be downloaded via `cedalion.datasets`.

In [2]:
rec = cedalion.datasets.get_fingertapping()
geo3d = rec.geo3d
meas_list = rec._measurement_lists["amp"]
amp = rec["amp"]
stim = rec.stim
display(stim.groupby("trial_type")[["onset"]].count())

Unnamed: 0_level_0,onset
trial_type,Unnamed: 1_level_1
1.0,30
15.0,2
2.0,30
3.0,30


We only use the geo3d information in this dataset for image reconstruction. Since the dataset already contains stims, but we want to construct synthetic HRFs, we will overwrite the OD data with noise for demonstration purposes. Later we will create our own stim dataframe and add our synthetic HRFs to the OD data.

In [3]:
od = -np.log(amp/amp.mean("time"))

In [4]:
noise_data = np.random.normal(0, od.std().pint.dequantify() / 10, (len(amp.channel), len(amp.wavelength), len(amp.time)))
od.data = xr.DataArray(noise_data, coords=[od.channel, od.wavelength, od.time], dims=["channel", "wavelength", "time"])

## Construct headmodel

We load the the Colin27 headmodel, since we need the geometry for image reconstruction.

In [5]:
SEG_DATADIR, mask_files, landmarks_file = cedalion.datasets.get_colin27_segmentation()

In [6]:
head = fw.TwoSurfaceHeadModel.from_surfaces(
    segmentation_dir=SEG_DATADIR,
    mask_files = mask_files,
    brain_surface_file= os.path.join(SEG_DATADIR, "mask_brain.obj"),
    scalp_surface_file= os.path.join(SEG_DATADIR, "mask_scalp.obj"),
    landmarks_ras_file=landmarks_file,
    brain_face_count=None,
    scalp_face_count=None,
    fill_holes=True,        # needs to be true, otherwise landmark calculation fails
)

In [7]:
head.brain.units = cedalion.units.mm
head.scalp.units = cedalion.units.mm
head.landmarks = head.landmarks.pint.dequantify()
head.landmarks.pint.units = cedalion.units.mm

In [8]:
head.landmarks

0,1
Magnitude,[[89.95007602497842 205.78946685790896 35.85887908935477] [91.75606727600317 25.05000305176084 17.14201354980642] [18.05000305175645 109.8805942535382 17.343021392821214] [165.94999694824514 112.9619064331096 17.922157287600022]]
Units,millimeter


head.landmarks contains the 4 landmarks ['Nz' 'Iz' 'LPA' 'RPA']. 
Since we want to create synthetic HRFs on the brain surface at landmark positions, we need to build the remaining 10-10 landmarks

In [9]:
lmbuilder = cd_landmarks.LandmarksBuilder1010(head.scalp, head.landmarks)
all_landmarks = lmbuilder.build()
head.landmarks = all_landmarks



In [10]:
geo3d_snapped = head.align_and_snap_to_scalp(geo3d)
center_brain = np.mean(head.brain.mesh.vertices, axis=0)

We want to build the synthetic HRFs at C3 and C4 (green dots in the image below)

In [11]:
plt_pv = pv.Plotter()
cedalion.plots.plot_surface(plt_pv, head.brain, color="#d3a6a1")
cedalion.plots.plot_surface(plt_pv, head.scalp, opacity=.1)
cedalion.plots.plot_labeled_points(plt_pv, head.landmarks.sel(label = ["C3", "C4"]), show_labels=True)
cedalion.plots.plot_labeled_points(plt_pv, geo3d_snapped[geo3d_snapped.type != cdc.PointType.LANDMARK], show_labels=True)
plt_pv.camera.position = (head.landmarks.sel(label = "C3").values - center_brain) * 7 + center_brain
plt_pv.show()

  data = np.asarray(data)


Widget(value='<iframe src="http://localhost:46207/index.html?ui=P_0x7d9a286225d0_0&reconnect=auto" class="pyvi…

## Build blob on brain surface for landmarks C3 and C4

Using the the nearest brain vertex to a given landmark, we build a blob on the brain surface. The blob is a Gaussian of the geodesic distance. The size of the blob is determined by the standard deviation of this Gaussian, given by the parameter `scale'

In [12]:
blob_img_c3 = synHRF_ced.build_blob(head, "C3", scale = 2 * cedalion.units.cm)
blob_img_c4 = synHRF_ced.build_blob(head, "C4", scale = 2 * cedalion.units.cm)

  return splu(A).solve


The resulting xarray.DataArray contains an activation value for each vertex on the brain surface. We will use this spatial information to create synthetic HRFs.

In [13]:
blob_img_c3

## Plot blobs @ C3 & C4

There exists a helper function to plot the blobs on the brain surface

In [14]:
plot = True

In [15]:
if plot:
    synHRF_ced.plot_blob(blob_img_c3, head.brain, title="C3 Blob")

Widget(value='<iframe src="http://localhost:46207/index.html?ui=P_0x7d99a9608550_1&reconnect=auto" class="pyvi…

In [16]:
if plot:
    synHRF_ced.plot_blob(blob_img_c4, head.brain, title="C4 Blob")

Widget(value='<iframe src="http://localhost:46207/index.html?ui=P_0x7d99a9eb8b90_2&reconnect=auto" class="pyvi…

## Image Reconstruction

We run the cedalion image reconstruction functionality to be able to map from image to channel space.

In [17]:
fwm = cedalion.imagereco.forward_model.ForwardModel(head, geo3d_snapped, meas_list)
fluence_all, fluence_at_optodes = cedalion.datasets.get_precomputed_fluence("fingertapping", "colin27")
Adot = fwm.compute_sensitivity(fluence_all, fluence_at_optodes)

## HRFs in channel Space

We build a hrf basis model in channel space. This will be our temporal HRF model

In [2]:
stim_dur = 10 * units.seconds
tbasis = synHRF_ced.generate_hrf(amp.time.sel(time=(amp.time < 18)), stim_dur, scale=[10 * units.micromolar, -4 * units.micromolar])
E = cedalion.nirs.get_extinction_coefficients('prahl', amp.wavelength)
Einv = xrutils.pinv(E)
tbasis_od = xr.dot(E, tbasis*1*cedalion.units.mm, dims='chromo')

plt.rcParams['figure.figsize'] = [15, 5]

if plot:

    plt.figure()
    plt.plot(tbasis.time, tbasis.T)
    plt.title('HRF chromo')
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude (M)")
    plt.legend(tbasis.chromo.values)
    plt.show()


Together with the spatial blobs and the Adot matrix (to project from image to channel space), we use this temporal HRF model to create spatio-temporal accurate HRFs in channel space.

In [19]:
syn_HRF_chan_c3 = synHRF_ced.hrfs_from_image_reco(blob_img_c3, tbasis, Adot)
syn_HRF_chan_c4 = synHRF_ced.hrfs_from_image_reco(blob_img_c4, tbasis, Adot)

  data = np.asarray(data)
  data = np.asarray(data)


In [20]:
syn_HRFs_chan = xr.concat([syn_HRF_chan_c3, syn_HRF_chan_c4], dim="trial_type").assign_coords(trial_type=["Stim C3", "Stim C4"])

### We now have a synthetic HRF for each channel, trial_type and wavelength

In [21]:
syn_HRFs_chan

We plot the synthetic HRFs for each channel, trial_type and wavelength

In [3]:
f,ax = plt.subplots(4,7, figsize=(12,8))
ax = ax.flatten()
for i_ch, ch in enumerate(syn_HRFs_chan.channel):
    for ls, trial_type in zip(["-", "--"], syn_HRFs_chan.trial_type):
        ax[i_ch].plot(
            syn_HRFs_chan.time,
            syn_HRFs_chan.sel(wavelength=760.0, trial_type=trial_type, channel=ch),
            "r",
            lw=2,
            ls=ls,
        )
        ax[i_ch].plot(
            syn_HRFs_chan.time,
            syn_HRFs_chan.sel(wavelength=850.0, trial_type=trial_type, channel=ch),
            "b",
            lw=2,
            ls=ls,
        )
    ax[i_ch].grid(1)
    ax[i_ch].set_title(ch.values)
    ax[i_ch].set_ylim(-0.01, 0.02)

# add legend
ax[0].legend(["760.0 C3", "850.0 C3",  "760.0 C4", "850.0 C4"])
plt.tight_layout()
plt.title("Synthetic HRFs in Channel Space")
plt.show()

We can also make a scalp plot of the maximum HRF amplitude for each channel

In [23]:
c3_chan_max = syn_HRF_chan_c3.max(dim=["time", "wavelength"])
c4_chan_max = syn_HRF_chan_c4.max(dim=["time", "wavelength"])

In [4]:
fig, ax = plt.subplots(1,1)
cedalion.plots.scalp_plot(
    rec["amp"],
    rec.geo3d,
    c3_chan_max.values, 
    ax, 
    cmap="jet", title='C3 synHRFs', vmin=-0.01, vmax=0.02, cb_label="max peak amplitude")
plt.show()
fig, ax = plt.subplots(1,1)
cedalion.plots.scalp_plot(
    rec["amp"],
    rec.geo3d,
    c4_chan_max.values, 
    ax, 
    cmap="jet", title='C4 synHRFs', vmin=-0.01, vmax=0.02, cb_label="Max peak amplitude")
plt.show()

Without the whole image reconstruction functionality, we can just add the temporal hrf model to all long channels. (This time in optical density (!). The mapping from concentration to optical density is done as one step in the image reconstruction.)

In [25]:
# geo3d = geo3d.rename({'digitized':'pos'})
# syn_HRFs_chan_no_imagereco = synHRF_ced.hrf_to_long_channels(tbasis_od, amp, geo3d)

## Add HRFs to OD

We generate a synthetic Stim dataframe giving the number of stims, stim duration, trial types, interval length between stims and the order of stims (alternating between trial types or random)

In [26]:
stim_df = synHRF_ced.build_stim_df(num_stims=10, stim_dur=18, trial_types=syn_HRFs_chan.trial_type.values, min_interval=30, max_interval=60, order="alternating")

In [27]:
stim_df.head()

Unnamed: 0,onset,duration,value,trial_type
0,39.64,18,1,Stim C3
1,113.12,18,1,Stim C4
2,173.29,18,1,Stim C3
3,228.6,18,1,Stim C4
4,293.43,18,1,Stim C3


### We can now use our synthetic HRFs and our stim dataframe to add the HRFs to the OD data

In [28]:
# add to od
od_w_hrf = synHRF_ced.add_hrf_to_od(od, syn_HRFs_chan, stim_df)

In [29]:
od_w_hrf

Below is 
1) the od data with added HRFs (of channel S3D2)
2) the od data with added HRFs subtracted by the original OD data, showing only the HRFs (of channel S3D2)

In [5]:
if plot:
    plt.figure()
    od_w_hrf.sel(wavelength="760.0", channel="S3D2", time=od.time < 350).plot()
    od_w_hrf.sel(wavelength="850.0", channel="S3D2", time=od.time < 350).plot()
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.title("od_w_hrf")
    plt.legend(od_w_hrf.wavelength.values)
    plt.show()

    plt.figure()
    (od_w_hrf - od).sel(wavelength="760.0", channel="S3D2", time=od.time < 350).plot()
    (od_w_hrf - od).sel(wavelength="850.0", channel="S3D2", time=od.time < 350).plot()
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.title("od_w_hrf - od")
    plt.legend(od_w_hrf.wavelength.values)
    plt.show()

## Recover the HRFs from OD data

We filter the data and calculate block averages over epochs to recover the HRFs

In [31]:
od_w_hrf_filtered = od_w_hrf.cd.freq_filter(fmin=0.02, fmax=0.5, butter_order=4)

In [6]:
epochs = od_w_hrf_filtered.cd.to_epochs(
        stim_df, # stimulus dataframe
        ["Stim C3", "Stim C4"],  # select events
        before=5, # seconds before stimulus
        after=20  # seconds after stimulus
)

# calculate baseline
baseline = epochs.sel(reltime=(epochs.reltime < 0)).mean("reltime")
# subtract baseline
epochs_blcorrected = epochs - baseline

# group trials by trial_type. For each group individually average the epoch dimension
blockaverage = epochs_blcorrected.groupby("trial_type").mean("epoch")


# show results
f,ax = plt.subplots(4,7, figsize=(12,8))
ax = ax.flatten()
for i_ch, ch in enumerate(blockaverage.channel):
    for ls, trial_type in zip(["-", "--"], blockaverage.trial_type):    
        ax[i_ch].plot(blockaverage.reltime, blockaverage.sel(wavelength=760, trial_type=trial_type, channel=ch), "r", lw=2, ls=ls)
        ax[i_ch].plot(blockaverage.reltime, blockaverage.sel(wavelength=850, trial_type=trial_type, channel=ch), "b", lw=2, ls=ls)
    ax[i_ch].grid(1)
    ax[i_ch].set_title(ch.values)
    ax[i_ch].set_ylim(-.01, .02)


plt.suptitle("760nm: r | 850nm: b | C3: - | C4: --")
plt.tight_layout()