In [None]:
# visualizes the parcellated results from step 9 (using schaefer 400)

# imports and general helper functions

In [1]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
import nilearn
import json
import datetime
import pickle
import seaborn as sns
import gc
import psutil
import math
import scipy.stats as stats
from matplotlib.patches import Patch
from nilearn import plotting
from nilearn.glm.first_level import FirstLevelModel
from nilearn.glm.second_level import SecondLevelModel
from nilearn.glm import threshold_stats_img
from nilearn.image import concat_imgs, mean_img, index_img
from nilearn.reporting import make_glm_report
from nilearn import masking, image
from nilearn import datasets
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from collections import defaultdict
from nilearn.maskers import NiftiLabelsMasker

In [2]:
# general helper functions:
def build_contrast_map_path(base_dir, level, subject, session, task, contrast_name):
    """Build the file path for a contrast map."""
    filename = f'{subject}_{session}_task-{task}_contrast-{contrast_name}_rtmodel-rt_centered_stat-effect-size.nii.gz'
    
    # NOTE: as of 7/6/25 for sub 10 in flanker the format is different: sub-s10_ses-01_run-1_task-flanker_contrast-incongruent-congruent_rtmodel-rt_centered_stat-effect-size.nii.gz
    if (subject == 'sub-s10' and task == 'flanker'):
        filename = f'{subject}_{session}_run-1_task-{task}_contrast-{contrast_name}_rtmodel-rt_centered_stat-effect-size.nii.gz'
        
    return os.path.join(base_dir, level, subject, task, 'indiv_contrasts', filename)

def is_valid_contrast_map(img_path):
    """Check if a contrast map has sufficient variance and no NaN values."""
    try:
        img = nib.load(img_path)
        data = img.get_fdata()
        return np.std(data) > 1e-10 and not np.isnan(data).any()
    except Exception as e:
        print(f"Error validating {img_path}: {e}")
        return False
        
def clean_z_map_data(z_map, task, contrast_name, encounter):
    """Clean z-map data by handling NaN and infinity values."""
    data = z_map.get_fdata()
    if np.isnan(data).any() or np.isinf(data).any():
        data = np.nan_to_num(data)
        z_map = nib.Nifti1Image(data, z_map.affine, z_map.header)
        print(f"Warning: Fixed NaN/Inf values in {task}:{contrast_name}:encounter-{encounter+1}")
    return z_map

def cleanup_memory():
    """
    Clean up memory between batches
    """
    # Force garbage collection
    gc.collect()
    
    # Get memory info
    memory = psutil.virtual_memory()
    print(f"Memory after cleanup: {memory.percent:.1f}% used ({memory.available/(1024**3):.1f}GB available)")
def convert_to_regular_dict(d):
    if isinstance(d, defaultdict):
        return {k: convert_to_regular_dict(v) for k, v in d.items()}
    elif isinstance(d, list):
        return [convert_to_regular_dict(i) for i in d]
    else:
        return d

In [15]:
# all tasks and contrasts
TASKS = ["nBack","flanker","directedForgetting","goNogo", "shapeMatching", "stopSignal", "cuedTS", "spatialTS"]
CONTRASTS = {}
CONTRASTS["nBack"] = ["twoBack-oneBack", "match-mismatch","task-baseline","response_time"] # the nback contrasts
CONTRASTS["flanker"] = ["incongruent-congruent", "task-baseline"]
CONTRASTS["directedForgetting"] = ["neg-con", "task-baseline","response_time"]
CONTRASTS["goNogo"] = ["nogo_success-go", "nogo_success","task-baseline","response_time"] # go_rtModel check
CONTRASTS["shapeMatching"] = ["DDD", "DDS", "DNN", "DSD", "main_vars", "SDD", "SNN", "SSS", "task-baseline","response_time"]
CONTRASTS["stopSignal"] = ["go", "stop_failure-go", "stop_failure", "stop_failure-stop_success", "stop_success-go", "stop_success", "stop_success-stop_failure", "task-baseline","response_time"]
CONTRASTS["cuedTS"] = ["cue_switch_cost", "task_switch_cost", "task_switch_cue_switch-task_stay_cue_stay", "task-baseline","response_time"]
CONTRASTS["spatialTS"] = ["cue_switch_cost", "task_switch_cost", "task_switch_cue_switch-task_stay_cue_stay", "task-baseline","response_time"]

# main conditions and contrasts that we're interested in looking at
requested_task_contrasts = defaultdict(lambda: defaultdict(list))
requested_task_contrasts['nBack'] = ["twoBack-oneBack", 'task-baseline']
requested_task_contrasts['flanker'] = ["incongruent-congruent",'task-baseline']
requested_task_contrasts['directedForgetting'] = ["neg-con",'task-baseline']
requested_task_contrasts['goNogo'] = ["nogo_success-go",'task-baseline']
requested_task_contrasts['shapeMatching'] = ["main_vars",'task-baseline']
requested_task_contrasts['stopSignal'] = ["stop_failure-go",'task-baseline']
requested_task_contrasts['cuedTS'] = ["task_switch_cost",'task-baseline']
requested_task_contrasts['spatialTS'] = ["task_switch_cost",'task-baseline']

compiled_req_contrasts = ["twoBack-oneBack", 'task-baseline', "incongruent-congruent", "neg-con", "nogo_success-go", "main_vars", "stop_failure-go","task_switch_cost"]

ENCOUNTERS = ['01', '02','03','04','05']
SUBJECTS = ['sub-s03', 'sub-s10', 'sub-s19', 'sub-s29', 'sub-s43']


In [4]:
SCHAFER_PARCELLATED_DIR = 'schafer400_dfs'
schafer_files = {'mean':'discovery_parcel_mean_924.pkl'}

# load the df

In [30]:
mean_filename = f"{SCHAFER_PARCELLATED_DIR}/{schafer_files['mean']}"
with open(mean_filename, 'rb') as f:
    loaded_mean_parcel_dict = pickle.load(f)

In [31]:
print(loaded_mean_parcel_dict['sub-s19']['nBack']['twoBack-oneBack']['01'])

                             region  activation network
0                7Networks_LH_Vis_1    0.086090      LH
1                7Networks_LH_Vis_2    0.244737      LH
2                7Networks_LH_Vis_3    0.229025      LH
3                7Networks_LH_Vis_4   -0.175264      LH
4                7Networks_LH_Vis_5   -0.102105      LH
..                              ...         ...     ...
395  7Networks_RH_Default_pCunPCC_5    0.390901      RH
396  7Networks_RH_Default_pCunPCC_6    0.262363      RH
397  7Networks_RH_Default_pCunPCC_7    0.311818      RH
398  7Networks_RH_Default_pCunPCC_8    0.228172      RH
399  7Networks_RH_Default_pCunPCC_9    0.111930      RH

[400 rows x 3 columns]


# relevant parcel analysis functions

In [36]:
def analyze_parcel_practice_effects(parcel_dict, subject, task, contrast, encounters_str = ENCOUNTERS):
    """
    Detailed analysis of practice effects for individual parcels

    inputs:
    parcel_dict: a dict with format subject: task: contrast: encounter: and then the dict of mean activations per parcel (along with region label per parcel)
    subject: subject id to parse their parcel trajectories
    task: task to look at
    contrast: contrast to look at
    encounters: by default its 1-5 (the constant); these are the ones being included in the trajectory calculations
    
    """
    print(f"{subject}/{task}/{contrast}")
    # Get all individual parcels
    first_encounter = parcel_dict[subject][task][contrast][encounters_str[0]]
    all_parcels = first_encounter['region'].tolist()
    encounters = [1,2,3,4,5]
    parcel_results = {}
    
    for parcel in all_parcels:
        # Extract trajectory for this specific parcel
        trajectory = []
        
        for enc_num, enc in enumerate(encounters_str, 1):
            df = parcel_dict[subject][task][contrast][enc]
            activation = df[df['region'] == parcel]['activation'].iloc[0]

            try:
                activation = float(activation)
            except (ValueError, TypeError):
                print(f"Warning: Could not convert activation '{activation}' to float for {subject}/{task}/{contrast}/{enc}/{parcel}")
                activation = 0.0
            
            # activation = df[df['region'] == parcel]['activation']
            trajectory.append(activation)
        
        trajectory = np.array(trajectory, dtype=float)  # Ensure numeric array
        
        # Statistical analysis
        slope, intercept, r_value, p_value, std_err = stats.linregress(encounters, trajectory)
        
        # Calculate additional metrics
        initial_activation = trajectory[0]
        final_activation = trajectory[-1]
        max_activation = np.max(trajectory)
        min_activation = np.min(trajectory)
        
        # Effect size calculations
        if abs(initial_activation) > 0.001:
            percent_change = ((final_activation - initial_activation) / abs(initial_activation)) * 100
        else:
            percent_change = 0
        
        # Cohen's d for effect size
        trajectory_std = np.std(trajectory)
        if trajectory_std > 0:
            cohens_d = abs(final_activation - initial_activation) / trajectory_std
        else:
            cohens_d = 0
        
        # Classification
        significant_change = (p_value < 0.05)
        large_change = significant_change and (abs(percent_change) > 10)
        
        parcel_results[parcel] = {
            'trajectory': trajectory,
            'slope': slope,
            'intercept': intercept,
            'r_squared': r_value**2,
            'p_value': p_value,
            'std_error': std_err,
            'initial_activation': initial_activation,
            'final_activation': final_activation,
            'percent_change': percent_change,
            'cohens_d': cohens_d,
            'max_activation': max_activation,
            'min_activation': min_activation,
            'activation_range': max_activation - min_activation,
            'significant_change': significant_change,
            'large_change': large_change,
            # 'network': df[df['region'] == parcel]['network'].iloc[0]
        }
    
    return parcel_results

In [37]:
# get the parcel trajectory results per subject
parcel_traj_results = {}
for subj in SUBJECTS:
    if (subj == "sub-s03"):
        continue
    parcel_traj_results[subj] = {}

    for task in requested_task_contrasts:
        parcel_traj_results[subj][task] = {}

        for contrast in requested_task_contrasts[task]:
            try:
                parcel_traj_results[subj][task][contrast] = analyze_parcel_practice_effects(
                    loaded_mean_parcel_dict, subj, task, contrast
                )
            except Exception as e:
                print(f"Error processing {subj}/{task}/{contrast}: {e}")
                continue

sub-s10/nBack/twoBack-oneBack
sub-s10/nBack/task-baseline
sub-s10/flanker/incongruent-congruent
sub-s10/flanker/task-baseline
sub-s10/directedForgetting/neg-con
sub-s10/directedForgetting/task-baseline
sub-s10/goNogo/nogo_success-go
sub-s10/goNogo/task-baseline
sub-s10/shapeMatching/main_vars
sub-s10/shapeMatching/task-baseline
sub-s10/stopSignal/stop_failure-go
sub-s10/stopSignal/task-baseline
sub-s10/cuedTS/task_switch_cost
sub-s10/cuedTS/task-baseline
sub-s10/spatialTS/task_switch_cost
sub-s10/spatialTS/task-baseline
sub-s19/nBack/twoBack-oneBack
sub-s19/nBack/task-baseline
sub-s19/flanker/incongruent-congruent
sub-s19/flanker/task-baseline
sub-s19/directedForgetting/neg-con
sub-s19/directedForgetting/task-baseline
sub-s19/goNogo/nogo_success-go
Error processing sub-s19/goNogo/nogo_success-go: '05'
sub-s19/goNogo/task-baseline
Error processing sub-s19/goNogo/task-baseline: '05'
sub-s19/shapeMatching/main_vars
sub-s19/shapeMatching/task-baseline
sub-s19/stopSignal/stop_failure-go
sub

# visualization functions

In [None]:
def create_parcel_practice_heatmap(parcel_traj):
    """
    Create a heatmap showing practice effects across all parcels

    input:
    parcel_traj: a df of parcels and activations 
    title: the title for this heatmap
    """
    
    # Prepare data for heatmap
    df = pd.DataFrame(parcel_traj).T
    
    # Sort by slope
    df_sorted = df.sort_values(['slope'])
    
    # Create trajectory matrix
    trajectory_matrix = np.array([row['trajectory'] for _, row in df_sorted.iterrows()])
    
    # Create the heatmap
    plt.figure(figsize=(30, 12))
    
    # Plot trajectories
    sns.heatmap(trajectory_matrix, 
                xticklabels=['Enc 1', 'Enc 2', 'Enc 3', 'Enc 4', 'Enc 5'],
                yticklabels=[row.name.split('_')[-1] for _, row in df_sorted.iterrows()],
                cmap='RdBu_r', center=0, 
                cbar_kws={'label': 'Activation'})
    
    plt.title('Practice Effects Across All Parcels')
    plt.xlabel('Encounter')
    plt.ylabel('Brain Parcels (sorted by network)')
    
    # Add network boundaries
    networks = df_sorted['network'].values
    boundaries = []
    current_network = networks[0]
    
    for i, network in enumerate(networks[1:], 1):
        if network != current_network:
            boundaries.append(i)
            current_network = network
    
    for boundary in boundaries:
        plt.axhline(y=boundary, color='white', linewidth=2)
    
    plt.tight_layout()
    plt.show()

# individual analysis

# group analysis