In [2]:
# uses the parcellated results from step 9 (using schaefer 400) and creates an averaged version across all the participants

# imports and general helper functions

In [1]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
import nilearn
import json
import datetime
import pickle
import seaborn as sns
import gc
import psutil
import math
import scipy.stats as stats
from matplotlib.patches import Patch
from nilearn import plotting
from nilearn.glm.first_level import FirstLevelModel
from nilearn.glm.second_level import SecondLevelModel
from nilearn.glm import threshold_stats_img
from nilearn.image import concat_imgs, mean_img, index_img
from nilearn.reporting import make_glm_report
from nilearn import masking, image
from nilearn import datasets
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from collections import defaultdict
from nilearn.maskers import NiftiLabelsMasker
from nilearn.plotting.find_cuts import find_cut_slices

In [2]:
# general helper functions:
def cleanup_memory():
    """
    Clean up memory between batches
    """
    # Force garbage collection
    gc.collect()
    
    # Get memory info
    memory = psutil.virtual_memory()
    print(f"Memory after cleanup: {memory.percent:.1f}% used ({memory.available/(1024**3):.1f}GB available)")

In [3]:
# all tasks and contrasts
TASKS = ["nBack","flanker","directedForgetting","goNogo", "shapeMatching", "stopSignal", "cuedTS", "spatialTS"]
CONTRASTS = {}
CONTRASTS["nBack"] = ["twoBack-oneBack", "match-mismatch","task-baseline","response_time"] # the nback contrasts
CONTRASTS["flanker"] = ["incongruent-congruent", "task-baseline"]
CONTRASTS["directedForgetting"] = ["neg-con", "task-baseline","response_time"]
CONTRASTS["goNogo"] = ["nogo_success-go", "nogo_success","task-baseline","response_time"] # go_rtModel check
CONTRASTS["shapeMatching"] = ["DDD", "DDS", "DNN", "DSD", "main_vars", "SDD", "SNN", "SSS", "task-baseline","response_time"]
CONTRASTS["stopSignal"] = ["go", "stop_failure-go", "stop_failure", "stop_failure-stop_success", "stop_success-go", "stop_success", "stop_success-stop_failure", "task-baseline","response_time"]
CONTRASTS["cuedTS"] = ["cue_switch_cost", "task_switch_cost", "task_switch_cue_switch-task_stay_cue_stay", "task-baseline","response_time"]
CONTRASTS["spatialTS"] = ["cue_switch_cost", "task_switch_cost", "task_switch_cue_switch-task_stay_cue_stay", "task-baseline","response_time"]

# interested in looking at them all now:
requested_task_contrasts = defaultdict(lambda: defaultdict(list))
requested_task_contrasts['nBack'] = CONTRASTS["nBack"]
requested_task_contrasts['flanker'] = CONTRASTS["flanker"]
requested_task_contrasts['directedForgetting'] = CONTRASTS["directedForgetting"]
requested_task_contrasts['goNogo'] = CONTRASTS["goNogo"]
requested_task_contrasts['shapeMatching'] = CONTRASTS["shapeMatching"]
requested_task_contrasts['stopSignal'] = CONTRASTS["stopSignal"]
requested_task_contrasts['cuedTS'] = CONTRASTS["cuedTS"]
requested_task_contrasts['spatialTS'] = CONTRASTS["spatialTS"] 

# compiled_req_contrasts = ["twoBack-oneBack", 'task-baseline', "incongruent-congruent", "neg-con", "nogo_success-go", "main_vars", "stop_failure-go","task_switch_cost"]
# compile all requested contrasts into one list
compiled_req_contrasts = []
for task in TASKS:
    for contrast in requested_task_contrasts[task]:
        if (contrast not in compiled_req_contrasts):
            compiled_req_contrasts.append(contrast)


ENCOUNTERS = ['01', '02','03','04','05']
SUBJECTS = ['sub-s03', 'sub-s10', 'sub-s19', 'sub-s29', 'sub-s43']


In [7]:
# schafer stuff
SCHAFER_PARCELLATED_DIR = 'schafer400_dfs'
schafer_files = {'mean':f'discovery_parcel_indiv_mean_updated'}
schafer_date_updated = '1001'
indices = [1,2,3]
# Get schaefer atlas
SCHAEFER = datasets.fetch_atlas_schaefer_2018(n_rois=400)
SCHAEFER_IMG = nib.load(SCHAEFER.maps)
SCHAEFER_DATA = SCHAEFER_IMG.get_fdata()


# smorgasbord stuff
SMORG_PARCELLATED_DIR = 'smor_parcel_dfs'
smor_files = {'mean':f'discovery_parcel_indiv_mean_updated'}
smor_date_updated = '1027'
indices = [1,2,3]
# get smorgasbord atlas
with open(f'{SMORG_PARCELLATED_DIR}/smorgasbord_atlas_files/smorgasbord_atlas.pkl', 'rb') as f:
    smorgasbord_atlas = pickle.load(f)
SMORG_IMG = smorgasbord_atlas.maps
SMORG_DATA = SMORG_IMG.get_fdata()

# load the df

In [11]:
req_atlas = "smor"

# Select atlas configuration
if req_atlas == "schafer":
    main_dir = SCHAFER_PARCELLATED_DIR
    main_files = schafer_files
    date_updated = schafer_date_updated
    atlas_obj = schaefer_atlas
elif req_atlas == "smor":
    main_dir = SMORG_PARCELLATED_DIR
    main_files = smor_files
    date_updated = smor_date_updated
    atlas_obj = smorgasbord_atlas
else:
    raise ValueError(f"Unknown atlas: {req_atlas}. Use 'schafer' or 'smor'")

# Load mean parcel data from multiple files
loaded_mean_parcel_dict = {}
mean_filename = f"{main_dir}/{main_files['mean']}_{date_updated}"

for num in indices:
    fin_filename = f"{mean_filename}_{num}.pkl"
    print(f"Loading: {fin_filename}")
    
    try:
        with open(fin_filename, 'rb') as f:
            dict_data = pickle.load(f)
            loaded_mean_parcel_dict.update(dict_data)
            print(f"Loaded {len(dict_data)} subjects")
    except FileNotFoundError:
        print(f"Warning: File not found - {fin_filename}")
        continue
    except Exception as e:
        print(f"Error loading {fin_filename}: {e}")
        continue

print(f"\nTotal subjects loaded: {len(loaded_mean_parcel_dict)}")
print(f"Atlas: {req_atlas} ({len(atlas_obj.labels)} regions)")

Loading: smor_parcel_dfs/discovery_parcel_indiv_mean_updated_1027_1.pkl
Loaded 2 subjects
Loading: smor_parcel_dfs/discovery_parcel_indiv_mean_updated_1027_2.pkl
Loaded 2 subjects
Loading: smor_parcel_dfs/discovery_parcel_indiv_mean_updated_1027_3.pkl
Loaded 1 subjects

Total subjects loaded: 5
Atlas: smor (429 regions)


# relevant parcel analysis functions

In [12]:
def analyze_parcel_practice_effects(parcel_dict, subject, task, contrast, encounters_str = ENCOUNTERS):
    """
    Detailed analysis of practice effects for individual parcels

    inputs:
    parcel_dict: a dict with format subject: task: contrast: encounter: and then the dict of mean activations per parcel (along with region label per parcel)
    subject: subject id to parse their parcel trajectories
    task: task to look at
    contrast: contrast to look at
    encounters: by default its 1-5 (the constant); these are the ones being included in the trajectory calculations

    Note: Assumes encounters 1-4 are always present, only encounter 5 may be missing
    """
    print(f"{subject}/{task}/{contrast}")
    
    # Get all individual parcels
    first_encounter = parcel_dict[subject][task][contrast][encounters_str[0]]
    all_parcels = first_encounter['region'].tolist()
    parcel_results = {}
    
    for parcel in all_parcels:
        # Extract trajectory for this specific parcel
        trajectory = []
        
        for enc in encounters_str:
            try: 
                df = parcel_dict[subject][task][contrast][enc]
                activation = df[df['region'] == parcel]['activation'].iloc[0]
    
                try:
                    activation = float(activation)
                except (ValueError, TypeError):
                    print(f"Warning: Could not convert activation '{activation}' to float for {subject}/{task}/{contrast}/{enc}/{parcel}")
                
                    activation = 0.0
 
                trajectory.append(activation)
                
            except Exception as e:
                print(f"for {subject}, encounter {enc} is missing for {task} {contrast}")
                break # break because it would be the last encounter anyways (5)
            
        trajectory = np.array(trajectory, dtype=float)  # Ensure numeric array
        
        # Statistical analysis
        enc_this = range(1, len(trajectory) + 1) # in cases where there's no 5th encounter it just does 4
        slope, intercept, r_value, p_value, std_err = stats.linregress(enc_this, trajectory)
        
        # Calculate additional metrics
        initial_activation = trajectory[0]
        final_activation = trajectory[-1]
        max_activation = np.max(trajectory)
        min_activation = np.min(trajectory)
        
        # Effect size calculations
        if abs(initial_activation) > 0.001:
            percent_change = ((final_activation - initial_activation) / abs(initial_activation)) * 100
        else:
            percent_change = 0
        
        # Cohen's d for effect size
        trajectory_std = np.std(trajectory)
        if trajectory_std > 0:
            cohens_d = abs(final_activation - initial_activation) / trajectory_std
        else:
            cohens_d = 0
        
        # Classification
        significant_change = (p_value < 0.05)
        
        parcel_results[parcel] = {
            'trajectory': trajectory,
            'slope': slope,
            'intercept': intercept,
            'r_squared': r_value**2,
            'p_value': p_value,
            'std_error': std_err,
            'initial_activation': initial_activation,
            'final_activation': final_activation,
            'percent_change': percent_change,
            'cohens_d': cohens_d,
            'max_activation': max_activation,
            'min_activation': min_activation,
            'activation_range': max_activation - min_activation,
            'significant_change': significant_change,
        }
    
    return parcel_results

In [13]:
# get the parcel trajectory results per subject
parcel_traj_results = {}
for subj in SUBJECTS:
    parcel_traj_results[subj] = {}

    for task in requested_task_contrasts:
        parcel_traj_results[subj][task] = {}

        for contrast in requested_task_contrasts[task]:
            try:
                parcel_traj_results[subj][task][contrast] = analyze_parcel_practice_effects(
                    loaded_mean_parcel_dict, subj, task, contrast
                )
            except Exception as e:
                print(f"Error processing {subj}/{task}/{contrast}: {e}")
                continue

sub-s03/nBack/twoBack-oneBack
sub-s03/nBack/match-mismatch
sub-s03/nBack/task-baseline
sub-s03/nBack/response_time
sub-s03/flanker/incongruent-congruent
sub-s03/flanker/task-baseline
sub-s03/directedForgetting/neg-con
sub-s03/directedForgetting/task-baseline
sub-s03/directedForgetting/response_time
sub-s03/goNogo/nogo_success-go
sub-s03/goNogo/nogo_success
sub-s03/goNogo/task-baseline
sub-s03/goNogo/response_time
sub-s03/shapeMatching/DDD
sub-s03/shapeMatching/DDS
sub-s03/shapeMatching/DNN
sub-s03/shapeMatching/DSD
sub-s03/shapeMatching/main_vars
sub-s03/shapeMatching/SDD
sub-s03/shapeMatching/SNN
sub-s03/shapeMatching/SSS
sub-s03/shapeMatching/task-baseline
sub-s03/shapeMatching/response_time
sub-s03/stopSignal/go
sub-s03/stopSignal/stop_failure-go
sub-s03/stopSignal/stop_failure
sub-s03/stopSignal/stop_failure-stop_success
sub-s03/stopSignal/stop_success-go
sub-s03/stopSignal/stop_success
sub-s03/stopSignal/stop_success-stop_failure
sub-s03/stopSignal/task-baseline
sub-s03/stopSignal

In [14]:
# verify numbers for each
for subj in SUBJECTS:
    count = 0
    
    for task in parcel_traj_results[subj].keys():
        for contrast in parcel_traj_results[subj][task].keys():
            count += 1
    print(f"for {subj} there are {count} specific task/contrast combos loaded")

for sub-s03 there are 42 specific task/contrast combos loaded
for sub-s10 there are 42 specific task/contrast combos loaded
for sub-s19 there are 42 specific task/contrast combos loaded
for sub-s29 there are 42 specific task/contrast combos loaded
for sub-s43 there are 42 specific task/contrast combos loaded


# group analysis

In [15]:
# create an averaged parcel df across all participants and save it to a file
avg_parcel_traj_results = {}

count_success = 0
for task in requested_task_contrasts:
    avg_parcel_traj_results[task] = {}

    for contrast in requested_task_contrasts[task]:
        print(f"Processing {task}/{contrast}:")
        avg_parcel_traj_results[task][contrast] = {}

        # Collect all parcel data across subjects
        parcel_data = defaultdict(list)
        
        for subj in SUBJECTS:
            try:
                curr_res = parcel_traj_results[subj][task][contrast]
                
                # For each parcel in this subject's results
                for parcel_name, parcel_stats in curr_res.items():
                    parcel_data[parcel_name].append(parcel_stats)
                    
            except Exception as e:
                print(f"Error processing {subj}/{task}/{contrast}: {e}")
                continue

        
        # Calculate averages for each parcel
        for parcel_name, subject_data_list in parcel_data.items():
            if len(subject_data_list) == 0:
                continue
                
            # Collect all values across subjects for this parcel
            slopes = [data['slope'] for data in subject_data_list]
            intercepts = [data['intercept'] for data in subject_data_list]
            r_squareds = [data['r_squared'] for data in subject_data_list]
            p_values = [data['p_value'] for data in subject_data_list]
            std_errors = [data['std_error'] for data in subject_data_list]
            initial_activations = [data['initial_activation'] for data in subject_data_list]
            final_activations = [data['final_activation'] for data in subject_data_list]
            percent_changes = [data['percent_change'] for data in subject_data_list]
            cohens_ds = [data['cohens_d'] for data in subject_data_list]
            max_activations = [data['max_activation'] for data in subject_data_list]
            min_activations = [data['min_activation'] for data in subject_data_list]
            activation_ranges = [data['activation_range'] for data in subject_data_list]
            trajectories = [data['trajectory'] for data in subject_data_list]


            # calculate trajectory vals since diff contrasts/subj have different numbers of encounters
            if len(trajectories) > 0:
                # Check if all trajectories are the same length
                trajectory_lengths = [len(traj) for traj in trajectories]
                
                if len(set(trajectory_lengths)) == 1:
                    # All same length
                    trajectory_array = np.array(trajectories)
                    trajectory_mean = np.mean(trajectory_array, axis=0)
                    trajectory_std = np.std(trajectory_array, axis=0)
                    trajectory_sem = trajectory_std / np.sqrt(len(trajectories))
                else:
                    # Different lengths - use padding
                    max_length = max(trajectory_lengths)
                    padded_trajectories = []
                    
                    for traj in trajectories:
                        if len(traj) < max_length:
                            padded = np.full(max_length, np.nan)
                            padded[:len(traj)] = traj
                            padded_trajectories.append(padded)
                        else:
                            padded_trajectories.append(traj)
                    
                    trajectory_array = np.array(padded_trajectories)
                    trajectory_mean = np.nanmean(trajectory_array, axis=0)
                    trajectory_std = np.nanstd(trajectory_array, axis=0)
                    n_valid = np.sum(~np.isnan(trajectory_array), axis=0)
                    trajectory_sem = trajectory_std / np.sqrt(len(trajectories))
        
            # Calculate averages and statistics
            avg_parcel_traj_results[task][contrast][parcel_name] = {
                'n_subjects': len(subject_data_list),
                'trajectory_n_subjects': len(trajectories),
                
                # Slope statistics
                'slope_mean': np.mean(slopes),
                'slope_std': np.std(slopes),
                'slope_sem': np.std(slopes) / np.sqrt(len(slopes)) if len(slopes) > 0 else 0,
                
                # Other metrics
                'intercept_mean': np.mean(intercepts),
                'r_squared_mean': np.mean(r_squareds),
                'p_value_mean': np.mean(p_values),
                
                # Activation statistics
                'initial_activation_mean': np.mean(initial_activations),
                'initial_activation_std': np.std(initial_activations),
                'final_activation_mean': np.mean(final_activations),
                'final_activation_std': np.std(final_activations),
                
                # Change metrics
                'percent_change_mean': np.mean(percent_changes),
                'percent_change_std': np.std(percent_changes),
                'cohens_d_mean': np.mean(cohens_ds),
                
                # Range metrics
                'max_activation_mean': np.mean(max_activations),
                'min_activation_mean': np.mean(min_activations),
                'activation_range_mean': np.mean(activation_ranges),
                
                # # Trajectory information
                'trajectory_mean': trajectory_mean,#mean of values at each encounter
                'trajectory_std': trajectory_std,
                'trajectory_sem': trajectory_sem,
                
                # Summary proportions
                'significant_slope_proportion': np.mean([1 if p < 0.05 else 0 for p in p_values]),
                'positive_slope_proportion': np.mean([1 if s > 0 else 0 for s in slopes]),
            }
        
        print(f"Completed averaging for {len(avg_parcel_traj_results[task][contrast])} parcels for {task}/{contrast}")
        count_success += 1

print(f"Averaging complete! for {count_success} task/contrasts")

Processing nBack/twoBack-oneBack:
Completed averaging for 429 parcels for nBack/twoBack-oneBack
Processing nBack/match-mismatch:
Completed averaging for 429 parcels for nBack/match-mismatch
Processing nBack/task-baseline:
Completed averaging for 429 parcels for nBack/task-baseline
Processing nBack/response_time:
Completed averaging for 429 parcels for nBack/response_time
Processing flanker/incongruent-congruent:
Completed averaging for 429 parcels for flanker/incongruent-congruent
Processing flanker/task-baseline:
Completed averaging for 429 parcels for flanker/task-baseline
Processing directedForgetting/neg-con:
Completed averaging for 429 parcels for directedForgetting/neg-con
Processing directedForgetting/task-baseline:
Completed averaging for 429 parcels for directedForgetting/task-baseline
Processing directedForgetting/response_time:
Completed averaging for 429 parcels for directedForgetting/response_time
Processing goNogo/nogo_success-go:
Completed averaging for 429 parcels for g

In [16]:
# save to a pkl file
with open(f'{mean_filename}_averaged.pkl', 'wb') as f:
    pickle.dump(avg_parcel_traj_results, f)