Create and save RSMs for all of the task/contrast maps per subject per encounter

In [1]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
import nilearn
import json
import datetime
import pickle
import seaborn as sns
import gc
import psutil
import math
import scipy.stats as stats
from matplotlib.patches import Patch
from nilearn import plotting
from nilearn.glm.first_level import FirstLevelModel
from nilearn.glm.second_level import SecondLevelModel
from nilearn.glm import threshold_stats_img
from nilearn.image import concat_imgs, mean_img, index_img
from nilearn.reporting import make_glm_report
from nilearn import masking, image
from nilearn import datasets
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from collections import defaultdict

In [2]:
# where the first level contrast maps are stored
BASE_DIR = '/oak/stanford/groups/russpold/data/network_grant/discovery_BIDS_20250402/derivatives/'
LEVEL = 'output_lev1_mni_no_ted_comp'

# subjects in the discovery sample
SUBJECTS = ['sub-s03', 'sub-s10', 'sub-s19', 'sub-s29', 'sub-s43']

In [3]:
# relevant task and contrast and session data
TASKS = ["nBack","flanker","directedForgetting","goNogo", "shapeMatching", "stopSignal", "cuedTS", "spatialTS"]
CONTRASTS = {}
CONTRASTS["nBack"] = ["twoBack-oneBack", "match-mismatch","task-baseline","response_time"] # the nback contrasts
CONTRASTS["flanker"] = ["incongruent-congruent", "task-baseline", "incongruent-congruent","response_time"]
CONTRASTS["directedForgetting"] = ["neg-con", "task-baseline","response_time"]
CONTRASTS["goNogo"] = ["nogo_success-go", "nogo_success","task-baseline","response_time"] # go_rtModel check
CONTRASTS["shapeMatching"] = ["DDD", "DDS", "DNN", "DSD", "main_vars", "SDD", "SNN", "SSS", "task-baseline","response_time"]
CONTRASTS["stopSignal"] = ["go", "stop_failure-go", "stop_failure", "stop_failure-stop_success", "stop_success-go", "stop_success", "stop_success-stop_failure", "task-baseline","response_time"]
CONTRASTS["cuedTS"] = ["cue_switch_cost", "task_switch_cost", "task_switch_cue_switch-task_stay_cue_stay", "task-baseline","response_time"]
CONTRASTS["spatialTS"] = ["cue_switch_cost", "task_switch_cost", "task_switch_cue_switch-task_stay_cue_stay", "task-baseline","response_time"]
SESSIONS = ['ses-01', 'ses-02', 'ses-03', 'ses-04', 'ses-05', 'ses-06', 'ses-07', 'ses-08', 'ses-09','ses-10']

# number of encounters each subject has with a task
max_num_encounters = 5

In [4]:
# Yeo network information
# download Yeo atlas
yeo = datasets.fetch_atlas_yeo_2011()
yeo_7network = yeo['thick_7'] 
yeo_atlas_nifti = nib.load(yeo_7network)
if yeo_atlas_nifti.get_fdata().ndim == 4:
    yeo_data_3d = yeo_atlas_nifti.get_fdata().squeeze()
    yeo_atlas_3d = nib.Nifti1Image(yeo_data_3d, yeo_atlas_nifti.affine, yeo_atlas_nifti.header)
    print(f"Fixed shape: {yeo_atlas_3d.shape}")
else:
    yeo_atlas_3d = yeo_atlas_nifti

network_names = {
    1: "Visual",
    2: "Somatomotor",
    3: "Dorsal Attention",
    4: "Ventral Attention",
    5: "Limbic",
    6: "Frontoparietal Control",
    7: "Default Mode"
}

Fixed shape: (256, 256, 256)


In [5]:
# helper functions:
def build_contrast_map_path(base_dir, level, subject, session, task, contrast_name):
    """Build the file path for a contrast map."""
    filename = f'{subject}_{session}_task-{task}_contrast-{contrast_name}_rtmodel-rt_centered_stat-effect-size.nii.gz'
    return os.path.join(base_dir, level, subject, task, 'indiv_contrasts', filename)

def is_valid_contrast_map(img_path):
    """Check if a contrast map has sufficient variance and no NaN values."""
    try:
        img = nib.load(img_path)
        data = img.get_fdata()
        return np.std(data) > 1e-10 and not np.isnan(data).any()
    except Exception as e:
        print(f"Error validating {img_path}: {e}")
        return False
        
def clean_z_map_data(z_map, task, contrast_name, encounter):
    """Clean z-map data by handling NaN and infinity values."""
    data = z_map.get_fdata()
    if np.isnan(data).any() or np.isinf(data).any():
        data = np.nan_to_num(data)
        z_map = nib.Nifti1Image(data, z_map.affine, z_map.header)
        print(f"Warning: Fixed NaN/Inf values in {task}:{contrast_name}:encounter-{encounter+1}")
    return z_map

def save_rsm(rsm_results, filename):
    """
    Simple save function
    
    Parameters:
        rsm_results: RSM results dictionary
        filename: filename to save (will add .pkl automatically)
    """
    if not filename.endswith('.pkl'):
        filename += '.pkl'
    
    with open(filename, 'wb') as f:
        pickle.dump(rsm_results, f)
    
    file_size = os.path.getsize(filename) / (1024 * 1024)
    print(f"Saved to {filename} ({file_size:.1f} MB)")

def load_rsm(filename):
    """
    Simple load function
    
    Parameters:
        filename: filename to load
    
    Returns:
        rsm_results: Loaded RSM dictionary
    """
    if not filename.endswith('.pkl'):
        filename += '.pkl'
    
    with open(filename, 'rb') as f:
        rsm_results = pickle.load(f)
    
    print(f"Loaded from {filename}")
    return rsm_results

def cleanup_memory():
    """
    Clean up memory between batches
    """
    # Force garbage collection
    gc.collect()
    
    # Get memory info
    memory = psutil.virtual_memory()
    print(f"Memory after cleanup: {memory.percent:.1f}% used ({memory.available/(1024**3):.1f}GB available)")

In [6]:
# arrange each subjects maps by which encounter num it is
all_contrast_maps = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
encounter_maps = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))

for task in TASKS:
    for contrast_name in CONTRASTS[task]:
        for subject in SUBJECTS:
            overall_encounter_count = 0
            
            for session in SESSIONS:
                contrast_map_path = build_contrast_map_path(BASE_DIR, LEVEL, subject, session, task, contrast_name)
                
                if os.path.exists(contrast_map_path):
                    all_contrast_maps[task][contrast_name][subject].append(contrast_map_path)
                    encounter_maps[task][contrast_name][subject][overall_encounter_count] = contrast_map_path
                    overall_encounter_count += 1

first_level_session_maps = all_contrast_maps
first_level_encounter_maps = encounter_maps

## generate the RSMs

In [7]:
# function to gather maps of a certain task/contrast from first_level_encounter_maps
def gather_tc_maps(req_tasks,req_contrasts,all_maps=first_level_encounter_maps,req_encounters=[0,1,2,3,4], req_subjects = SUBJECTS):
    '''
    Get a list of loaded niftis for specific task/contrast/encounter combinations of first level maps 
    
    Parameters
        req_tasks: list of tasks as strings (all tasks have to be from the TASKS dict)
        req_contrasts: list of contrasts as strings (all tasks have to be from the CONTRASTS dict)
        all_maps: [task][contrast_name][subject][overall_encounter_count] -> one map each (here it is in a filepath format)
        req_encounters: list of encounter numbers that are requested (default is all 5)
        req_subjects: list of subject id strings that are requested (default is all in SUBJECTS)
    Return
        specified_maps: list of loaded nifti files that fit the requested task, contrast, and encounter (this returns this for all subjects)
        specified_descriptors: list of descriptions of each file (i.e. titles)
        data_title: informative title for the RSM that will later be created
    
    '''
    specified_maps = []
    specified_descriptors = []
    max_num_encounters = 5

    if (len(req_tasks) == 0) or (len(req_contrasts) == 0):
        return [], [], ''

    for task in req_tasks:
        if task not in TASKS:
            print(f"task {task} not in task masterlist")
            continue
    
        for contrast in req_contrasts:
            if contrast not in CONTRASTS[task]: # make sure this contrast exists in the given task
                print(f"skipped for contrast {contrast} and task {task}")
                continue
                
            for subject in req_subjects:
                if subject not in SUBJECTS:
                    print(f"subject: {subject} is not in this dataset, so skipped")
                    continue
                    
                for encounter in req_encounters:
                    if encounter < 0 or encounter >= max_num_encounters:
                        continue

                    descriptor_name = f"{subject}:encounter-0{encounter + 1}"
                            
                    if task in all_maps.keys():
                        if contrast in all_maps[task].keys():
                            if subject in all_maps[task][contrast].keys():
                                if encounter in all_maps[task][contrast][subject].keys():

                                    map_data = all_maps[task][contrast][subject][encounter]
                                    
                                    # Check if file is already loaded
                                    if isinstance(map_data, str):
                                        # map_data is a file path, need to load it
                                        try:
                                            if os.path.exists(map_data):
                                                loaded_map = nib.load(map_data)
                                                specified_maps.append(loaded_map)
                                                specified_descriptors.append(descriptor_name)
                                            else:
                                                print(f"File not found: {map_data}")
                                                failed_loads.append((descriptor_name, "File not found"))
                                        except Exception as e:
                                            print(f"Error loading {map_data}: {str(e)}")
                                    else:
                                        print(f"Unexpected data type for {descriptor_name}: {type(map_data)}")
                                        
                                else:
                                    print(f"{task}|{contrast}|{subject}: {encounter}")
                                    continue
                            else:
                                print(f"{task}|{contrast} subject {subject}")
                                continue
                        else:
                            print(f"{task}:{contrast}")
                            continue
                    else:
                        print(f"{task}")
                        continue
    # create RSM title
    data_title = ''
    if (len(req_tasks) == 1):
        data_title += f'Task:{req_tasks[0]}|'
    else:  # more than 1 task
        data_title += 'Task:'
        for i, task in enumerate(req_tasks):
            if (i != len(req_tasks) - 1):
                data_title += f"{task},"
            else:
                data_title += f"{task}"
        data_title += '|'

    if (len(req_contrasts) == 1):
        data_title += f'Contrast:{req_contrasts[0]}'
    else:
        data_title += 'Contrast:'
        for i, contrast in enumerate(req_contrasts):
            if (i != (len(req_contrasts) - 1)):
                data_title += f"{contrast},"
            else:
                data_title += f"{contrast}"
    
    return specified_maps, specified_descriptors, data_title

In [8]:
task_contrast_all_maps = {}
not_included = {}
print("MISSING:") # to see what maps are missing
for task in TASKS:
    task_contrast_all_maps[task] = {}
    for contrast in CONTRASTS[task]:
        task_contrast_all_maps[task][contrast] = {}
        task_contrast_all_maps[task][contrast]["maps_list"] = []
        task_contrast_all_maps[task][contrast]["descriptors_list"] = []
        task_contrast_all_maps[task][contrast]["data_title"] = ""

        req_tasks = [task]
        req_contrasts = [contrast]

        task_contrast_all_maps[task][contrast]["maps_list"],task_contrast_all_maps[task][contrast]["descriptors_list"],task_contrast_all_maps[task][contrast]["data_title"] = gather_tc_maps(req_tasks,req_contrasts,all_maps=first_level_encounter_maps,req_encounters=[0,1,2,3,4], req_subjects = SUBJECTS)
        

MISSING:
goNogo|nogo_success-go|sub-s19: 4
goNogo|nogo_success-go|sub-s29: 4
goNogo|nogo_success-go|sub-s43: 4
goNogo|nogo_success|sub-s19: 4
goNogo|nogo_success|sub-s29: 4
goNogo|nogo_success|sub-s43: 4
goNogo|task-baseline|sub-s19: 4
goNogo|task-baseline|sub-s29: 4
goNogo|task-baseline|sub-s43: 4
goNogo|response_time|sub-s19: 4
goNogo|response_time|sub-s29: 4
goNogo|response_time|sub-s43: 4
cuedTS|cue_switch_cost|sub-s29: 4
cuedTS|task_switch_cost|sub-s29: 4
cuedTS|task_switch_cue_switch-task_stay_cue_stay|sub-s29: 4
cuedTS|task-baseline|sub-s29: 4
cuedTS|response_time|sub-s29: 4


In [9]:
# resample Yeo atlas to get each network activation for each task/contrast pair as well
def resample_atlas(reference_map_object, yeo_atlas_data=yeo_atlas_3d):
    """
    Resample atlas to match reference map
    
    Parameters:
        reference_map_object: loaded NIfTI object to use as reference
        yeo_atlas_data: loaded Yeo atlas
    
    Returns:
        resampled_atlas: resampled atlas ready for use
    """
    # Load reference if it's a path    
    print(f"Reference shape: {reference_map_object.shape}")
    print(f"Atlas original shape: {yeo_atlas_data.shape}")
    
    # Resample atlas
    yeo_resampled = nilearn.image.resample_to_img(
        yeo_atlas_data, 
        reference_map_object, 
        interpolation='nearest',
        force_resample=True
    )
    
    print(f"Atlas resampled shape: {yeo_resampled.shape}")
    return yeo_resampled

# Get a reference map (use the first map from first task/contrast combination)
def get_reference_map(task_contrast_all_maps=task_contrast_all_maps):
    """Get a reference map for resampling"""
    for task in task_contrast_all_maps:
        for contrast in task_contrast_all_maps[task]:
            if len(task_contrast_all_maps[task][contrast]["maps_list"]) > 0:
                reference_map = task_contrast_all_maps[task][contrast]["maps_list"][0]
                print(f"Using reference map from {task}, {contrast}")
                return reference_map
    return None

# resample the atlas
print("Resampling atlas:")
reference_map = get_reference_map()
if reference_map is None:
    print("Error: No reference map found")
else:
    yeo_atlas_resampled = resample_atlas(reference_map,yeo_atlas_3d)
    print("resampling complete!")

Resampling atlas:
Using reference map from nBack, twoBack-oneBack
Reference shape: (97, 115, 97)
Atlas original shape: (256, 256, 256)


  return resample_img(


Atlas resampled shape: (97, 115, 97)
resampling complete!


In [10]:
# Pattern-based RSM computation for all networks and whole brain
def compute_rsms(specified_maps, specified_descriptors, data_title, 
                              yeo_atlas_data=yeo_atlas_resampled, network_names=network_names, 
                              correlation_metric='pearson'):
    '''
    RSM computation using full voxel patterns for all networks and whole brain
    
    Parameters:
        specified_maps: list of loaded nifti files
        specified_descriptors: list of descriptions for each map 
        data_title: title for the RSM
        yeo_atlas_data: resampled and loaded Yeo atlas (default yeo_atlas_resampled)
        network_names: dictionary mapping network numbers to names
        correlation_metric: 'pearson', 'spearman', or 'cosine'
        include_whole_brain: whether to include whole brain RSM
    '''
    if len(specified_maps) == 0:
        print("No maps provided for RSM computation")
        return {}, [], []
    
    # Get reference image for resampling
    reference_img = specified_maps[0]
    # print(f"Reference image shape: {reference_img.shape}")
    
    # Use pre-resampled atlas
    if yeo_atlas_data is not None:
        if hasattr(yeo_atlas_data, 'get_fdata'):
            yeo_data = yeo_atlas_data.get_fdata()
            print(f"Using pre-resampled Yeo atlas shape: {yeo_data.shape}")
        else:
            yeo_data = yeo_atlas_data
            print(f"Using Yeo atlas array with shape: {yeo_data.shape}")
    else:
        print("Error: yeo_atlas_data is None")
        return {}, [], []
    
    # Get network labels and make the names
    network_labels = np.unique(yeo_data[yeo_data > 0])
    network_labels = [int(label) for label in network_labels]
    print(f"Found {len(network_labels)} networks: {network_labels}")
    network_labels_named = [f"{label}_{network_names.get(label, f'Network{label}')}" 
                               for label in network_labels]
    print(f"Network names: {network_labels_named}")
    
    # Convert maps to data arrays and extract data for each network
    n_maps = len(specified_maps)
    
    # Initialize storage for network data
    network_data = {}
    
    # Add whole brain
    network_data['whole_brain'] = []
    
    # Initialize network data storage
    for i, network in enumerate(network_labels):
        network_key = network_labels_named[i]
        network_data[network_key] = []
    
    # Extract data from each map
    for i, nifti_map in enumerate(specified_maps):
        try:
            # Get map data
            if hasattr(nifti_map, 'get_fdata'):
                map_data = nifti_map.get_fdata()
            else:
                map_data = nifti_map
            
            # Check shape alignment
            if map_data.shape != yeo_data.shape:
                print(f"Warning: Map {i} shape {map_data.shape} doesn't match atlas shape {yeo_data.shape}")
                continue
            
            # Extract whole brain data
            brain_mask = (yeo_data > 0) & (~np.isnan(map_data))
            whole_brain_values = map_data[brain_mask]
            network_data['whole_brain'].append(whole_brain_values)
            
            # Extract full activation patterns for each network
            for j, network in enumerate(network_labels):
                network_key = network_labels_named[j]
                network_mask = (yeo_data == network) & (~np.isnan(map_data))
                network_values = map_data[network_mask]
                
                if len(network_values) > 0:
                    # Store FULL activation pattern for this network
                    network_data[network_key].append(network_values)
                else:
                    # Store empty array for missing data
                    network_data[network_key].append(np.array([]))
                        
        except Exception as e:
            print(f"Error processing map {i} ({specified_descriptors[i]}): {e}")
            continue
    
    # Compute RSMs
    rsm_results = {}
    
    for region_name, region_data in network_data.items():
        if len(region_data) != n_maps:
            print(f"Warning: {region_name} has {len(region_data)} values, expected {n_maps}")
            continue
            
        try:
            # correlate full activation patterns
            rsm_matrix = np.zeros((n_maps, n_maps))
            
            for i in range(n_maps):
                for j in range(n_maps):
                    if i <= j:  # symmetric matrix
                        
                        # Handle empty arrays
                        if len(region_data[i]) == 0 or len(region_data[j]) == 0:
                            corr = 0.0
                        else:
                            if correlation_metric == 'pearson':
                                corr, _ = pearsonr(region_data[i].flatten(), 
                                                 region_data[j].flatten())
                            elif correlation_metric == 'spearman':
                                corr, _ = spearmanr(region_data[i].flatten(), 
                                                  region_data[j].flatten())
                            elif correlation_metric == 'cosine':
                                dot_product = np.dot(region_data[i].flatten(), 
                                                   region_data[j].flatten())
                                norm_i = np.linalg.norm(region_data[i].flatten())
                                norm_j = np.linalg.norm(region_data[j].flatten())
                                corr = dot_product / (norm_i * norm_j) if (norm_i * norm_j) > 0 else 0
                        
                        rsm_matrix[i, j] = corr
                        rsm_matrix[j, i] = corr 
            
            # Handle NaNs
            rsm_matrix = np.nan_to_num(rsm_matrix, nan=0.0)
            
            # Calculate network size
            if region_name == 'whole_brain':
                n_voxels = len(region_data[0]) if len(region_data) > 0 and len(region_data[0]) > 0 else 0
            else:
                network_num = int(region_name.split('_')[0])
                n_voxels = int(np.sum(yeo_data == network_num))
            
            rsm_results[region_name] = {
                'rsm': rsm_matrix,
                'n_voxels': n_voxels,
                'analysis_type': 'pattern',
                'network_name': region_name,
                'full_title': data_title +'|'+ region_name,
                'descriptors': specified_descriptors
            }
            
        except Exception as e:
            print(f"Error computing RSM for {region_name}: {e}")
            continue
    
    print(f"Successfully computed pattern-based RSMs for {len(rsm_results)} regions")
    return rsm_results, network_labels_named, specified_descriptors

In [11]:
first_batch_tasks = ["nBack","flanker","directedForgetting"]
second_batch_tasks = ["goNogo", "shapeMatching", "stopSignal"]
third_batch_tasks = [ "cuedTS", "spatialTS"]

# calculate batch 1 if necessary
if os.path.exists("first_batch_rsm.pkl") or os.path.exists("complete_rsm_results.pkl"):
    print("File exists!")
    
else:
    print("File not found...calculating now")
    rsm_results_all = {}
    for task in task_contrast_all_maps:
        rsm_results_all[task] = {}
        
        if task not in first_batch_tasks:
            continue
            
        for contrast in task_contrast_all_maps[task]:
            if len(task_contrast_all_maps[task][contrast]["maps_list"]) > 0:
                rsm_results, network_labels, descriptors = compute_rsms(
                    specified_maps=task_contrast_all_maps[task][contrast]["maps_list"],
                    specified_descriptors=task_contrast_all_maps[task][contrast]["descriptors_list"],
                    data_title=task_contrast_all_maps[task][contrast]["data_title"],
                )
                rsm_results_all[task][contrast] = rsm_results

    # save first batch srm
    save_rsm(rsm_results_all, "first_batch_rsm")
    
    del rsm_results_all  # Delete the large dictionary
    cleanup_memory()

# calculate batch 2 if necessary
if os.path.exists("second_batch_rsm.pkl") or os.path.exists("complete_rsm_results.pkl"):
    print("File exists!")
    
else:
    print("File not found...calculating now")
    rsm_results_all = {}
    for task in task_contrast_all_maps:
        rsm_results_all[task] = {}
        
        if task not in second_batch_tasks:
            continue
            
        for contrast in task_contrast_all_maps[task]:
            if len(task_contrast_all_maps[task][contrast]["maps_list"]) > 0:
                rsm_results, network_labels, descriptors = compute_rsms(
                    specified_maps=task_contrast_all_maps[task][contrast]["maps_list"],
                    specified_descriptors=task_contrast_all_maps[task][contrast]["descriptors_list"],
                    data_title=task_contrast_all_maps[task][contrast]["data_title"],
                    yeo_atlas_data = yeo_atlas_resampled
                )
                rsm_results_all[task][contrast] = rsm_results

    # save second batch srm
    save_rsm(rsm_results_all, "second_batch_rsm")
    
    del rsm_results_all  # Delete the large dictionary
    cleanup_memory()

# calculate batch 3 if necessary
if os.path.exists("third_batch_rsm.pkl") or os.path.exists("complete_rsm_results.pkl"):
    print("File exists!")
    
else:
    print("File not found...calculating now")
    rsm_results_all = {}
    for task in task_contrast_all_maps:
        rsm_results_all[task] = {}
        
        if task not in third_batch_tasks:
            continue
            
        for contrast in task_contrast_all_maps[task]:
            if len(task_contrast_all_maps[task][contrast]["maps_list"]) > 0:
                rsm_results, network_labels, descriptors = compute_rsms(
                    specified_maps=task_contrast_all_maps[task][contrast]["maps_list"],
                    specified_descriptors=task_contrast_all_maps[task][contrast]["descriptors_list"],
                    data_title=task_contrast_all_maps[task][contrast]["data_title"],
                    yeo_atlas_data = yeo_atlas_resampled
                )
                rsm_results_all[task][contrast] = rsm_results

    # save third batch srm
    save_rsm(rsm_results_all, "third_batch_rsm")
    
    del rsm_results_all  # Delete the large dictionary
    cleanup_memory()

# load all three RSM batches in one file called "complete_rsm_results"
if os.path.exists("complete_rsm_results.pkl"):
    print("File exists!")
    all_rsms = load_rsm("complete_rsm_results")
else:
    print("File not found...calculating now")
    
    first_batch = load_rsm("first_batch_rsm")
    second_batch = load_rsm("second_batch_rsm")
    third_batch = load_rsm("third_batch_rsm")
    
    # delete empty task keys from each batch dataset
    for task in first_batch_tasks:
        del second_batch[task]
        del third_batch[task]
    for task in second_batch_tasks:
        del first_batch[task]
        del third_batch[task]
    for task in third_batch_tasks:
        del first_batch[task]
        del second_batch[task]

    all_rsms = first_batch
    all_rsms.update(second_batch)
    all_rsms.update(third_batch)
    save_rsm(all_rsms, "complete_rsm_results")

File exists!
File exists!
File exists!
File exists!
Loaded from complete_rsm_results.pkl
