In [1]:
# parcellates each subject's encounter map with the smorgasbord atlas OR the schaefer atlas
# (includes both cortical and subcortical) and saves the df in smor_parcel_dfs or schafer400_dfs

In [1]:
import os
import sys
import numpy as np
import pandas as pd
import nibabel as nib
import nilearn
import json
import datetime
import pickle
import seaborn as sns
import gc
import psutil
import math
import scipy.stats as stats
from matplotlib.patches import Patch
from nilearn import plotting
from nilearn.glm.first_level import FirstLevelModel
from nilearn.glm.second_level import SecondLevelModel
from nilearn.glm import threshold_stats_img
from nilearn.image import concat_imgs, mean_img, index_img
from nilearn.reporting import make_glm_report
from nilearn import masking, image
from nilearn import datasets
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from collections import defaultdict
from nilearn.maskers import NiftiLabelsMasker
from sklearn.utils import Bunch

# Import shared utilities and configuration
# need to do it this way because in a sub-directory (later turn config and utils into part of a package)
from utils import (
    TASKS, CONTRASTS, SUBJECTS, SESSIONS, ENCOUNTERS,
    build_first_level_contrast_map_path, is_valid_contrast_map, clean_z_map_data,
    convert_to_regular_dict, create_smor_atlas,load_smor_atlas, load_schaefer_atlas, cleanup_memory
)
from config import BASE_DIR, INPUT_LEVEL, OUTPUT_DIRS

In [2]:
# compile all requested contrasts into one list
compiled_req_contrasts = []
for task in TASKS:
    for contrast in CONTRASTS[task]:
        if (contrast not in compiled_req_contrasts):
            compiled_req_contrasts.append(contrast)
print(compiled_req_contrasts)

['twoBack-oneBack', 'match-mismatch', 'task-baseline', 'response_time', 'incongruent-congruent', 'neg-con', 'nogo_success-go', 'nogo_success', 'DDD', 'DDS', 'DNN', 'DSD', 'main_vars', 'SDD', 'SNN', 'SSS', 'go', 'stop_failure-go', 'stop_failure', 'stop_failure-stop_success', 'stop_success-go', 'stop_success', 'stop_success-stop_failure', 'cue_switch_cost', 'task_switch_cost', 'task_switch_cue_switch-task_stay_cue_stay']


# Load subject files per session

In [3]:
# where the first level contrast maps are stored
# number of encounters each subject has with a task
max_num_encounters = 5

In [4]:
# arrange each subjects maps by which encounter num it is
all_contrast_maps = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
encounter_maps = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))

for task in TASKS:
    for contrast_name in CONTRASTS[task]:
        for subject in SUBJECTS:
            
            overall_encounter_count = 0
            
            for session in SESSIONS:
                contrast_map_path = build_first_level_contrast_map_path(BASE_DIR, INPUT_LEVEL, subject, session, task, contrast_name)
                
                if os.path.exists(contrast_map_path):
                    all_contrast_maps[task][contrast_name][subject].append(contrast_map_path)
                    encounter_maps[task][contrast_name][subject][overall_encounter_count] = contrast_map_path
                    overall_encounter_count += 1

first_level_session_maps = all_contrast_maps
first_level_encounter_maps = encounter_maps

# general loading and plotting functions that can apply across all tasks

In [5]:
# relevant loading functions taken from 3_create_RSMs_first_level
# function to gather maps of a certain task/contrast from first_level_encounter_maps
def gather_tc_maps(req_tasks,req_contrasts,all_maps=first_level_encounter_maps,req_encounters=[0,1,2,3,4], req_subjects = SUBJECTS):
    '''
    Get a list of loaded niftis for specific task/contrast/encounter combinations of first level maps 
    
    Parameters
        req_tasks: list of tasks as strings (all tasks have to be from the TASKS dict)
        req_contrasts: list of contrasts as strings (all tasks have to be from the CONTRASTS dict)
        all_maps: [task][contrast_name][subject][overall_encounter_count] -> one map each (here it is in a filepath format)
        req_encounters: list of encounter numbers that are requested (default is all 5)
        req_subjects: list of subject id strings that are requested (default is all in SUBJECTS)
    Return
        specified_maps: list of loaded nifti files that fit the requested task, contrast, and encounter (this returns this for all subjects)
        specified_descriptors: list of descriptions of each file (i.e. titles)
        data_title: informative title for the RSM that will later be created
    
    '''
    specified_maps = []
    specified_descriptors = []
    max_num_encounters = 5

    if (len(req_tasks) == 0) or (len(req_contrasts) == 0):
        return [], [], ''

    for task in req_tasks:
        if task not in TASKS:
            print(f"task {task} not in task masterlist")
            continue
    
        for contrast in req_contrasts:
            if contrast not in CONTRASTS[task]: # make sure this contrast exists in the given task
                print(f"skipped for contrast {contrast} and task {task}")
                continue
                
            for subject in req_subjects:
                if subject not in SUBJECTS:
                    print(f"subject: {subject} is not in this dataset, so skipped")
                    continue
                    
                for encounter in req_encounters:
                    if encounter < 0 or encounter >= max_num_encounters:
                        continue

                    descriptor_name = f"{subject}:encounter-0{encounter + 1}"
                            
                    if task in all_maps.keys():
                        if contrast in all_maps[task].keys():
                            if subject in all_maps[task][contrast].keys():
                                if encounter in all_maps[task][contrast][subject].keys():

                                    map_data = all_maps[task][contrast][subject][encounter]
                                    
                                    # Check if file is already loaded
                                    if isinstance(map_data, str):
                                        # map_data is a file path, need to load it
                                        try:
                                            if os.path.exists(map_data):
                                                loaded_map = nib.load(map_data)
                                                specified_maps.append(loaded_map)
                                                specified_descriptors.append(descriptor_name)
                                            else:
                                                print(f"File not found: {map_data}")
                                                failed_loads.append((descriptor_name, "File not found"))
                                        except Exception as e:
                                            print(f"Error loading {map_data}: {str(e)}")
                                    else:
                                        print(f"Unexpected data type for {descriptor_name}: {type(map_data)}")
                                        
                                else:
                                    print(f"{task}|{contrast}|{subject}: {encounter}")
                                    continue
                            else:
                                print(f"{task}|{contrast} subject {subject}")
                                continue
                        else:
                            print(f"{task}:{contrast}")
                            continue
                    else:
                        print(f"{task}")
                        continue
    # create RSM title
    data_title = ''
    if (len(req_tasks) == 1):
        data_title += f'Task:{req_tasks[0]}|'
    else:  # more than 1 task
        data_title += 'Task:'
        for i, task in enumerate(req_tasks):
            if (i != len(req_tasks) - 1):
                data_title += f"{task},"
            else:
                data_title += f"{task}"
        data_title += '|'

    if (len(req_contrasts) == 1):
        data_title += f'Contrast:{req_contrasts[0]}'
    else:
        data_title += 'Contrast:'
        for i, contrast in enumerate(req_contrasts):
            if (i != (len(req_contrasts) - 1)):
                data_title += f"{contrast},"
            else:
                data_title += f"{contrast}"
    
    return specified_maps, specified_descriptors, data_title

In [6]:
# sort this by encounter (later move this into the main gather_tc-maps func)
def reorganize_dict(original_dict):
    new_dict = {}
    
    for task, contrasts in original_dict.items():
        new_dict[task] = {}
        
        for contrast, data in contrasts.items():
            new_dict[task][contrast] = {}
            
            maps_list = data['maps_list']
            descriptors_list = data['descriptors_list']
            
            # Process each map and its corresponding descriptor
            for map_obj, descriptor in zip(maps_list, descriptors_list):
                # (e.g., 'sub-s10:encounter-01')
                parts = descriptor.split(':')
                sub = parts[0]  # 'sub-s10'
                encounter = parts[1]  # 'encounter-01'
                encounter_num = encounter.split('-')[1]  # '01'
                
                # Create nested structure if it doesn't exist
                if sub not in new_dict[task][contrast]:
                    new_dict[task][contrast][sub] = {}
                
                # Assign the map object to the appropriate location
                new_dict[task][contrast][sub][encounter_num] = map_obj
    
    return new_dict
    
def gather_tc_maps_full_task(req_tasks_contrasts = CONTRASTS, curr_task = None, req_contrasts = None, create_subset=False):
    '''
    Get a dict of loaded niftis for every contrast of a requested task (uses gather_tc_maps and puts it into a dict organized
    by task/contrast). If curr_task or req_contrasts is not None, then just do this for one task (not all task/contrasts in 
    req_tasks_contrasts)
    
    Parameters
        req_tasks_contrasts = dict of all tasks that we are organizing
        curr_task: specific task to load (if create_subset = True) instead of doing it for all task/contrasts
        req_contrasts: specific contrasts to load (if create_subset=True and curr_task != None). 
        create_subset: default False (if true, then don't organize the dict for all task/contrasts, but only for the requested current task and its contrasts).
    Return
        task_contrast_all_maps[task][contrast] = ["maps_list","descriptors_list", "data_title"]
    '''
    task_contrast_all_maps = {}

    # if it was just a subset
    if ((create_subset) and (curr_task != None)):
        if req_contrasts == None:
            req_contrasts = req_tasks_contrasts[curr_task]

        print(f"creating a subset dict for {curr_task} and contrasts {req_contrasts.items()}")

        print("MISSING:")
        task_contrast_all_maps[curr_task] = {}
        for contrast in req_contrasts:
            task_contrast_all_maps[curr_task][contrast] = {}
            task_contrast_all_maps[curr_task][contrast]["maps_list"] = []
            task_contrast_all_maps[curr_task][contrast]["descriptors_list"] = []
            task_contrast_all_maps[curr_task][contrast]["data_title"] = ""
        
            req_tasks_tc = [curr_task]
            req_contrasts_tc = [contrast]
        
            task_contrast_all_maps[curr_task][contrast]["maps_list"],task_contrast_all_maps[curr_task][contrast]["descriptors_list"],task_contrast_all_maps[curr_task][contrast]["data_title"] = gather_tc_maps(req_tasks_tc,req_contrasts_tc,all_maps=first_level_encounter_maps,req_encounters=[0,1,2,3,4], req_subjects = SUBJECTS)

        return task_contrast_all_maps

    # if it was for loading all of the task/contrast combos
    for curr_task in req_tasks_contrasts:
        task_contrast_all_maps[curr_task] = {}

        print(f"MISSING for {curr_task}:")
        for contrast in req_tasks_contrasts[curr_task]:
            task_contrast_all_maps[curr_task][contrast] = {}
            task_contrast_all_maps[curr_task][contrast]["maps_list"] = []
            task_contrast_all_maps[curr_task][contrast]["descriptors_list"] = []
            task_contrast_all_maps[curr_task][contrast]["data_title"] = ""
        
            req_tasks_tc = [curr_task]
            req_contrasts_tc = [contrast]
        
            task_contrast_all_maps[curr_task][contrast]["maps_list"],task_contrast_all_maps[curr_task][contrast]["descriptors_list"],task_contrast_all_maps[curr_task][contrast]["data_title"] = gather_tc_maps(req_tasks_tc,req_contrasts_tc,all_maps=first_level_encounter_maps,req_encounters=[0,1,2,3,4], req_subjects = SUBJECTS)
            
    return task_contrast_all_maps

# Parcellate across all task/contrasts/subjects

In [7]:
# load all of the maps in an organized dict and see which maps are missing per task
task_contrast_all_maps = gather_tc_maps_full_task(req_tasks_contrasts = CONTRASTS,create_subset=False)

# Use the function
task_contrast_enc_all_maps = reorganize_dict(task_contrast_all_maps)

MISSING for nBack:
MISSING for flanker:
MISSING for directedForgetting:
MISSING for goNogo:
goNogo|nogo_success-go|sub-s19: 4
goNogo|nogo_success-go|sub-s29: 4
goNogo|nogo_success-go|sub-s43: 4
goNogo|nogo_success|sub-s19: 4
goNogo|nogo_success|sub-s29: 4
goNogo|nogo_success|sub-s43: 4
goNogo|task-baseline|sub-s19: 4
goNogo|task-baseline|sub-s29: 4
goNogo|task-baseline|sub-s43: 4
goNogo|response_time|sub-s19: 4
goNogo|response_time|sub-s29: 4
goNogo|response_time|sub-s43: 4
MISSING for shapeMatching:
MISSING for stopSignal:
MISSING for cuedTS:
cuedTS|cue_switch_cost|sub-s29: 4
cuedTS|task_switch_cost|sub-s29: 4
cuedTS|task_switch_cue_switch-task_stay_cue_stay|sub-s29: 4
cuedTS|task-baseline|sub-s29: 4
cuedTS|response_time|sub-s29: 4
MISSING for spatialTS:


In [8]:
# GET ATLASES:
smor_atlas_path = 'processed_data_dfs/smor_parcel_dfs/smorgasbord_atlas_files/smorgasbord_atlas.pkl'
if not os.path.exists(smor_atlas_path):
    print("Smorgasbord atlas not found. Creating new atlas...")
    create_smor_atlas()
else:
    print("atlas already exists")

atlas already exists


In [9]:
encounters = ['01', '02','03','04','05']

def save_subset_of_parcels(subs_requested, run_num, date, atlas='schaefer'):
    """
    Extract parcel-wise activation values from fMRI contrast maps.
    
    Parameters:
    -----------
    subs_requested : list
        Subject IDs to process
    run_num : str/int
        Run identifier for output filename
    atlas : str
        Either 'schaefer' or 'smor'
    data: str
        Added as processed date
    """
    parcel_dict = {}
    
    # Select atlas
    if atlas == 'schaefer':
        current_atlas = load_schaefer_atlas()
        atlas_name = 'schaefer400'
        print(f'doing {atlas_name}')
    else:
        current_atlas = load_smor_atlas()
        atlas_name = 'smor_parcel'
        print(f'doing {atlas_name}')

    for curr_subj in subs_requested:   
        parcel_dict[curr_subj] = {}
        
        for curr_task in CONTRASTS:
            parcel_dict[curr_subj][curr_task] = {}
    
            for curr_contrast in CONTRASTS[curr_task]:
                parcel_dict[curr_subj][curr_task][curr_contrast] = {}
    
                for enc in encounters:
                    print(f"Processing: {curr_subj} - {curr_task} - {curr_contrast} - Enc {enc}")
    
                    try:
                        fmri_img = task_contrast_enc_all_maps[curr_task][curr_contrast][curr_subj][enc]
                        print(f"fMRI data loaded | Shape: {fmri_img.shape}")
        
                        # Create the masker and get regional avg activation
                        masker = NiftiLabelsMasker(
                            labels_img=current_atlas.maps,
                            standardize=False, 
                            memory='nilearn_cache',
                            strategy='mean'  # Average activation within each region
                        )
                        
                        # Extract regional values
                        regional_values = masker.fit_transform(fmri_img)
                        
                        # Handle labels (decode if bytes)
                        region_labels = [
                            label.decode('utf-8') if isinstance(label, bytes) else label 
                            for label in current_atlas.labels
                        ]
                        
                        # Create activation dataframe
                        activation_df = pd.DataFrame({
                            'region': region_labels,
                            'activation': regional_values.flatten()
                        })
                        
                        # Add network information (handle both Schaefer and other atlases)
                        activation_df['network'] = activation_df['region'].apply(
                            lambda x: x.split('_')[1] if 'Networks' in x else 'Subcortical'
                        )
                        
                        # Add ROI values if available (for smorgasbord atlas)
                        if hasattr(current_atlas, 'roi_values'):
                            activation_df['roi_value'] = current_atlas.roi_values
                        
                        # Save the activation df
                        parcel_dict[curr_subj][curr_task][curr_contrast][enc] = activation_df
                        print(f"Extracted {len(activation_df)} regions")
                        
                    except KeyError as e:
                        print(f"Warning: Data not found - Missing key: {e}")
                        parcel_dict[curr_subj][curr_task][curr_contrast][enc] = None
                        continue
                    except Exception as e:
                        print(f"Error processing data: {str(e)}")
                        parcel_dict[curr_subj][curr_task][curr_contrast][enc] = None
                        continue

    # Save to pickle
    output_dir = f'processed_data_dfs/{atlas_name}_dfs'
    os.makedirs(output_dir, exist_ok=True)
    
    output_file = f'{output_dir}/discovery_parcel_indiv_mean_updated_{date}_{run_num}.pkl'
    with open(output_file, 'wb') as f:
        pickle.dump(parcel_dict, f)
    
    print(f"\nSaved to: {output_file}")
    
    # Delete from cache
    del parcel_dict
    cleanup_memory()

In [None]:
# get first 2 subjects and save
save_subset_of_parcels(SUBJECTS[0:2], 1, date="1117", atlas="smor")

In [None]:
# get next 2 subjects and save
save_subset_of_parcels(SUBJECTS[2:4], 2, date="117", atlas="smor")

In [10]:
# get last subject and save
save_subset_of_parcels([SUBJECTS[-1]], 3, date="1117", atlas="smor")

Loading Smorgasbord atlas...
Atlas loaded with 429 regions
Atlas shape: (193, 229, 193)
doing smor_parcel
Processing: sub-s43 - nBack - twoBack-oneBack - Enc 01
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - twoBack-oneBack - Enc 02
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - twoBack-oneBack - Enc 03
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - twoBack-oneBack - Enc 04
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - twoBack-oneBack - Enc 05
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - match-mismatch - Enc 01
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - match-mismatch - Enc 02
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - match-mismatch - Enc 03
fMRI data loaded | Shap