In [3]:
import os
import shutil
from collections import Counter
from glob import glob

import mne
from mne_nirs.io.snirf import write_raw_snirf
from mne_nirs.io import fold
import mne_nirs
from mne_bids import write_raw_bids, BIDSPath, read_raw_bids
import matplotlib.pyplot as plt
from mne_nirs.experimental_design import make_first_level_design_matrix
from mne.preprocessing.nirs import (optical_density,
                                    temporal_derivative_distribution_repair)
from nilearn.plotting import plot_design_matrix
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA, FastICA
from mne.decoding import UnsupervisedSpatialFilter

In [4]:
root_bids = r'C:\Datasets\Test-retest study\bids_dataset'  # Replace with the path to your data
root_nirx = r'C:\Datasets\Test-retest study\sourcedata\sub-05\ses-01\nirs'
task = "auditory"        # Set to whatever the name of your experiment is
stimulus_duration = {'Control': 5, 'Noise': 5, 'Speech': 5.25}
trigger_info = {'1.0': 'Control',
                '2.0': 'Noise',
                '3.0': 'Speech',
                '4.0': 'XStop_break', # start of break?
                '5.0': 'XStart_break'}   # end of break?

subject_dirs = glob(os.path.join(root_bids, "sub-*"))
subjects = [subject_dir.split("-")[-1] for subject_dir in subject_dirs] # ["01","02",...]

ses_list = np.array([])
for folder in subject_dirs:
    ses_list = np.append(ses_list,np.array([ses for ses in os.listdir(folder)]))
ses_dict = dict(Counter(ses_list))
# Check if all participants came the same number of times
sessions = list(ses_dict.keys())
sessions =[ses.split("-")[-1] for ses in sessions] # ["01","02",...]
complete_sub_ses = all([value == len(subjects) for value in ses_dict.values()])

print(f"Found subjects: {subjects}")
print(f"Found sessions: {sessions}")
if complete_sub_ses:
    print(f"All {len(subjects)} subject came for {len(sessions)} sessions.")

Found subjects: ['01', '010', '011', '012', '02', '03', '04', '05', '06', '07', '08', '09']
Found sessions: ['01', '02']
All 12 subject came for 2 sessions.


In [None]:
## PRE-PROCESSING PIPELINE
# Describe steps?

subject = '05'
session = '01'

data_path = root_bids

print(f"Processing subject {subject}, session {session}...")
#os.makedirs(save_path, exist_ok=True)

# Create path to file based on experiment info
bids_path = BIDSPath(subject=subject,
                        session=session,
                        task=task,
                        root=data_path,
                        datatype="nirs",
                        suffix="nirs",
                        extension=".snirf")

# Load data
print("Loading raw NIRS data from BIDS dataset format")
raw_intensity = read_raw_bids(bids_path=bids_path, verbose=False)
raw_intensity.load_data()

# Set durations
raw_intensity.annotations.set_durations({'Control': 5, 'Noise': 5, 'Speech': 5.25})

# Get event timings
print("Extracting event timings...")
AllEvents, event_id = mne.events_from_annotations(raw_intensity)
Breaks, _ = mne.events_from_annotations(raw_intensity, {'XStop_break': 4, 'XStart_break': 5})
# Get Breaks from index to time stamps
Breaks = Breaks[:, 0] / raw_intensity.info['sfreq']
LastEvent = AllEvents[-1, 0] / raw_intensity.info['sfreq']

if len(Breaks) % 2 == 0:
    raise ValueError("Breaks array should have an odd number of elements.")

# Compute total experiment duration with breaks
original_duration = raw_intensity.times[-1] - raw_intensity.times[0]
print(f"Original duration: {original_duration:.2f} seconds")

# Cropping dataset
print("Cropping the breaks from the dataset...")
cropped_intensity = raw_intensity.copy().crop(Breaks[0], Breaks[1]) # block 1 in between break
# Crop and append blocks 2, 3 and 4
for j in range(2, len(Breaks) - 1, 2):
    block = raw_intensity.copy().crop(Breaks[j], Breaks[j + 1]) 
    cropped_intensity.append(block)
cropped_intensity.append(raw_intensity.copy().crop(Breaks[-1], LastEvent + 15.25)) # why 15.25?

cropped_duration = cropped_intensity.times[-1] - cropped_intensity.times[0]
print(f"Cropped duration: {cropped_duration:.2f} seconds")

if cropped_duration >= original_duration:
    print(f"WARNING: Cropping did not reduce duration for {subject} - {session}!")

raw_intensity_cropped = cropped_intensity.copy()

Processing subject 05, session 01...
Loading raw NIRS data from BIDS dataset format
Reading 0 ... 8666  =      0.000 ...  1663.872 secs...
Extracting event timings...
Used Annotations descriptions: [np.str_('Control'), np.str_('Noise'), np.str_('Speech'), np.str_('XStart_break'), np.str_('XStop_break')]
Used Annotations descriptions: [np.str_('XStart_break'), np.str_('XStop_break')]
Original duration: 1663.87 seconds
Cropping the breaks from the dataset...
Cropped duration: 1262.59 seconds


'epochs = mne_nirs.Epochs(\n    raw_intensity_cropped, AllEvents, event_id=events_id,\n)'

In [7]:
# Remove break annotations
print("Removing break annotations for the orginal raw...")
raw_intensity.annotations.delete(np.where(
    (raw_intensity.annotations.description == 'XStart_break') | 
    (raw_intensity.annotations.description == 'XStop_break') | 
    (raw_intensity.annotations.description == 'BAD boundary') | 
    (raw_intensity.annotations.description == 'EDGE boundary')
    )[0])

print("Removing break annotations for the cropped raw...")
raw_intensity_cropped.annotations.delete(np.where(
    (raw_intensity_cropped.annotations.description == 'XStart_break') | 
    (raw_intensity_cropped.annotations.description == 'XStop_break') | 
    (raw_intensity_cropped.annotations.description == 'BAD boundary') | 
    (raw_intensity_cropped.annotations.description == 'EDGE boundary')
    )[0])
    

Removing break annotations for the orginal raw...
Removing break annotations for the cropped raw...


In [8]:
# Optical density 
print("Converting to optical density...")
cropped_od = optical_density(raw_intensity_cropped)
original_od= optical_density(raw_intensity)

Converting to optical density...


In [None]:
# Project onto PCA before converting back to optical density
pca = UnsupervisedSpatialFilter(PCA(30), average=False)

In [None]:
# Motion artifact correction
print("Applying 'Temporal Derivative Distribution Repair' motion artefact correction...")
cropped_corrected_od = temporal_derivative_distribution_repair(cropped_od)
original_corrected_od = temporal_derivative_distribution_repair(original_od)

In [None]:
def signal_quality_single_subject(data_path, save_path, task, subject, session):
    print(f"Processing subject {subject}, session {session}...")
    #os.makedirs(save_path, exist_ok=True)

    # Create path to file based on experiment info
    bids_path = BIDSPath(subject=subject,
                            session=session,
                            task=task,
                            root=data_path,
                            datatype="nirs",
                            suffix="nirs",
                            extension=".snirf")

    # Load data
    print("Loading raw NIRS data from BIDS dataset format")
    raw_intensity = read_raw_bids(bids_path=bids_path, verbose=False)
    raw_intensity.load_data()

    # Set durations
    raw_intensity.annotations.set_durations({'Control': 5, 'Noise': 5, 'Speech': 5.25})

    # Get event timings
    print("Extracting event timings...")
    AllEvents, _ = mne.events_from_annotations(raw_intensity)
    Breaks, _ = mne.events_from_annotations(raw_intensity, {'XStop_break': 4, 'XStart_break': 5})
    # Get Breaks from index to time stamps
    Breaks = Breaks[:, 0] / raw_intensity.info['sfreq']
    LastEvent = AllEvents[-1, 0] / raw_intensity.info['sfreq']

    if len(Breaks) % 2 == 0:
        raise ValueError("Breaks array should have an odd number of elements.")

    # Compute total experiment duration with breaks
    original_duration = raw_intensity.times[-1] - raw_intensity.times[0]
    print(f"Original duration: {original_duration:.2f} seconds")

    # Cropping dataset
    print("Cropping the breaks from the dataset...")
    cropped_intensity = raw_intensity.copy().crop(Breaks[0], Breaks[1]) # block 1 in between break
    # Crop and append blocks 2, 3 and 4
    for j in range(2, len(Breaks) - 1, 2):
        block = raw_intensity.copy().crop(Breaks[j], Breaks[j + 1]) 
        cropped_intensity.append(block)
    cropped_intensity.append(raw_intensity.copy().crop(Breaks[-1], LastEvent + 15.25)) # why 15.25?

    cropped_duration = cropped_intensity.times[-1] - cropped_intensity.times[0]
    print(f"Cropped duration: {cropped_duration:.2f} seconds")

    if cropped_duration >= original_duration:
        print(f"WARNING: Cropping did not reduce duration for {subject} - {session}!")

    raw_intensity_cropped = cropped_intensity.copy()

    # Remove break annotations
    print("Removing break annotations for the orginal raw...")
    raw_intensity.annotations.delete(np.where(
        (raw_intensity.annotations.description == 'XStart_break') | 
        (raw_intensity.annotations.description == 'XStop_break') | 
        (raw_intensity.annotations.description == 'BAD boundary') | 
        (raw_intensity.annotations.description == 'EDGE boundary')
        )[0])

    print("Removing break annotations for the cropped raw...")
    raw_intensity_cropped.annotations.delete(np.where(
        (raw_intensity_cropped.annotations.description == 'XStart_break') | 
        (raw_intensity_cropped.annotations.description == 'XStop_break') | 
        (raw_intensity_cropped.annotations.description == 'BAD boundary') | 
        (raw_intensity_cropped.annotations.description == 'EDGE boundary')
        )[0])
    
    # Optical density and correction
    print("Converting to optical density...")
    cropped_od = optical_density(raw_intensity_cropped)
    original_od= optical_density(raw_intensity)

    # Replace oversaturated channels with high variance noise

    # Flag bad channels if standard deviation exceeds 15% - averaged signal over two wavelengths
    
    # Linearly nterpolate all flagged channels from adjacent good channels

    # Motion artifact correction
    print("Applying 'Temporal Derivative Distribution Repair' motion artefact correction...")
    cropped_corrected_od = temporal_derivative_distribution_repair(cropped_od)
    original_corrected_od = temporal_derivative_distribution_repair(original_od)



In [None]:
root_bids
save_path
subject
session

raw_intensity_sub01_ses01 = signal_quality_single_subject(data_path=root_bids, save_path='', task=task, subject='05', session='01')

Processing subject 05, session 01...
Loading raw NIRS data from BIDS dataset format
Reading 0 ... 8666  =      0.000 ...  1663.872 secs...
Extracting event timings...
Used Annotations descriptions: [np.str_('Xend'), np.str_('Xstart')]
Used Annotations descriptions: [np.str_('Control'), np.str_('Noise'), np.str_('Speech'), np.str_('Xend'), np.str_('Xstart')]
