In [1]:
import sys
import numpy as np
from matplotlib import pyplot as plt
from joblib import Parallel, delayed
import seaborn as sns
import pandas as pd
import warnings
from nilearn import image
from nilearn.interfaces.fmriprep import load_confounds
sys.path.append('..')
from utils.data import Subject, load_participant_list

In [2]:
base_dir = '/home/ubuntu/data/learning-habits'
bids_dir = "/home/ubuntu/data/learning-habits/bids_dataset/derivatives/fmriprep-24.0.1"

all_sub_ids = load_participant_list(base_dir, file_name='modeling_participants.tsv')

sub_ids = all_sub_ids

In [3]:
print('Number of subjects:', len(sub_ids))

Number of subjects: 67


In [4]:
subjects = [Subject(base_dir, sub_id, include_modeling=True, include_imaging=True, bids_dir=bids_dir) for sub_id in sub_ids]



# Loading fmriprep confounds directly

In [5]:
runs = subjects[0].runs

In [6]:
cfds = subjects[0].get_confounds_path(runs[0])

In [7]:
parts = []
for sub in subjects:
    conf_df = pd.read_csv(sub.get_confounds_path(runs[0]), sep='\t').reset_index(drop=True)
    conf_df.insert(0, 'sub_id', sub.sub_id)  # put sub_id as first column
    parts.append(conf_df)

learning1 = pd.concat(parts, ignore_index=True)

In [8]:
parts = []
for sub in subjects:
    conf_df = pd.read_csv(sub.get_confounds_path(runs[1]), sep='\t').reset_index(drop=True)
    conf_df.insert(0, 'sub_id', sub.sub_id)  # put sub_id as first column
    parts.append(conf_df)

learning2 = pd.concat(parts, ignore_index=True)

In [9]:
parts = []
for sub in subjects:
    conf_df = pd.read_csv(sub.get_confounds_path(runs[2]), sep='\t').reset_index(drop=True)
    conf_df.insert(0, 'sub_id', sub.sub_id)  # put sub_id as first column
    parts.append(conf_df)

test = pd.concat(parts, ignore_index=True)

In [10]:
fd_thresh = 0.5

In [11]:
learning1 = learning1.assign(flagged=learning1['framewise_displacement'] > fd_thresh)
learning2 = learning2.assign(flagged=learning2['framewise_displacement'] > fd_thresh)
test = test.assign(flagged=test['framewise_displacement'] > fd_thresh)

In [12]:
N_learning = (learning1.sub_id == 'sub-01').sum()
N_test = (test.sub_id == 'sub-01').sum()

In [13]:
excl_learning1 = np.where((learning1.groupby('sub_id').flagged.sum() > 0.2*N_learning).values)[0]
excl_learning2 = np.where((learning2.groupby('sub_id').flagged.sum() > 0.2*N_learning).values)[0]
excl_test = np.where((test.groupby('sub_id').flagged.sum() > 0.2*N_test).values)[0]

In [14]:
excl_learning1

array([61])

In [15]:
# Learning 1
print('Subjects excluded due to motion:')
print('Learning 1')
print(['sub-' + sub_ids[i] for i in excl_learning1])
print('Learning 2')
print(['sub-' + sub_ids[i] for i in excl_learning2])
print('Test')
print(['sub-' + sub_ids[i] for i in excl_test])

Subjects excluded due to motion:
Learning 1
['sub-68']
Learning 2
['sub-44', 'sub-48', 'sub-68']
Test
['sub-17', 'sub-31', 'sub-48', 'sub-68']


# Testing the loading of confounds

In [16]:
sub = subjects[4]
conf, mask = sub.load_confounds('learning1', 'basic', False, scrub=0, fd_thresh=.5, std_dvars_thresh=None)

In [17]:
with_dummies = '/mnt/data/learning-habits/spm_format_20250603/sub-01/func/sub-01_ses-1_task-test_run-3_space-MNI152NLin2009cAsym_desc-preproc_bold_motion_with_dummies.txt'
with_dummies = pd.read_csv(with_dummies, sep='\t', header=None)

motion = '/mnt/data/learning-habits/spm_format_20250603/sub-01/func/sub-01_ses-1_task-test_run-3_space-MNI152NLin2009cAsym_desc-preproc_bold_motion.txt'
motion = pd.read_csv(motion, sep='\t', header=None)

# Old code - not very efficient

In [None]:
# get the reference images to know the shape of the data
N_learning = image.load_img(subjects[0].img.get('learning1')).shape[-1]
N_test = image.load_img(subjects[0].img.get('test')).shape[-1]

In [None]:
all_volumes = np.zeros((len(subjects), 3))
all_volumes[:, :2] = N_learning
all_volumes[:, 2] = N_test

In [None]:
fd_thresholds = [0.3, 0.5, 0.75, 1, 2]
std_dvars_thresholds = [1, 1.5, 2, 2.5, 3, 5]

In [None]:
def compute_valid_volumes_for_thresholds(fd_t, sd_t, scrub=0):
    """
    Returns a (len(subjects), 3) array of valid-volume proportions
    for the given FD & DVARS thresholds.
    """
    out = np.zeros((len(subjects), 3))
    for i, sub in enumerate(subjects):
        for j, run in enumerate(sub.runs):
            N_block = N_learning if j < 2 else N_test
            img_path = sub.img.get(run)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=DeprecationWarning)
                _, sample_mask = load_confounds(
                    img_path,
                    strategy=('motion','high_pass','wm_csf','scrub'),
                    scrub=scrub,
                    fd_threshold=fd_t,
                    std_dvars_threshold=sd_t
                )
            valid = len(sample_mask) if sample_mask is not None else N_block
            out[i, j] = valid / N_block
    return out

In [None]:
n_valid = compute_valid_volumes_for_thresholds(0.5, 2.5, 0)

In [None]:
thresh = 0.2

for run in range(3):
    subject_ids = [sub_ids[i] for i in range(len(sub_ids)) if n_valid[i, run] < 1-thresh]
    print(f'Run {run+1}: {len(subject_ids)} subjects with > {thresh} scrubbed volumes')
    print(f'Subject IDs: {subject_ids}')

In [None]:
thresh = 0.2

for run in range(3):
    subject_ids = [sub_ids[i] for i in range(len(sub_ids)) if n_valid[i, run] < 1-thresh]
    print(f'Run {run+1}: {len(subject_ids)} subjects with > {thresh} scrubbed volumes')
    print(f'Subject IDs: {subject_ids}')

In [None]:
# Run in parallel over all (fd, dvars) combos
results = Parallel(n_jobs=30)(
    delayed(compute_valid_volumes_for_thresholds)(fd_t, sd_t)
    for fd_t in fd_thresholds
    for sd_t in std_dvars_thresholds
)

In [None]:
n_fd = len(fd_thresholds)
n_sd = len(std_dvars_thresholds)
prop_valid = np.zeros((n_fd, n_sd, len(subjects), 3))

k = 0
for i_fd, fd_t in enumerate(fd_thresholds):
    for j_sd, sd_t in enumerate(std_dvars_thresholds):
        prop_valid[i_fd, j_sd] = results[k]
        k += 1

In [None]:
prop_valid[1,3,0,0]*N_learning

In [None]:
# Calculate mean and standard deviation
mean_prop_valid = prop_valid.mean(axis=(2, 3))
std_prop_valid = prop_valid.std(axis=(2, 3))

# Create the annotation text with mean and standard deviation
annot = np.array([["{:.2f}Â±{:.2f}".format(mean, std) for mean, std in zip(row_mean, row_std)] 
                  for row_mean, row_std in zip(mean_prop_valid, std_prop_valid)])

# Create the heatmap
plt.figure(figsize=(9, 7))  # Adjust figure size if needed
sns.heatmap(
    mean_prop_valid,  # Average over runs and subjects
    xticklabels=std_dvars_thresholds, 
    yticklabels=fd_thresholds, 
    cmap="viridis", 
    annot=annot,  # Add mean and std annotations
    fmt="",  # No additional formatting needed
    cbar_kws={'label': "Proportion of valid volumes"}  # Add label to colorbar
)

# Configure labels
plt.xlabel("Standard DVARS threshold")
plt.ylabel("FD threshold")
plt.title("Proportion of Retained Volumes")

plt.show()


# Scrub parameter

In [None]:
# look at the effect of the scrub parameter 
fd_threshold = 0.5
std_dvars_threshold = 2
scrub = [1, 2, 3, 4, 5]

results = Parallel(n_jobs=5)(
    delayed(compute_valid_volumes_for_thresholds)(fd_threshold, std_dvars_threshold, s)
    for s in scrub
)

In [None]:
# Calculate mean and standard deviation for varying scrub
mean_prop_valid_scrub = np.array([result.mean(axis=(0, 1)) for result in results])
std_prop_valid_scrub = np.array([result.std(axis=(0, 1)) for result in results])

# Create the bar plot with error bars
plt.figure(figsize=(9, 7))  # Adjust figure size if needed
plt.bar(scrub, mean_prop_valid_scrub, yerr=std_prop_valid_scrub, capsize=5, color='skyblue', edgecolor='black')
plt.ylim((0.7,1))

# Configure labels
plt.xlabel("Scrub parameter")
plt.ylabel("Proportion of valid volumes")
plt.title("Effect of Scrub on Proportion of Retained Volumes")
plt.show()

In [None]:
# Choose specific thresholds
chosen_fd_threshold_index = fd_thresholds.index(0.5)
chosen_std_threshold_index = std_dvars_thresholds.index(2)

# Extract valid volumes for the chosen thresholds
valid_volumes = prop_valid[chosen_fd_threshold_index, chosen_std_threshold_index]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].hist(valid_volumes[:, 0], bins=10, label='learning1')
axes[0].axvline(valid_volumes[:, 0].mean(), color='red', linestyle='dashed', linewidth=1)
axes[0].set_title('Learning1')
axes[0].set_xlabel('Fraction of valid Volumes')
axes[0].set_ylabel('Count')
axes[0].set_xlim(0.7, 1)

axes[1].hist(valid_volumes[:, 1], bins=10, label='learning2')
axes[1].axvline(valid_volumes[:, 1].mean(), color='red', linestyle='dashed', linewidth=1)
axes[1].set_title('Learning2')
axes[1].set_xlabel('Fraction of valid Volumes')
axes[1].set_ylabel('Count')
axes[1].set_xlim(0.7, 1)

axes[2].hist(valid_volumes[:, 2], bins=10, label='test')
axes[2].axvline(valid_volumes[:, 2].mean(), color='red', linestyle='dashed', linewidth=1)
axes[2].set_title('Test')
axes[2].set_xlabel('Fraction of valid Volumes')
axes[2].set_ylabel('Count')
axes[2].set_xlim(0.7, 1)

plt.tight_layout()
plt.show()

In [None]:
# exclusion threshold
max_scrub = 0.2

In [None]:
for run in range(3):
    subject_ids = [sub_ids[i] for i in range(len(sub_ids)) if valid_volumes[i, run] < 1-max_scrub]
    print(f'Run {run+1}: {len(subject_ids)} subjects with > {max_scrub} scrubbed volumes')
    print(f'Subject IDs: {subject_ids}')