In [3]:
!nvidia-smi

Sat Dec 20 17:35:54 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.57.08              Driver Version: 575.57.08      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        On  |   00000000:B1:00.0 Off |                  N/A |
|  0%   26C    P8             14W /  350W |       1MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
import pickle
import cedalion
import numpy as np
import cedalion.sigproc.motion_correct as motion_correct
import cedalion.sigproc.quality as quality
import cedalion.sigproc.physio as physio
from cedalion.io.forward_model import load_Adot,save_Adot
import cedalion.dot as dot
from cedalion import units

import cedalion.nirs as nirs
import os
import xarray as xr
import glob
import warnings
import sys
from pathlib import Path, PureWindowsPath
warnings.filterwarnings("ignore")


ModuleNotFoundError: No module named 'cedalion.sigproc.motion_correct'

In [None]:
def get_bad_ch_mask(int_data):
    # Saturated and Dark Channels

    dark_sat_thresh = [1e-3, 0.84]
    amp_threshs_sat = [0., dark_sat_thresh[1]]
    amp_threshs_low = [dark_sat_thresh[0], 1]
    _, amp_mask_sat = quality.mean_amp(int_data, amp_threshs_sat)
    _, amp_mask_low = quality.mean_amp(int_data, amp_threshs_low)
    _, snr_mask = quality.snr(int_data, 10)
    amp_mask=amp_mask_sat & amp_mask_low

    _, list_bad_ch = quality.prune_ch(int_data, [amp_mask, snr_mask], "all")
   
    return list_bad_ch

In [None]:
base_path = "/home/orabe/fNIRS_sparseToDense/"
DATASET_NAME = "BallSqueezingHD_modified"

raw_path = Path(f'datasets/raw/{DATASET_NAME}')

pre_processed_path = Path(f'datasets/pre_processed/{DATASET_NAME}')
pre_processed_path.mkdir(parents=True, exist_ok=True)

In [None]:
base_path = '/home/orabe/fNIRS_sparseToDense'

# Available datasets:
# DATASET_NAME = "BallSqueezingHD_modified"
DATASET_NAME = "FreshMotor"
# DATASET_NAME = "BS_Laura"
# DATASET_NAME = "ElectricalThermal"


if DATASET_NAME == "BallSqueezingHD_modified":
    raw_path = f"{raw_path}/{DATASET_NAME}/sub-*/nirs/sub-*.snirf"

elif DATASET_NAME == "BS_Laura":
    raw_path = f"{raw_path}/{DATASET_NAME}/sub-*/nirs/sub-*.snirf"
    
elif DATASET_NAME == "Electrical_Thermal":
    raw_path = f"{raw_path}/{DATASET_NAME}/sub-*/ses-*/nirs/sub-*_ses-*_task-Electrical*_nirs.snirf"
    # TODO: exclude subjects without txt files for landmarks coords
    
elif DATASET_NAME == "FreshMotor":
    duration = "*" # * to include both 2s and 3s
    raw_path = f"{raw_path}/{DATASET_NAME}/sub-*/ses-*{duration}/nirs/sub-*_ses-*{duration}_task-FRESHMOTOR_nirs.snirf"
else:
    raise ValueError("Unknown dataset name")

files = glob.glob(raw_path)

# TODO: to be confirmed
# remove non-BS files for Laura's dataset to avoid errors
if DATASET_NAME == "BS_Laura":
    files = [p for p in files if "BS" in os.path.basename(p)]
    
files = sorted(files)
print(f"{len(files)} files found.")

In [None]:
filename = files[0] # select one
rec = cedalion.io.read_snirf(filename)[0]  # read snirf files
meas_list = rec._measurement_lists["amp"]

head_icbm152 = dot.get_standard_headmodel('icbm152')  
geo3d_snapped_ijk = head_icbm152.align_and_snap_to_scalp(rec.geo3d)

fwm = cedalion.dot.forward_model.ForwardModel(
    head_icbm152, 
    geo3d_snapped_ijk,
    meas_list
)

fluence_fname = os.path.join(pre_processed_path, "fluence_" + DATASET_NAME + ".h5")
sensitivity_fname = os.path.join(pre_processed_path, "sensitivity_" + DATASET_NAME + ".h5")

fwm.compute_fluence_mcx(fluence_fname)
fwm.compute_sensitivity(fluence_fname, sensitivity_fname)

Adot = load_Adot(sensitivity_fname)
recon = dot.ImageRecon(
    Adot,
    recon_mode="mua2conc",
    brain_only=True,
    alpha_meas=10,
    alpha_spatial=10e-3,
    apply_c_meas=True,
    spatial_basis_functions=None,
)

In [None]:
subject_to_rec = {}
skipped_subjects = []

for f in files:
    records = cedalion.io.read_snirf(f)
    rec = records[0]

    rec.stim = rec.stim.sort_values(by="onset") ## Yuanyuan dataset

    rec['rep_amp'] = quality.repair_amp(rec['amp'], median_len=3, method='linear')  # Repair Amp
    rec['od_amp'], baseline= nirs.cw.int2od(rec['rep_amp'],return_baseline=True)

    # motion correct [TDDR + WAVELET]
    rec["od_tddr"] = motion_correct.tddr(rec["od_amp"])
    rec["od_tddr_wavel"] = motion_correct.wavelet(rec["od_tddr"])

    #-----------------------------------------highpass filter--------------------------------
    rec['od_hpfilt'] = rec['od_tddr_wavel'].cd.freq_filter(fmin=0.008,fmax=0,butter_order=4)
    #----------------------------------------------------------------------------------------

    # clean amplitude data
    rec['amp_clean'] = cedalion.nirs.cw.od2int(rec['od_hpfilt'], baseline)

    # get bad channel mask
    list_bad_ch = get_bad_ch_mask(rec["amp_clean"]) # this has custom paramerers!? 
    print('the list of bad channels: ', len(list_bad_ch))

    # channel variance
    od_var_vec = quality.measurement_variance(rec["od_hpfilt"], list_bad_channels=list_bad_ch, bad_rel_var=1e6,calc_covariance=False)

    #---------------------------------------------------------------------------------------
    dpf = xr.DataArray(
        [6, 6],
        dims="wavelength",
        coords={"wavelength": rec["amp"].wavelength},
    )
    rec['conc'] = cedalion.nirs.cw.od2conc(rec['od_hpfilt'], rec.geo3d, dpf, spectrum="prahl")

    # conc_pr vs conc 
    chromo_var = quality.measurement_variance(rec['conc'], list_bad_channels = list_bad_ch, bad_rel_var = 1e6, calc_covariance = False)
    rec['conc_pcr'], gb_comp_rem = physio.global_component_subtract(rec['conc'],ts_weights=1/chromo_var,k=0,spatial_dim='channel',spectral_dim='chromo')

    rec['od_pcr1'] = cedalion.nirs.cw.conc2od(rec['conc_pcr'], rec.geo3d, dpf, spectrum="prahl")#     delta_conc = chunked_eff_xr_matmult(od_stacked, B, contract_dim="flat_channel", sample_dim="time", chunksize=300)
    c_meas = quality.measurement_variance(rec['od_hpfilt'], list_bad_channels=list_bad_ch, bad_rel_var=1e6,calc_covariance=False)

    delta_conc = recon.reconstruct(rec['od_pcr1'], c_meas) 
    delta_conc.time.attrs["units"] = units.s

    dC_brain = delta_conc.cd.freq_filter(fmin=0.01, fmax=0.5, butter_order=4)
    dC_brain = dC_brain.sel(time=slice(rec.stim.onset.values[0]-3 , rec.stim.onset.values[-1]+13))
    dC_brain = dC_brain.where(dC_brain.is_brain == True)
    # alternatively use 1/conc_var to weight vertex sensitivity and then normalize by sum of weights
    dC_brain = dC_brain.pint.quantify().pint.to("uM").pint.dequantify()

    hbr = dC_brain.sel(chromo='HbR').groupby('parcel').mean()
    hbo = dC_brain.sel(chromo='HbO').groupby('parcel').mean()
    signal_raw = xr.concat([hbo, hbr], dim='chromo')

    # revised matrix
    signal_raw = signal_raw.sel(parcel=signal_raw.parcel != 'Background+FreeSurfer_Defined_Medial_Wall_LH')
    signal_raw = signal_raw.sel(parcel=signal_raw.parcel != 'Background+FreeSurfer_Defined_Medial_Wall_RH')
    
    delta_conc, global_comp = physio.global_component_subtract(
        signal_raw, 
        ts_weights=None, k=0, 
        spatial_dim='parcel',
        spectral_dim= 'chromo')

    delta_conc = delta_conc / np.abs(delta_conc).max()
    delta_conc = delta_conc.fillna(0)
    delta_conc = delta_conc.transpose("time", "parcel", "chromo")

    parcel_dOD, parcel_mask = fwm.parcel_sensitivity(
        Adot,
        list_bad_ch,
        dOD_thresh = 0.001,       
        minCh=1,
        dHbO=10,
        dHbR=-3
    )
    sensitive_parcels = parcel_mask.where(parcel_mask, drop=True)["parcel"].values.tolist()
    dropped_parcels = parcel_mask.where(~parcel_mask, drop=True)["parcel"].values.tolist()
    print(f"Number of sensitive parcels: {len(sensitive_parcels)}")
    print(f"Number of dropped parcels: {len(dropped_parcels)}")
    
    data = {
        'delta_conc': delta_conc,
        'dropped_parcels': dropped_parcels,
        'sensitive_parcels': sensitive_parcels,
    }
    
    # save as pickle
    path = PureWindowsPath(f)
    subject_dir = path.parts[-3]
    filename = path.stem

    if DATASET_NAME == "FreshMotor":
        subject_dir = path.parts[-4]
        session_label = path.parts[-3]
        task_fragment = next(
            (part for part in filename.split('_') if part.startswith('task-')),
            f"task-{DATASET_NAME.replace('_', '').upper()}",
        )
        run_fragment = session_label.replace('ses-', 'run-')
        filename = f'{subject_dir}_{task_fragment}_{run_fragment}_nirs'

    if subject_dir not in subject_to_rec:
        subject_to_rec[subject_dir] = []

    all_parcels_dir = pre_processed_path / 'ts_all_parcels' / subject_dir
    all_parcels_dir.mkdir(parents=True, exist_ok=True)

    file_name_to_save = all_parcels_dir / f'{filename}_ts_all_parcels.pkl'

    with open(file_name_to_save, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

if skipped_subjects:
    print(f"Skipped {len(skipped_subjects)} file(s) because all channels were bad.")

print("Processing complete.")

In [None]:
if DATASET_NAME == "BallSqueezingHD_modified":
    MAX_PARCELS = 0
    PARCEL_TEMPLATE = None

    for file in files:
        with open(file, 'rb') as handle:
            data_pickle = pickle.load(handle)        
        timeseries = data_pickle['delta_conc']
        sensitive_parcels = data_pickle['sensitive_parcels']
        
        sensitive_timeseries = timeseries.sel(parcel=sensitive_parcels)
        
        if sensitive_timeseries.sizes["parcel"] > MAX_PARCELS:
            MAX_PARCELS = sensitive_timeseries.sizes["parcel"]
            PARCEL_TEMPLATE = sensitive_timeseries.parcel.values.tolist()
    
    print(f"Max parcels across subjects: {MAX_PARCELS}")
    print(f"Parcel template: {PARCEL_TEMPLATE}")

    # Save the template parcels to pkl file
    with open(f'{pre_processed_path}/parcel_template_{DATASET_NAME}.pkl', 'wb') as handle:
        pickle.dump(PARCEL_TEMPLATE, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# For any dataset: load Pre-defined template parcel from BallSqueezingHD_modified file
with open(f'{pre_processed_path.parents[0]}/BallSqueezingHD_modified/parcel_template_BallSqueezingHD_modified.pkl', 'rb') as handle:
    PARCEL_TEMPLATE = pickle.load(handle)
len(PARCEL_TEMPLATE)        

# Segmentation

In [None]:
# load (all) parcel files
preproc_files_path = str(pre_processed_path / 'ts_all_parcels' / 'sub-*' / '*.pkl')
pkl_files = glob.glob(preproc_files_path)

len(pkl_files), pkl_files[:2]

In [None]:
processed_path = Path(f'datasets/processed/{DATASET_NAME}')
processed_path.mkdir(parents=True, exist_ok=True)

In [None]:
baseline_duration = 2.5  # in seconds
n_shifts = 9
duration = 10  # in seconds
post_padding = 5  # in seconds
n_timepoints = 87  # fixed length after shifting

if DATASET_NAME == "BallSqueezingHD_modified":
    delta_range = (-2.5, 2.5)
elif DATASET_NAME == "FreshMotor":
    delta_range = (-2.0, 0.0)
start_shift = np.linspace(*delta_range, n_shifts)


for file in pkl_files:
    with open(file, 'rb') as handle:
        data_pickle = pickle.load(handle)
    
    delta_brain = data_pickle['delta_conc']
    sensitive_parcels = data_pickle['sensitive_parcels']

    # Align subject-specific parcels to a common parcel template (zero-pad missing parcels)
    delta_brain = delta_brain.sel(parcel=sensitive_parcels).reindex(parcel=PARCEL_TEMPLATE, fill_value=0)

    i = 0
    for index, row in rec.stim.iterrows():
        label = row["trial_type"].lower()
        for s in start_shift:
            start_time = row["onset"] + s
            end_time = start_time + duration + post_padding # in seconds
            baseline = delta_brain.sel(
                time=slice(row["onset"] - baseline_duration, row["onset"])
            ).mean("time")
            
            # Then, trimming is easy with `.sel()`:
            x = delta_brain.sel(time=slice(start_time, end_time)) - baseline
            x = x.isel(time=slice(0, n_timepoints))
            x = x.transpose("parcel", "chromo", "time")
            del x.time.attrs['units']

            if s == 0:
                x.to_netcdf(file.replace(processed_path).replace(".snirf", "_" + label + "_"+str(i)+"_test.nc"))
                i += 1
            else:
                x.to_netcdf(file.replace(processed_path).replace(".snirf", "_" + label + "_"+str(i)+".nc"))
                i += 1
    print("finished processing file: ", os.path.basename(file).replace(".snirf",".npy"))