In [1]:
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
import pyvista as pv
import matplotlib.pyplot as p
import pyvista as pv
from matplotlib.colors import LinearSegmentedColormap,ListedColormap, BoundaryNorm

import matplotlib.cm as cm

import cedalion.data
import cedalion.dot as dot

import cedalion.dataclasses as cdc
import cedalion.sigproc.motion_correct as motion_correct

pv.set_jupyter_backend("html")

from cedalion import units

import cedalion
import cedalion.sigproc.quality as quality
import cedalion.sigproc.physio as physio
import cedalion.nirs as nirs
# import cedalion.plots as plots
from cedalion.io.probe_geometry import load_tsv
import cedalion.geometry.landmarks

# from cedalion.imagereco.solver import pseudo_inverse_stacked
from cedalion.io.forward_model import load_Adot,save_Adot

from typing import Tuple
from pint import Quantity

%load_ext autoreload
%autoreload 2
%matplotlib widget

os.getcwd()

'/home/orabe/fNIRS_sparseToDense'

In [2]:
def get_bad_ch_mask(int_data, ch_preproc, pruned_sig):
    # Saturated and Dark Channels
    amp_threshs_sat = [0., ch_preproc['dark_sat_thresh'][1]]
    amp_threshs_low = [ch_preproc['dark_sat_thresh'][0], 1]
    _, amp_mask_sat = quality.mean_amp(int_data, amp_threshs_sat)
    _, amp_mask_low = quality.mean_amp(int_data, amp_threshs_low)
    _, snr_mask = quality.snr(int_data, 10)
    amp_mask=amp_mask_sat & amp_mask_low

    pruned_sig, list_bad_ch = quality.prune_ch(pruned_sig, [amp_mask, snr_mask], "all")

    print("Flagged Channels : ",len(list_bad_ch), '/', len(int_data.channel))
    print("Percentage: ", int(len(list_bad_ch) / len(int_data.channel) * 100), '%')
    
    return list_bad_ch, pruned_sig

In [3]:
import glob
# DATASET_NAME = "BallSqueezingHD_modified"
DATASET_NAME = "FreshMotor"
dataset_path = f"datasets/raw/{DATASET_NAME}"

files = glob.glob(dataset_path + "/**/*.snirf", recursive=True)
len(files)

40

In [4]:

c_meas_mean_runs = []
c_meas_var_runs = []

for f in files:
    #----------------------------LOAD----------------------------------------------------------
    try:
        rec = cedalion.io.read_snirf(f)[0]  # read snirf files
    except Exception as e:
        print(e)

    #--------------------------CONF------------------------------------------------------------
    ch_preproc = {
    'FLAG_OW'   : False,
    'FLAG_Plot' : True,
    'channel'   : 'S1D1',
    'pwindow'   : slice(4320,4500),
    'hp_filt'   : [0.008,0],            # fmin = 0.008, fmax = 0
    'sci_thresh' : 0.6,
    'psp_thresh' : 0.1,
    'dark_sat_thresh' : [1e-3, 0.84],
    'perc_time_clean' : 0.5             # 50 %   

    }

    # pre-processing 
    rec['rep_amp'] = quality.repair_amp(rec['amp'], median_len=3, method='linear')  # Repair Amp
    rec['od_amp'], baseline= nirs.cw.int2od(rec['rep_amp'],return_baseline=True)

    # motion correct [TDDR + WAVELET]
    rec["od_tddr"] = motion_correct.tddr(rec["od_amp"])
    rec["od_tddr_wavel"] = motion_correct.wavelet(rec["od_tddr"])

    #-----------------------------------------highpass filter--------------------------------
    rec['od_hpfilt'] = rec['od_tddr_wavel'].cd.freq_filter(fmin=0.008,fmax=0,butter_order=4)
    #----------------------------------------------------------------------------------------

    # clean amplitude data
    rec['amp_clean'] = cedalion.nirs.cw.od2int(rec['od_hpfilt'], baseline)

    # get bad channel mask
    list_bad_ch, rec['od_hpfilt'] = get_bad_ch_mask(
        rec["amp_clean"], 
        ch_preproc,
        rec['od_hpfilt']) # this has custom paramerers!? 

    print('the list of bad channels: ', len(list_bad_ch))

    # channel variance
    c_meas = quality.measurement_variance(rec["od_hpfilt"], list_bad_channels=list_bad_ch, bad_rel_var=1e6,calc_covariance=False)
    
    c_meas_mean_runs.append(c_meas.mean(dim="channel"))
    c_meas_var_runs.append(c_meas.var(dim="channel"))


Flagged Channels :  14 / 68
Percentage:  20 %
the list of bad channels:  14
Flagged Channels :  15 / 68
Percentage:  22 %
the list of bad channels:  15
Flagged Channels :  15 / 68
Percentage:  22 %
the list of bad channels:  15
Flagged Channels :  14 / 68
Percentage:  20 %
the list of bad channels:  14
Flagged Channels :  10 / 68
Percentage:  14 %
the list of bad channels:  10
Flagged Channels :  11 / 68
Percentage:  16 %
the list of bad channels:  11
Flagged Channels :  13 / 68
Percentage:  19 %
the list of bad channels:  13
Flagged Channels :  10 / 68
Percentage:  14 %
the list of bad channels:  10
Flagged Channels :  10 / 68
Percentage:  14 %
the list of bad channels:  10
Flagged Channels :  55 / 68
Percentage:  80 %
the list of bad channels:  55
Flagged Channels :  56 / 68
Percentage:  82 %
the list of bad channels:  56
Flagged Channels :  12 / 68
Percentage:  17 %
the list of bad channels:  12
Flagged Channels :  13 / 68
Percentage:  19 %
the list of bad channels:  13
Flagged Chan

In [5]:
c_meas.std("channel")

0,1
Magnitude,[0.0013390826658658355 0.0003612924840349499]
Units,dimensionless


In [6]:
c_meas_mean_allruns = xr.concat(c_meas_mean_runs, dim="run")
c_meas_var_allruns = xr.concat(c_meas_var_runs, dim="run")

In [7]:
import pickle
# save pickle c_meas_mean and c_meas_var an plot the mean and var across all runs (with dataset name) (:
with open(f'c_meas_mean_allRuns_{DATASET_NAME}.pkl', 'wb') as handle:
    pickle.dump(c_meas_mean_allruns, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open(f'c_meas_var_allRuns_{DATASET_NAME}.pkl', 'wb') as handle:
    pickle.dump(c_meas_var_allruns, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [8]:
# Load them and print the mean and var across all runs
with open(f'c_meas_mean_allRuns_{DATASET_NAME}.pkl', 'rb') as handle:
    c_meas_mean_allruns = pickle.load(handle)
with open(f'c_meas_var_allRuns_{DATASET_NAME}.pkl', 'rb') as handle:
    c_meas_var_allruns = pickle.load(handle)

# Average across runs
c_meas_mean_avg = c_meas_mean_allruns.mean(dim="run")
c_meas_var_avg = c_meas_var_allruns.mean(dim="run")


In [9]:
c_meas_mean_avg

0,1
Magnitude,[0.00023028080397673807 0.00018084825168167192]
Units,dimensionless


In [10]:
c_meas_var_avg

0,1
Magnitude,[7.355987996118687e-07 4.302497628473456e-08]
Units,dimensionless
