# Precomupte forward model results

In [None]:
import pyvista as pv
pv.set_jupyter_backend('static')
#import memory_profiler

In [None]:
from functools import lru_cache
import os

import numpy as np
import xarray as xr

import cedalion.datasets
import cedalion.geometry.segmentation
import cedalion.imagereco.forward_model as fw
import cedalion.io.forward_model
from cedalion.io.forward_model import FluenceFile, load_Adot
import cedalion.plots
xr.set_options(display_expand_data=False);

In [None]:
def compute_fluence_mcx(rec, head, output_file):
    geo3d_snapped_ijk = head.align_and_snap_to_scalp(rec.geo3d)

    fwm = cedalion.imagereco.forward_model.ForwardModel(
        head, geo3d_snapped_ijk, rec._measurement_lists["amp"]
    )

    fwm.compute_fluence_mcx(output_file)

    #with cedalion.io.forward_model.FluenceFile(output_file ,"r") as fluence_file:
    #    fluence_all = fluence_file.get_fluence_all()
    #    fluence_at_optodes = fluence_file.get_fluence_at_optodes()

    #return fluence_all, fluence_at_optodes


def plot_fluence(rec, head, fluence_fname, src, det, wl):
    geo3d_snapped_ijk = head.align_and_snap_to_scalp(rec.geo3d)

    with FluenceFile(fluence_fname) as fluence_file:
        f = fluence_file.get_fluence(src, wl) * fluence_file.get_fluence(det, wl)

    # clip fluence to smallest positive value and transform to log
    f[f <= 0] = f[f > 0].min()
    f = np.log10(f)

    vf = pv.wrap(f)

    plt = pv.Plotter()
    plt.add_volume(
        vf,
        log_scale=False,
        cmap="plasma_r",
        clim=(-10, 0),
    )
    cedalion.plots.plot_surface(plt, head.brain, color="w")
    cedalion.plots.plot_labeled_points(plt, geo3d_snapped_ijk, show_labels=False)

    cog = head.brain.vertices.mean("label").values
    plt.camera.position = cog + [-300, 30, 150]
    plt.camera.focal_point = cog
    plt.camera.up = [0, 0, 1]

    plt.show()

def compute_sensitivity(rec, head, fluence_fname, sensitivity_fname):
    geo3d_snapped_ijk = head.align_and_snap_to_scalp(rec.geo3d)

    fwm = cedalion.imagereco.forward_model.ForwardModel(
        head, geo3d_snapped_ijk, rec._measurement_lists["amp"]
    )

    fwm.compute_sensitivity(fluence_fname, sensitivity_fname)

@lru_cache
def get_headmodel(model : str):
    if model == "colin27":
        SEG_DATADIR, mask_files, landmarks_file = (
            cedalion.datasets.get_colin27_segmentation()
        )

        head = fw.TwoSurfaceHeadModel.from_segmentation(
            segmentation_dir=SEG_DATADIR,
            mask_files=mask_files,
            landmarks_ras_file=landmarks_file,
        )


    elif model == "icbm152":
        SEG_DATADIR, mask_files, landmarks_file = (
        cedalion.datasets.get_icbm152_segmentation()
        )

        head = fw.TwoSurfaceHeadModel.from_surfaces(
            segmentation_dir=SEG_DATADIR,
            mask_files=mask_files,
            brain_surface_file=os.path.join(SEG_DATADIR, "mask_brain.obj"),
            landmarks_ras_file=landmarks_file,
            brain_face_count=None,
            scalp_face_count=None,
        )
    else:
        raise ValueError("unknown head model")

    return head

def get_fnirs_dataset(dataset):
    if dataset == "fingertappingDOT":
        return cedalion.datasets.get_fingertappingDOT()
    elif dataset == "fingertapping":
        return cedalion.datasets.get_fingertapping()


def compute_fluence_and_sensitivity(dataset : str, headmodel : str):
    rec = get_fnirs_dataset(dataset)
    head = get_headmodel(headmodel)

    fluence_fname = f"fluence_{dataset}_{headmodel}.h5"
    sensitivity_fname = f"sensitivity_{dataset}_{headmodel}.nc"

    compute_fluence_mcx(rec, head, fluence_fname)
    compute_sensitivity(rec, head, fluence_fname, sensitivity_fname)

In [None]:
compute_fluence_and_sensitivity("fingertappingDOT", "colin27")

In [None]:
compute_fluence_and_sensitivity("fingertappingDOT", "icbm152")

In [None]:
compute_fluence_and_sensitivity("fingertapping", "colin27")

In [None]:
compute_fluence_and_sensitivity("fingertapping", "icbm152")

In [None]:
ls -lh fluence*.h5