In [None]:
# JAN 2026
# cleaned version of step 9
# going back to using beta values (putting z vals into the linear regression model doesn't make sense)

In [1]:
import os
import sys
import numpy as np
import pandas as pd
import nibabel as nib
import nilearn
import json
import datetime
import pickle
import seaborn as sns
import gc
import psutil
import math
import scipy.stats as stats
from matplotlib.patches import Patch
from nilearn import plotting
from nilearn.glm.first_level import FirstLevelModel
from nilearn.glm.second_level import SecondLevelModel
from nilearn.glm import threshold_stats_img
from nilearn.image import concat_imgs, mean_img, index_img
from nilearn.reporting import make_glm_report
from nilearn import masking, image
from nilearn import datasets
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from collections import defaultdict
from nilearn.maskers import NiftiLabelsMasker
from sklearn.utils import Bunch
from nilearn.image import resample_to_img

# Import shared utilities and configuration
# need to do it this way because in a sub-directory (later turn config and utils into part of a package)
from utils import (
    TASKS, CONTRASTS, SUBJECTS, SESSIONS, ENCOUNTERS,
    build_first_level_contrast_map_path, is_valid_contrast_map, clean_z_map_data,
    convert_to_regular_dict, create_smor_atlas,load_smor_atlas, load_schaefer_atlas, cleanup_memory
)
from config import BASE_DIR, OUTPUT_DIRS

In [2]:
# compile all requested contrasts into one list
compiled_req_contrasts = []
for task in TASKS:
    for contrast in CONTRASTS[task]:
        if (contrast not in compiled_req_contrasts):
            compiled_req_contrasts.append(contrast)
print(compiled_req_contrasts)

['twoBack-oneBack', 'match-mismatch', 'task-baseline', 'response_time', 'incongruent-congruent', 'neg-con', 'nogo_success-go', 'nogo_success', 'DDD', 'DDS', 'DNN', 'DSD', 'main_vars', 'SDD', 'SNN', 'SSS', 'go', 'stop_failure-go', 'stop_failure', 'stop_failure-stop_success', 'stop_success-go', 'stop_success', 'stop_success-stop_failure', 'cue_switch_cost', 'task_switch_cost', 'task_switch_cue_switch-task_stay_cue_stay']


# load subject files per session:

In [3]:
# max number of encounters each subject has within each task
max_num_encounters = 5

In [4]:
# arrange each subjects maps by which encounter num it is
all_contrast_maps = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
encounter_maps = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))

for task in TASKS:
    for contrast_name in CONTRASTS[task]:
        for subject in SUBJECTS:
            
            overall_encounter_count = 0
            
            for session in SESSIONS:
                # contrast_map_path = build_first_level_contrast_map_path(BASE_DIR, subject, session, task, contrast_name, "z") # LOADS THE ZSTATS
                contrast_map_path = build_first_level_contrast_map_path(BASE_DIR, subject, session, task, contrast_name) # LOADS THE BETAS

                if os.path.exists(contrast_map_path):
                    all_contrast_maps[task][contrast_name][subject].append(contrast_map_path)
                    encounter_maps[task][contrast_name][subject][overall_encounter_count] = contrast_map_path
                    overall_encounter_count += 1

first_level_session_maps = all_contrast_maps
first_level_encounter_maps = encounter_maps

## useful functions

In [5]:
def load_nifti_map(file_path, descriptor_name):
    """
    Load a single NIfTI file and return it with its descriptor.
    
    Parameters:
        file_path: path to the NIfTI file
        descriptor_name: descriptive name for this map
        
    Returns:
        tuple: (loaded_map, descriptor_name) or (None, None) if failed
    """
    if not isinstance(file_path, str):
        print(f"Unexpected data type for {descriptor_name}: {type(file_path)}")
        return None, None
    
    if not os.path.exists(file_path):
        print(f"File not found: {file_path}")
        return None, None
    
    try:
        loaded_map = nib.load(file_path)
        return loaded_map, descriptor_name
    except Exception as e:
        print(f"Error loading {file_path}: {str(e)}")
        return None, None


def gather_maps_for_conditions(tasks, contrasts, encounters=None, subjects=SUBJECTS, 
                                 all_maps=first_level_encounter_maps):
    """
    Gather NIfTI maps for specified tasks, contrasts, encounters, and subjects.
    
    Parameters:
        tasks: list of task names
        contrasts: list of contrast names
        encounters: list of encounter numbers (default: [0,1,2,3,4])
        subjects: list of subject IDs (default: all SUBJECTS)
        all_maps: nested dict structure [task][contrast][subject][encounter] -> filepath
        
    Returns:
        tuple: (list of loaded maps, list of descriptors, data title string)
    """
    if encounters is None:
        encounters = list(range(max_num_encounters))
    
    maps = []
    descriptors = []
    
    for task in tasks:
        if task not in TASKS:
            print(f"Task '{task}' not in task masterlist")
            continue
        
        if task not in all_maps:
            print(f"No data for task '{task}'")
            continue
        
        for contrast in contrasts:
            if contrast not in CONTRASTS[task]:
                print(f"Contrast '{contrast}' not valid for task '{task}'")
                continue
            
            if contrast not in all_maps[task]:
                print(f"No data for contrast '{contrast}' in task '{task}'")
                continue
            
            for subject in subjects:
                if subject not in SUBJECTS:
                    print(f"Subject '{subject}' not in dataset")
                    continue
                
                if subject not in all_maps[task][contrast]:
                    print(f"No data for subject '{subject}' in {task}/{contrast}")
                    continue
                
                for encounter in encounters:
                    if encounter < 0 or encounter >= max_num_encounters:
                        continue
                    
                    if encounter not in all_maps[task][contrast][subject]:
                        print(f"Missing: {task}|{contrast}|{subject}|encounter-{encounter}")
                        continue
                    
                    file_path = all_maps[task][contrast][subject][encounter]
                    descriptor = f"{subject}:encounter-{encounter+1:02d}"
                    
                    loaded_map, loaded_descriptor = load_nifti_map(file_path, descriptor)
                    if loaded_map is not None:
                        maps.append(loaded_map)
                        descriptors.append(loaded_descriptor)
    
    # Create descriptive title
    task_str = ','.join(tasks) if len(tasks) > 1 else tasks[0]
    contrast_str = ','.join(contrasts) if len(contrasts) > 1 else contrasts[0]
    title = f"Task:{task_str}|Contrast:{contrast_str}"
    
    return maps, descriptors, title


def organize_maps_by_task_contrast(tasks_contrasts_dict=CONTRASTS, single_task=None, 
                                     single_contrasts=None):
    """
    Organize all maps by task and contrast into a structured dictionary.
    
    Parameters:
        tasks_contrasts_dict: dict mapping tasks to their contrasts (default: CONTRASTS)
        single_task: if provided, only process this task
        single_contrasts: if provided with single_task, only process these contrasts
        
    Returns:
        dict: {task: {contrast: {"maps_list": [...], "descriptors_list": [...], "data_title": "..."}}}
    """
    result = {}
    
    # Determine which tasks/contrasts to process
    if single_task is not None:
        if single_contrasts is None:
            single_contrasts = tasks_contrasts_dict[single_task]
        tasks_to_process = {single_task: single_contrasts}
        print(f"Creating subset for task '{single_task}' with contrasts: {single_contrasts}")
    else:
        tasks_to_process = tasks_contrasts_dict
    
    # Process each task/contrast combination
    for task, contrasts in tasks_to_process.items():
        result[task] = {}
        print(f"\nProcessing task: {task}")
        
        for contrast in contrasts:
            maps, descriptors, title = gather_maps_for_conditions(
                tasks=[task],
                contrasts=[contrast]
            )
            
            result[task][contrast] = {
                "maps_list": maps,
                "descriptors_list": descriptors,
                "data_title": title
            }
    
    return result


def reorganize_by_encounter(task_contrast_dict):
    """
    Reorganize maps from flat lists into structure: [task][contrast][subject][encounter].
    
    Parameters:
        task_contrast_dict: dict with structure {task: {contrast: {"maps_list": [...], "descriptors_list": [...]}}}
        
    Returns:
        dict: {task: {contrast: {subject: {encounter: map_object}}}}
    """
    result = {}
    
    for task, contrasts in task_contrast_dict.items():
        result[task] = {}
        
        for contrast, data in contrasts.items():
            result[task][contrast] = {}
            
            maps_list = data['maps_list']
            descriptors_list = data['descriptors_list']
            
            for map_obj, descriptor in zip(maps_list, descriptors_list):
                # Parse descriptor: 'sub-s10:encounter-01'
                subject, encounter_str = descriptor.split(':')
                encounter_num = encounter_str.split('-')[1]  # Extract '01'
                
                # Create nested structure
                if subject not in result[task][contrast]:
                    result[task][contrast][subject] = {}
                
                result[task][contrast][subject][encounter_num] = map_obj
    
    return result

## Organized load all task/contrasts/subjects

In [6]:
# Load all maps organized by task/contrast
print("Loading all task/contrast maps...")
task_contrast_all_maps = organize_maps_by_task_contrast()

# Reorganize by encounter
print("\nReorganizing by encounter...")
task_contrast_enc_all_maps = reorganize_by_encounter(task_contrast_all_maps)

print("\nDone!")

Loading all task/contrast maps...

Processing task: nBack

Processing task: flanker

Processing task: directedForgetting

Processing task: goNogo
Missing: goNogo|nogo_success-go|sub-s19|encounter-4
Missing: goNogo|nogo_success-go|sub-s29|encounter-4
Missing: goNogo|nogo_success-go|sub-s43|encounter-4
Missing: goNogo|nogo_success|sub-s19|encounter-4
Missing: goNogo|nogo_success|sub-s29|encounter-4
Missing: goNogo|nogo_success|sub-s43|encounter-4
Missing: goNogo|task-baseline|sub-s19|encounter-4
Missing: goNogo|task-baseline|sub-s29|encounter-4
Missing: goNogo|task-baseline|sub-s43|encounter-4
No data for contrast 'response_time' in task 'goNogo'

Processing task: shapeMatching
Missing: shapeMatching|DDD|sub-s10|encounter-4
Missing: shapeMatching|DDS|sub-s10|encounter-4
Missing: shapeMatching|DNN|sub-s10|encounter-4
Missing: shapeMatching|DSD|sub-s10|encounter-4
Missing: shapeMatching|main_vars|sub-s10|encounter-4
Missing: shapeMatching|SDD|sub-s10|encounter-4
Missing: shapeMatching|SNN|

# load atlas data

In [7]:
# GET ATLASES:
smor_atlas_path = 'smorgasbord_atlas_files/smorgasbord_atlas.pkl'
if not os.path.exists(smor_atlas_path):
    print("Smorgasbord atlas not found. Creating new atlas...")
    create_smor_atlas()
else:
    print("atlas already exists")
    current_atlas = load_smor_atlas()

atlas already exists
Loading Smorgasbord atlas...
Atlas loaded with 429 regions
Atlas shape: (193, 229, 193)


# parcellate each individual's encounter data

In [8]:
encounters = ['01', '02','03','04','05']

def save_subset_of_parcels(subs_requested, run_num, date, atlas='smor', file_type="default"):
    """
    Extract parcel-wise activation values from fMRI contrast maps.
    
    Parameters:
    -----------
    subs_requested : list
        Subject IDs to process
    run_num : str/int
        Run identifier for output filename
    atlas : str
        Either 'schaefer' or 'smor_parcel'
    date: str
        Added as processed date
    file_type: str
        Either 'default' or 'z' for z-scored data
    """
    from nilearn.image import resample_to_img
    
    parcel_dict = {}
    
    # Select atlas
    if atlas == 'smor':
        current_atlas = load_smor_atlas()
        atlas_name = 'smor'
        print(f'doing {atlas_name}')
    else:
        current_atlas = load_schaefer_atlas()
        atlas_name = 'schaefer400'
        print(f'doing {atlas_name}')
    
    # Track if resampled the atlas yet
    resampled_atlas = None
    needs_resampling = False

    if not isinstance(subs_requested, list):
        subs_requested = [subs_requested]
        
    for curr_subj in subs_requested:   
        parcel_dict[curr_subj] = {}
        
        for curr_task in CONTRASTS:
            parcel_dict[curr_subj][curr_task] = {}
    
            for curr_contrast in CONTRASTS[curr_task]:
                parcel_dict[curr_subj][curr_task][curr_contrast] = {}
    
                for enc in encounters:
                    print(f"Processing: {curr_subj} - {curr_task} - {curr_contrast} - Enc {enc}")
    
                    try:
                        fmri_img = task_contrast_enc_all_maps[curr_task][curr_contrast][curr_subj][enc]
                        print(f"fMRI data loaded | Shape: {fmri_img.shape}")
                        
                        # Check if atlas needs resampling (only check once)
                        if resampled_atlas is None:
                            if current_atlas.maps.shape != fmri_img.shape:
                                needs_resampling = True
                                print(f"Atlas shape {current_atlas.maps.shape} != fMRI shape {fmri_img.shape}")
                                print(f"Resampling atlas once to match fMRI resolution...")

                                atlas_data = current_atlas.maps.get_fdata()  # Extract just the numpy array
                                atlas_affine = current_atlas.maps.affine      # Extract just the affine matrix
                                clean_atlas = nib.Nifti1Image(atlas_data, atlas_affine)  # Create NEW image (no file path)

                                resampled_atlas = resample_to_img(
                                    clean_atlas,
                                    fmri_img,
                                    interpolation='nearest',
                                    force_resample=True,
                                    copy_header=True
                                )
                                print(f"Atlas resampled to: {resampled_atlas.shape}")
                            else:
                                needs_resampling = False
                                resampled_atlas = current_atlas.maps
                                print(f"Atlas and fMRI shapes match - no resampling needed")
                        
                        # Use resampled atlas for masker
                        masker = NiftiLabelsMasker(
                            labels_img=resampled_atlas,
                            standardize=False, 
                            memory='nilearn_cache',
                            strategy='mean'
                        )
                        
                        # Extract regional values
                        regional_values = masker.fit_transform(fmri_img)
                        
                        # Handle labels (decode if bytes)
                        region_labels = [
                            label.decode('utf-8') if isinstance(label, bytes) else label 
                            for label in current_atlas.labels
                        ]
                        
                        # Create activation dataframe
                        activation_df = pd.DataFrame({
                            'region': region_labels,
                            'activation': regional_values.flatten()
                        })
                        
                        # Add network information (handle both Schaefer and other atlases)
                        activation_df['network'] = activation_df['region'].apply(
                            lambda x: x.split('_')[1] if 'Networks' in x else 'Subcortical'
                        )
                        
                        # Add ROI values if available (for smorgasbord atlas)
                        if hasattr(current_atlas, 'roi_values'):
                            activation_df['roi_value'] = current_atlas.roi_values
                        
                        # Save the activation df
                        parcel_dict[curr_subj][curr_task][curr_contrast][enc] = activation_df
                        print(f"Extracted {len(activation_df)} regions")
                        
                    except KeyError as e:
                        print(f"Warning: Data not found - Missing key: {e}")
                        parcel_dict[curr_subj][curr_task][curr_contrast][enc] = None
                        continue
                    except Exception as e:
                        print(f"Error processing data: {str(e)}")
                        parcel_dict[curr_subj][curr_task][curr_contrast][enc] = None
                        continue
    
    # Save to pickle
    output_ending = "_betas"
    if (file_type == "z"):
        output_ending = "_z_scored"
    
    output_dir = OUTPUT_DIRS[atlas]
    os.makedirs(output_dir, exist_ok=True)
    
    output_file = f'{output_dir}/discovery_parcel_indiv_mean_updated_{date}_{run_num}{output_ending}.pkl'
    with open(output_file, 'wb') as f:
        pickle.dump(parcel_dict, f)
    
    print(f"\nSaved to: {output_file}")
    
    return parcel_dict

## parcellate in groups of 2 subjects at a time (due to memory constraints)

In [None]:
# get first 2 subjects and save
parcel_dict_temp = save_subset_of_parcels(SUBJECTS[0:2], 1, date="0111", atlas="smor")

In [None]:
# get next 2 subjects and save
parcel_dict_temp = save_subset_of_parcels(SUBJECTS[2:4], 2, date="0111", atlas="smor")

In [9]:
# get last subj
parcel_dict_temp = save_subset_of_parcels(SUBJECTS[-1], 3, date="0111", atlas="smor")

Loading Smorgasbord atlas...
Atlas loaded with 429 regions
Atlas shape: (193, 229, 193)
doing smor
Processing: sub-s43 - nBack - twoBack-oneBack - Enc 01
fMRI data loaded | Shape: (97, 115, 97)
Atlas shape (193, 229, 193) != fMRI shape (97, 115, 97)
Resampling atlas once to match fMRI resolution...
Atlas resampled to: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - twoBack-oneBack - Enc 02
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - twoBack-oneBack - Enc 03
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - twoBack-oneBack - Enc 04
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - twoBack-oneBack - Enc 05
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - match-mismatch - Enc 01
fMRI data loaded | Shape: (97, 115, 97)
Extracted 429 regions
Processing: sub-s43 - nBack - match-mismatch - Enc 02
fMRI d