In [1]:
# parcellates each map with the schafer-400 and saves the df in schafer400_dfs

# imports and general helper functions

In [15]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
import nilearn
import json
import datetime
import pickle
import seaborn as sns
import gc
import psutil
import math
import scipy.stats as stats
from matplotlib.patches import Patch
from nilearn import plotting
from nilearn.glm.first_level import FirstLevelModel
from nilearn.glm.second_level import SecondLevelModel
from nilearn.glm import threshold_stats_img
from nilearn.image import concat_imgs, mean_img, index_img
from nilearn.reporting import make_glm_report
from nilearn import masking, image
from nilearn import datasets
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from collections import defaultdict
from nilearn.maskers import NiftiLabelsMasker


In [16]:
# general helper functions:
def build_contrast_map_path(base_dir, level, subject, session, task, contrast_name):
    """Build the file path for a contrast map."""
    filename = f'{subject}_{session}_task-{task}_contrast-{contrast_name}_rtmodel-rt_centered_stat-effect-size.nii.gz'
    
    # NOTE: as of 7/6/25 for sub 10 in flanker the format is different: sub-s10_ses-01_run-1_task-flanker_contrast-incongruent-congruent_rtmodel-rt_centered_stat-effect-size.nii.gz
    if (subject == 'sub-s10' and task == 'flanker'):
        filename = f'{subject}_{session}_run-1_task-{task}_contrast-{contrast_name}_rtmodel-rt_centered_stat-effect-size.nii.gz'
    # NOTE: as of 10/1/25 for sub 3 in all tasks the format is different: (also has run-1)
    if (subject == 'sub-s03'):
        filename = f'{subject}_{session}_run-1_task-{task}_contrast-{contrast_name}_rtmodel-rt_centered_stat-effect-size.nii.gz'
        
    return os.path.join(base_dir, level, subject, task, 'indiv_contrasts', filename)

def is_valid_contrast_map(img_path):
    """Check if a contrast map has sufficient variance and no NaN values."""
    try:
        img = nib.load(img_path)
        data = img.get_fdata()
        return np.std(data) > 1e-10 and not np.isnan(data).any()
    except Exception as e:
        print(f"Error validating {img_path}: {e}")
        return False
        
def clean_z_map_data(z_map, task, contrast_name, encounter):
    """Clean z-map data by handling NaN and infinity values."""
    data = z_map.get_fdata()
    if np.isnan(data).any() or np.isinf(data).any():
        data = np.nan_to_num(data)
        z_map = nib.Nifti1Image(data, z_map.affine, z_map.header)
        print(f"Warning: Fixed NaN/Inf values in {task}:{contrast_name}:encounter-{encounter+1}")
    return z_map

def cleanup_memory():
    """
    Clean up memory between batches
    """
    # Force garbage collection
    gc.collect()
    
    # Get memory info
    memory = psutil.virtual_memory()
    print(f"Memory after cleanup: {memory.percent:.1f}% used ({memory.available/(1024**3):.1f}GB available)")
    
def convert_to_regular_dict(d):
    if isinstance(d, defaultdict):
        return {k: convert_to_regular_dict(v) for k, v in d.items()}
    elif isinstance(d, list):
        return [convert_to_regular_dict(i) for i in d]
    else:
        return d

# constants, filenames, and roi labels

In [17]:
# all tasks and contrasts
TASKS = ["nBack","flanker","directedForgetting","goNogo", "shapeMatching", "stopSignal", "cuedTS", "spatialTS"]
CONTRASTS = {}
CONTRASTS["nBack"] = ["twoBack-oneBack", "match-mismatch","task-baseline","response_time"] # the nback contrasts
CONTRASTS["flanker"] = ["incongruent-congruent", "task-baseline"]
CONTRASTS["directedForgetting"] = ["neg-con", "task-baseline","response_time"]
CONTRASTS["goNogo"] = ["nogo_success-go", "nogo_success","task-baseline","response_time"] # go_rtModel check
CONTRASTS["shapeMatching"] = ["DDD", "DDS", "DNN", "DSD", "main_vars", "SDD", "SNN", "SSS", "task-baseline","response_time"]
CONTRASTS["stopSignal"] = ["go", "stop_failure-go", "stop_failure", "stop_failure-stop_success", "stop_success-go", "stop_success", "stop_success-stop_failure", "task-baseline","response_time"]
CONTRASTS["cuedTS"] = ["cue_switch_cost", "task_switch_cost", "task_switch_cue_switch-task_stay_cue_stay", "task-baseline","response_time"]
CONTRASTS["spatialTS"] = ["cue_switch_cost", "task_switch_cost", "task_switch_cue_switch-task_stay_cue_stay", "task-baseline","response_time"]

# interested in looking at them all now:
requested_task_contrasts = defaultdict(lambda: defaultdict(list))
requested_task_contrasts['nBack'] = CONTRASTS["nBack"]
requested_task_contrasts['flanker'] = CONTRASTS["flanker"]
requested_task_contrasts['directedForgetting'] = CONTRASTS["directedForgetting"]
requested_task_contrasts['goNogo'] = CONTRASTS["goNogo"]
requested_task_contrasts['shapeMatching'] = CONTRASTS["shapeMatching"]
requested_task_contrasts['stopSignal'] = CONTRASTS["stopSignal"]
requested_task_contrasts['cuedTS'] = CONTRASTS["cuedTS"]
requested_task_contrasts['spatialTS'] = CONTRASTS["spatialTS"] 

# compiled_req_contrasts = ["twoBack-oneBack", 'task-baseline', "incongruent-congruent", "neg-con", "nogo_success-go", "main_vars", "stop_failure-go","task_switch_cost"]

encounters = ['01', '02','03','04','05']

# compile all requested contrasts into one list
compiled_req_contrasts = []
for task in TASKS:
    for contrast in requested_task_contrasts[task]:
        if (contrast not in compiled_req_contrasts):
            compiled_req_contrasts.append(contrast)
print(compiled_req_contrasts)

['twoBack-oneBack', 'match-mismatch', 'task-baseline', 'response_time', 'incongruent-congruent', 'neg-con', 'nogo_success-go', 'nogo_success', 'DDD', 'DDS', 'DNN', 'DSD', 'main_vars', 'SDD', 'SNN', 'SSS', 'go', 'stop_failure-go', 'stop_failure', 'stop_failure-stop_success', 'stop_success-go', 'stop_success', 'stop_success-stop_failure', 'cue_switch_cost', 'task_switch_cost', 'task_switch_cue_switch-task_stay_cue_stay']


# load subject files per session

In [18]:
# load files per subject per session

# where the first level contrast maps are stored
BASE_DIR = '/oak/stanford/groups/russpold/data/network_grant/discovery_BIDS_20250402/derivatives/'
LEVEL = 'output_lev1_mni'
# subjects in the discovery sample
SUBJECTS = ['sub-s03', 'sub-s10', 'sub-s19', 'sub-s29', 'sub-s43']
SESSIONS = ['ses-01', 'ses-02', 'ses-03', 'ses-04', 'ses-05', 'ses-06', 'ses-07', 'ses-08', 'ses-09','ses-10']

# number of encounters each subject has with a task
max_num_encounters = 5

In [19]:
# arrange each subjects maps by which encounter num it is
all_contrast_maps = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
encounter_maps = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))

for task in TASKS:
    for contrast_name in CONTRASTS[task]:
        for subject in SUBJECTS:
            overall_encounter_count = 0
            
            for session in SESSIONS:
                contrast_map_path = build_contrast_map_path(BASE_DIR, LEVEL, subject, session, task, contrast_name)
                # print(contrast_map_path)
                
                if os.path.exists(contrast_map_path):
                    all_contrast_maps[task][contrast_name][subject].append(contrast_map_path)
                    encounter_maps[task][contrast_name][subject][overall_encounter_count] = contrast_map_path
                    overall_encounter_count += 1
                # else:
                #     print(f"{contrast_map_path} is not found")

first_level_session_maps = all_contrast_maps
first_level_encounter_maps = encounter_maps

# general loading and plotting functions that can apply across all tasks

In [20]:
# relevant loading functions taken from 3_create_RSMs_first_level
# function to gather maps of a certain task/contrast from first_level_encounter_maps
def gather_tc_maps(req_tasks,req_contrasts,all_maps=first_level_encounter_maps,req_encounters=[0,1,2,3,4], req_subjects = SUBJECTS):
    '''
    Get a list of loaded niftis for specific task/contrast/encounter combinations of first level maps 
    
    Parameters
        req_tasks: list of tasks as strings (all tasks have to be from the TASKS dict)
        req_contrasts: list of contrasts as strings (all tasks have to be from the CONTRASTS dict)
        all_maps: [task][contrast_name][subject][overall_encounter_count] -> one map each (here it is in a filepath format)
        req_encounters: list of encounter numbers that are requested (default is all 5)
        req_subjects: list of subject id strings that are requested (default is all in SUBJECTS)
    Return
        specified_maps: list of loaded nifti files that fit the requested task, contrast, and encounter (this returns this for all subjects)
        specified_descriptors: list of descriptions of each file (i.e. titles)
        data_title: informative title for the RSM that will later be created
    
    '''
    specified_maps = []
    specified_descriptors = []
    max_num_encounters = 5

    if (len(req_tasks) == 0) or (len(req_contrasts) == 0):
        return [], [], ''

    for task in req_tasks:
        if task not in TASKS:
            print(f"task {task} not in task masterlist")
            continue
    
        for contrast in req_contrasts:
            if contrast not in CONTRASTS[task]: # make sure this contrast exists in the given task
                print(f"skipped for contrast {contrast} and task {task}")
                continue
                
            for subject in req_subjects:
                if subject not in SUBJECTS:
                    print(f"subject: {subject} is not in this dataset, so skipped")
                    continue
                    
                for encounter in req_encounters:
                    if encounter < 0 or encounter >= max_num_encounters:
                        continue

                    descriptor_name = f"{subject}:encounter-0{encounter + 1}"
                            
                    if task in all_maps.keys():
                        if contrast in all_maps[task].keys():
                            if subject in all_maps[task][contrast].keys():
                                if encounter in all_maps[task][contrast][subject].keys():

                                    map_data = all_maps[task][contrast][subject][encounter]
                                    
                                    # Check if file is already loaded
                                    if isinstance(map_data, str):
                                        # map_data is a file path, need to load it
                                        try:
                                            if os.path.exists(map_data):
                                                loaded_map = nib.load(map_data)
                                                specified_maps.append(loaded_map)
                                                specified_descriptors.append(descriptor_name)
                                            else:
                                                print(f"File not found: {map_data}")
                                                failed_loads.append((descriptor_name, "File not found"))
                                        except Exception as e:
                                            print(f"Error loading {map_data}: {str(e)}")
                                    else:
                                        print(f"Unexpected data type for {descriptor_name}: {type(map_data)}")
                                        
                                else:
                                    print(f"{task}|{contrast}|{subject}: {encounter}")
                                    continue
                            else:
                                print(f"{task}|{contrast} subject {subject}")
                                continue
                        else:
                            print(f"{task}:{contrast}")
                            continue
                    else:
                        print(f"{task}")
                        continue
    # create RSM title
    data_title = ''
    if (len(req_tasks) == 1):
        data_title += f'Task:{req_tasks[0]}|'
    else:  # more than 1 task
        data_title += 'Task:'
        for i, task in enumerate(req_tasks):
            if (i != len(req_tasks) - 1):
                data_title += f"{task},"
            else:
                data_title += f"{task}"
        data_title += '|'

    if (len(req_contrasts) == 1):
        data_title += f'Contrast:{req_contrasts[0]}'
    else:
        data_title += 'Contrast:'
        for i, contrast in enumerate(req_contrasts):
            if (i != (len(req_contrasts) - 1)):
                data_title += f"{contrast},"
            else:
                data_title += f"{contrast}"
    
    return specified_maps, specified_descriptors, data_title

In [21]:
# sort this by encounter (later move this into the main gather_tc-maps func)
def reorganize_dict(original_dict):
    new_dict = {}
    
    for task, contrasts in original_dict.items():
        new_dict[task] = {}
        
        for contrast, data in contrasts.items():
            new_dict[task][contrast] = {}
            
            maps_list = data['maps_list']
            descriptors_list = data['descriptors_list']
            
            # Process each map and its corresponding descriptor
            for map_obj, descriptor in zip(maps_list, descriptors_list):
                # (e.g., 'sub-s10:encounter-01')
                parts = descriptor.split(':')
                sub = parts[0]  # 'sub-s10'
                encounter = parts[1]  # 'encounter-01'
                encounter_num = encounter.split('-')[1]  # '01'
                
                # Create nested structure if it doesn't exist
                if sub not in new_dict[task][contrast]:
                    new_dict[task][contrast][sub] = {}
                
                # Assign the map object to the appropriate location
                new_dict[task][contrast][sub][encounter_num] = map_obj
    
    return new_dict
    
def gather_tc_maps_full_task(req_tasks_contrasts = requested_task_contrasts, curr_task = None, req_contrasts = None, create_subset=False):
    '''
    Get a dict of loaded niftis for every contrast of a requested task (uses gather_tc_maps and puts it into a dict organized
    by task/contrast). If curr_task or req_contrasts is not None, then just do this for one task (not all task/contrasts in 
    req_tasks_contrasts)
    
    Parameters
        req_tasks_contrasts = dict of all tasks that we are organizing
        curr_task: specific task to load (if create_subset = True) instead of doing it for all task/contrasts
        req_contrasts: specific contrasts to load (if create_subset=True and curr_task != None). 
        create_subset: default False (if true, then don't organize the dict for all task/contrasts, but only for the requested current task and its contrasts).
    Return
        task_contrast_all_maps[task][contrast] = ["maps_list","descriptors_list", "data_title"]
    '''
    task_contrast_all_maps = {}

    # if it was just a subset
    if ((create_subset) and (curr_task != None)):
        if req_contrasts == None:
            req_contrasts = req_tasks_contrasts[curr_task]

        print(f"creating a subset dict for {curr_task} and contrasts {req_contrasts.items()}")

        print("MISSING:")
        task_contrast_all_maps[curr_task] = {}
        for contrast in req_contrasts:
            task_contrast_all_maps[curr_task][contrast] = {}
            task_contrast_all_maps[curr_task][contrast]["maps_list"] = []
            task_contrast_all_maps[curr_task][contrast]["descriptors_list"] = []
            task_contrast_all_maps[curr_task][contrast]["data_title"] = ""
        
            req_tasks_tc = [curr_task]
            req_contrasts_tc = [contrast]
        
            task_contrast_all_maps[curr_task][contrast]["maps_list"],task_contrast_all_maps[curr_task][contrast]["descriptors_list"],task_contrast_all_maps[curr_task][contrast]["data_title"] = gather_tc_maps(req_tasks_tc,req_contrasts_tc,all_maps=first_level_encounter_maps,req_encounters=[0,1,2,3,4], req_subjects = SUBJECTS)

        return task_contrast_all_maps

    # if it was for loading all of the task/contrast combos
    for curr_task in req_tasks_contrasts:
        task_contrast_all_maps[curr_task] = {}

        print(f"MISSING for {curr_task}:")
        for contrast in req_tasks_contrasts[curr_task]:
            task_contrast_all_maps[curr_task][contrast] = {}
            task_contrast_all_maps[curr_task][contrast]["maps_list"] = []
            task_contrast_all_maps[curr_task][contrast]["descriptors_list"] = []
            task_contrast_all_maps[curr_task][contrast]["data_title"] = ""
        
            req_tasks_tc = [curr_task]
            req_contrasts_tc = [contrast]
        
            task_contrast_all_maps[curr_task][contrast]["maps_list"],task_contrast_all_maps[curr_task][contrast]["descriptors_list"],task_contrast_all_maps[curr_task][contrast]["data_title"] = gather_tc_maps(req_tasks_tc,req_contrasts_tc,all_maps=first_level_encounter_maps,req_encounters=[0,1,2,3,4], req_subjects = SUBJECTS)
            
    return task_contrast_all_maps

In [22]:
def standardize_mask(mask_img, dtype=bool):
    """Ensure mask has consistent data type and format"""
    mask_data = mask_img.get_fdata()
    # Convert to binary and specified dtype
    binary_data = (mask_data > 0).astype(dtype)
    return image.new_img_like(mask_img, binary_data)

# parcellate across encounters

In [23]:
# load all of the maps in an organized dict and see which maps are missing per task
task_contrast_all_maps = gather_tc_maps_full_task(req_tasks_contrasts = requested_task_contrasts,create_subset=False)

# Use the function
task_contrast_enc_all_maps = reorganize_dict(task_contrast_all_maps)

MISSING for nBack:
MISSING for flanker:
MISSING for directedForgetting:
MISSING for goNogo:
goNogo|nogo_success-go|sub-s19: 4
goNogo|nogo_success-go|sub-s29: 4
goNogo|nogo_success-go|sub-s43: 4
goNogo|nogo_success|sub-s19: 4
goNogo|nogo_success|sub-s29: 4
goNogo|nogo_success|sub-s43: 4
goNogo|task-baseline|sub-s19: 4
goNogo|task-baseline|sub-s29: 4
goNogo|task-baseline|sub-s43: 4
goNogo|response_time|sub-s19: 4
goNogo|response_time|sub-s29: 4
goNogo|response_time|sub-s43: 4
MISSING for shapeMatching:
MISSING for stopSignal:
MISSING for cuedTS:
cuedTS|cue_switch_cost|sub-s29: 4
cuedTS|task_switch_cost|sub-s29: 4
cuedTS|task_switch_cue_switch-task_stay_cue_stay|sub-s29: 4
cuedTS|task-baseline|sub-s29: 4
cuedTS|response_time|sub-s29: 4
MISSING for spatialTS:


# For all subjects, parcellate each session and save the DF

In [24]:
# 1. Load the Schaefer 400 atlas
print("Loading Schaefer 400 atlas...")
schaefer_atlas = datasets.fetch_atlas_schaefer_2018(
    n_rois=400, 
    yeo_networks=7,  # 7 or 17 networks available
    resolution_mm=2  # 1mm or 2mm resolution
)
print(f"Atlas loaded with {len(schaefer_atlas.labels)} regions")
print(f"Atlas shape: {nib.load(schaefer_atlas.maps).shape}")

Loading Schaefer 400 atlas...


Atlas loaded with 400 regions
Atlas shape: (91, 109, 91)


In [25]:
encounters = ['01', '02','03','04','05']
# parcel_dict = {curr_subj: {curr_task: {curr_contrast: {}}}}

def save_subset_of_parcels(subs_requested, run_num):
    parcel_dict = {}

    for curr_subj in subs_requested:   
        parcel_dict[curr_subj] = {}
        
        for curr_task in requested_task_contrasts:
            parcel_dict[curr_subj][curr_task] = {}
    
            for curr_contrast in requested_task_contrasts[curr_task]:
                parcel_dict[curr_subj][curr_task][curr_contrast] = {}
    
                for enc in encounters:
                    print(f"{curr_task}{curr_contrast}{curr_subj}")
    
                    try:
                        fmri_img = task_contrast_enc_all_maps[curr_task][curr_contrast][curr_subj][enc]
                        print(f"fMRI data loaded successfully for {curr_subj} {curr_task} {curr_contrast} encounter {enc}")
                        print(f"fMRI shape: {fmri_img.shape}")
        
                        # make the masker and get regional avg activation
                        masker = NiftiLabelsMasker(
                            labels_img=schaefer_atlas.maps,
                            standardize=False, 
                            memory='nilearn_cache',
                            strategy='mean'  # Average activation within each region
                        )
                        # Extract regional values
                        regional_values = masker.fit_transform(fmri_img)
                        # Create a more meaningful output
                        region_labels = [label.decode('utf-8') if isinstance(label, bytes) else label 
                                        for label in schaefer_atlas.labels]
                        activation_df = pd.DataFrame({
                            'region': region_labels,
                            'activation': regional_values.flatten(),
                            'network': [label.split('_')[1] if '7Networks' in label else 'Unknown' 
                                       for label in region_labels]
                        })
                    
                        # save the activation df
                        parcel_dict[curr_subj][curr_task][curr_contrast][enc] = activation_df
                        
                    except KeyError as e:
                        print(f"Warning: Data not found for {curr_subj} {curr_task} {curr_contrast} encounter {enc}")
                        print(f"Missing key: {e}")
                        continue

    # save to pickle
    with open(f'schafer400_dfs/discovery_parcel_indiv_mean_updated_1001_{run_num}.pkl', 'wb') as f:
        pickle.dump(parcel_dict, f)
    
    # delete from cache
    del parcel_dict
    cleanup_memory()

In [None]:
# get first 2 subjects and save
save_subset_of_parcels(SUBJECTS[0:2], 1)

In [None]:
# get next 2 subjects and save
save_subset_of_parcels(SUBJECTS[2:4], 2)

In [31]:
# get last subject and save
save_subset_of_parcels([SUBJECTS[-1]], 3)

nBacktwoBack-oneBacksub-s43
fMRI data loaded successfully for sub-s43 nBack twoBack-oneBack encounter 01
fMRI shape: (97, 115, 97)
nBacktwoBack-oneBacksub-s43
fMRI data loaded successfully for sub-s43 nBack twoBack-oneBack encounter 02
fMRI shape: (97, 115, 97)
nBacktwoBack-oneBacksub-s43
fMRI data loaded successfully for sub-s43 nBack twoBack-oneBack encounter 03
fMRI shape: (97, 115, 97)
nBacktwoBack-oneBacksub-s43
fMRI data loaded successfully for sub-s43 nBack twoBack-oneBack encounter 04
fMRI shape: (97, 115, 97)
nBacktwoBack-oneBacksub-s43
fMRI data loaded successfully for sub-s43 nBack twoBack-oneBack encounter 05
fMRI shape: (97, 115, 97)
nBackmatch-mismatchsub-s43
fMRI data loaded successfully for sub-s43 nBack match-mismatch encounter 01
fMRI shape: (97, 115, 97)
nBackmatch-mismatchsub-s43
fMRI data loaded successfully for sub-s43 nBack match-mismatch encounter 02
fMRI shape: (97, 115, 97)
nBackmatch-mismatchsub-s43
fMRI data loaded successfully for sub-s43 nBack match-mismat