In [None]:
# visualizes the group parcellations (averaged across the subjects) for all task/contrasts

In [2]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
import nilearn
import json
import datetime
import pickle
import seaborn as sns
import gc
import psutil
import math
import scipy.stats as stats
from matplotlib.patches import Patch
from nilearn import plotting
from nilearn.glm.first_level import FirstLevelModel
from nilearn.glm.second_level import SecondLevelModel
from nilearn.glm import threshold_stats_img
from nilearn.image import concat_imgs, mean_img, index_img
from nilearn.reporting import make_glm_report
from nilearn import masking, image
from nilearn import datasets
from scipy.stats import pearsonr
import matplotlib
matplotlib.use('Agg')  # MUST be before importing pyplot
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from collections import defaultdict
from nilearn.maskers import NiftiLabelsMasker
from nilearn.plotting.find_cuts import find_cut_slices
import nibabel as nib

# Import shared utilities and configuration
# need to do it this way because in a sub-directory (later turn config and utils into part of a package)
from utils import (
    TASKS, CONTRASTS, SUBJECTS, SESSIONS, ENCOUNTERS,
    build_first_level_contrast_map_path, is_valid_contrast_map, clean_z_map_data,
    convert_to_regular_dict, create_smor_atlas,load_smor_atlas, load_schaefer_atlas, cleanup_memory
)
from config import BASE_DIR, OUTPUT_DIRS

In [3]:
# compiled_req_contrasts = ["twoBack-oneBack", 'task-baseline', "incongruent-congruent", "neg-con", "nogo_success-go", "main_vars", "stop_failure-go","task_switch_cost"]
# compile all requested contrasts into one list
compiled_req_contrasts = []
for task in TASKS:
    for contrast in CONTRASTS[task]:
        if (contrast not in compiled_req_contrasts):
            compiled_req_contrasts.append(contrast)

## load atlas

In [6]:
# schafer stuff
SCHAFER_PARCELLATED_DIR = OUTPUT_DIRS["schaefer"]
schafer_files = {'mean':f'discovery_parcel_indiv_mean_updated'}
schafer_date_updated = '1208'
indices = [1,2,3]
# Get schaefer atlas
SCHAEFER = load_schaefer_atlas()
SCHAEFER_IMG = nib.load(SCHAEFER.maps)
SCHAEFER_DATA = SCHAEFER_IMG.get_fdata()

# smorgasbord stuff
SMORG_PARCELLATED_DIR = OUTPUT_DIRS["smor"]
smor_files = {'mean':f'discovery_parcel_indiv_mean_updated'}
smor_date_updated = '1208'
indices = [1,2,3]
# get smorgasbord atlas
smorgasbord_atlas = load_smor_atlas()
SMORG_IMG = smorgasbord_atlas.maps
SMORG_DATA = SMORG_IMG.get_fdata()

Loading Schaefer 400 atlas...


Atlas loaded with 400 regions
Atlas shape: (91, 109, 91)
Loading Smorgasbord atlas...
Atlas loaded with 429 regions
Atlas shape: (193, 229, 193)


In [8]:
req_atlas = "smor"

# Select atlas configuration
if req_atlas == "schafer":
    main_dir = SCHAFER_PARCELLATED_DIR
    main_files = schafer_files
    date_updated = schafer_date_updated
    atlas_obj = SCHAEFER
    atlas_img = SCHAEFER_IMG
    atlas_data = SCHAEFER_DATA
elif req_atlas == "smor":
    main_dir = SMORG_PARCELLATED_DIR
    main_files = smor_files
    date_updated = smor_date_updated
    atlas_obj = smorgasbord_atlas
    atlas_img = SMORG_IMG
    atlas_data = SMORG_DATA
else:
    raise ValueError(f"Unknown atlas: {req_atlas}. Use 'schafer' or 'smor'")


file_type = "z"
output_ending = ""
if (file_type == "z"):
    output_ending = "_z_scored"

## loading parcellated, average fixed effect maps

In [15]:
FIXED_DIR = f'{main_dir}_fixed'

# Average the fixed effects maps
avg_fixed_file = f'{FIXED_DIR}/discovery_parcel_fixedeffects_mean_updated_{date_updated}{output_ending}_averaged.pkl'
with open(avg_fixed_file, 'rb') as f:
    averaged_fixed_maps = pickle.load(f)

# Load the averaged parcel results

In [17]:
avg_parcel_traj_results = {}
mean_filename = f"{main_dir}/{main_files['mean']}_{date_updated}{output_ending}_averaged.pkl"

with open(mean_filename, 'rb') as f:
    avg_parcel_traj_results = pickle.load(f)

# Visualizing functions

In [18]:
def create_parcel_practice_heatmap(parcel_traj, title, indiv_data = True, n_rows=50):
    """
    Create a heatmap showing practice effects across all parcels

    input:
    parcel_traj: a df of parcels and activations 
    title: the title for this heatmap
    """
    
    # Prepare data for heatmap
    df = pd.DataFrame(parcel_traj).T

    if (indiv_data):
        # Sort by slope
        df_sorted = df.sort_values('slope', key=abs).head(n_rows)
        # Create trajectory matrix
        trajectory_matrix = np.array([row['trajectory'] for _, row in df_sorted.iterrows()])
    else:
        # Sort by avg slope
        df_sorted = df.sort_values('slope_mean', key=abs).head(n_rows)
        # Create trajectory matrix
        trajectory_matrix = np.array([row['trajectory_mean'] for _, row in df_sorted.iterrows()])
    
    
    
    # Create the heatmap
    plt.figure(figsize=(30, 12))
    
    # Plot trajectories
    sns.heatmap(trajectory_matrix, 
                xticklabels=['Enc 1', 'Enc 2', 'Enc 3', 'Enc 4', 'Enc 5'],
                yticklabels=[row.name for _, row in df_sorted.iterrows()],
                cmap='RdBu_r', center=0, 
                cbar_kws={'label': 'Activation'})
    
    plt.title('Practice Effects Across All Parcels')
    plt.xlabel('Encounter')
    plt.ylabel('Brain Parcel ID')
        
    plt.title(f"{title}: first {n_rows} rows")
    plt.tight_layout()
    plt.show()

In [19]:
# Calculate cut slices
x_cuts = find_cut_slices(SCHAEFER_IMG, direction='x', n_cuts=8)
y_cuts = find_cut_slices(SCHAEFER_IMG, direction='y', n_cuts=8)
z_cuts = find_cut_slices(SCHAEFER_IMG, direction='z', n_cuts=8)

# Specify which slices to display
FIXED_X_CUTS = x_cuts[1:7]  # Middle six X slices
FIXED_Y_CUTS = y_cuts[1:7]  # Middle six Y slices
FIXED_Z_CUTS = z_cuts[1:7]  # Middle six Z slices

In [23]:
def plot_slopes_on_brain(avg_results, task, contrast, n_rois=400, atlas_name=req_atlas, atlas=atlas_obj, atlas_img=atlas_img, atlas_data=atlas_data, title="Average parcel slopes", threshold=None):
    """
    Plot parcel slope means on brain using atlas labels.
    """
    # Get atlas labels
    atlas_labels = [label.decode('utf-8') if isinstance(label, bytes) else label 
                   for label in atlas.labels]
    
    # Get slope data
    parcel_data = avg_results[task][contrast]
    
    # Create brain image with slope values
    slope_data = np.zeros_like(atlas_data)
    
    # Map parcel names to atlas regions
    if hasattr(atlas, 'roi_values'):
        # Smorgasbord atlas - use actual ROI values
        for i, (roi_value, atlas_label) in enumerate(zip(atlas.roi_values, atlas_labels)):
            if atlas_label in parcel_data:
                slope_value = parcel_data[atlas_label]['slope_mean']
                slope_data[atlas_data == roi_value] = slope_value
    else:
        # Schaefer atlas - use consecutive indexing
        for i, atlas_label in enumerate(atlas_labels):
            if atlas_label in parcel_data:
                slope_value = parcel_data[atlas_label]['slope_mean']
                slope_data[atlas_data == (i + 1)] = slope_value

    # Create a NIfTI image
    slope_img = nib.Nifti1Image(slope_data, atlas_img.affine)

    # Calculate vmin and vmax
    nonzero_slopes = slope_data[slope_data != 0]
    vmin, vmax = np.percentile(nonzero_slopes, [2, 98])
    
    # Ensure vmin and vmax are symmetric for diverging colormap
    abs_max = max(abs(vmin), abs(vmax))
    vmin, vmax = -abs_max, abs_max

    # # Create custom colormap
    # cmap = create_custom_colormap()

    # Set up the layout for plotting
    fig, axes = plt.subplots(3, 6, figsize=(20, 15))
    fig.suptitle(f'{title}: {task}/{contrast} ({atlas_name} atlas)', fontsize=20)

    display_modes = ['x', 'y', 'z']
    cuts_by_view = [FIXED_X_CUTS, FIXED_Y_CUTS, FIXED_Z_CUTS]

    for idx, ax_row in enumerate(axes):
        for j, ax in enumerate(ax_row):
            coord = cuts_by_view[idx][j]
            plotting.plot_stat_map(slope_img,
                                   colorbar=True,
                                   cmap='seismic',
                                   vmin=vmin,
                                   vmax=vmax,
                                   threshold=threshold,
                                   display_mode=display_modes[idx],
                                   axes=ax,
                                   cut_coords=[coord],
                                   draw_cross=False)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# visualizing the averaged parcel slopes

In [None]:
# # heatmap of top slope-changed parcels
# for task in requested_task_contrasts:
#     if (count > 3):
#         break
#     contrast = requested_task_contrasts[task][0]
#     title = f"{task}:{contrast} heatmap of avg parcel activation (all subjects) over encounters, sorted from largest avg abs slope to smallest"

#     create_parcel_practice_heatmap(avg_parcel_traj_results[task][contrast],title,indiv_data=False,n_rows=10)

In [None]:
# # Plot all contrasts 
# # top 10% of slope change
# count = 0
# for task in avg_parcel_traj_results.keys():
#     for contrast in avg_parcel_traj_results[task].keys():
#         if (count >2):
#             break
#         if len(avg_parcel_traj_results[task][contrast]) > 0: 
#             plot_slopes_on_brain(avg_parcel_traj_results, task, contrast)
#             count += 1

In [24]:
# do the same thing but with an unthresholded brain
# Plot all contrasts with no threshold
count = 0
for task in avg_parcel_traj_results.keys():
    for contrast in avg_parcel_traj_results[task].keys():
        if (count >= 1):
            break
        if len(avg_parcel_traj_results[task][contrast]) > 0: 
            plot_slopes_on_brain(avg_parcel_traj_results, task, contrast, threshold = None)
            count += 1

  plt.tight_layout(rect=[0, 0.03, 1, 0.95])


# plotting fixed effects

In [None]:
# def plot_fixed_effects_maps(subjects, tasks, contrasts, maps, descriptors, 
#                             x_cuts, y_cuts, z_cuts, threshold=None):
#     """
#     Plot fixed effects maps on brain slices.
    
#     Parameters:
#     -----------
#     subjects : list
#         List of subject IDs
#     tasks : list
#         List of task names
#     contrasts : dict
#         Dictionary mapping tasks to contrasts
#     maps : list
#         List of NIfTI images to plot
#     descriptors : list
#         List of descriptive labels for each map
#     x_cuts, y_cuts, z_cuts : list
#         Coordinates for slice cuts
#     threshold : float or None
#         Threshold for stat map (None = show all values)
#     """
#     count = 0
#     display_modes = ['x', 'y', 'z']
    
#     for subj in subjects:
#         for task in tasks:
#             for contrast in contrasts[task]:
#                 fig, axes = plt.subplots(3, 6, figsize=(20, 15))
                
#                 title = descriptors[count]
#                 img = maps[count]
                
#                 fig.suptitle(f'{title}: {task}/{contrast}', fontsize=20)
                
#                 # Define cuts for each view
#                 cuts_by_view = [x_cuts, y_cuts, z_cuts]
                
#                 for idx, ax_row in enumerate(axes):
#                     for j, ax in enumerate(ax_row):
#                         coord = cuts_by_view[idx][j]
                        
#                         plotting.plot_stat_map(
#                             img,
#                             colorbar=True,
#                             cmap='seismic',
#                             symmetric_cbar=True,
#                             threshold=threshold,
#                             display_mode=display_modes[idx],
#                             axes=ax,
#                             cut_coords=[coord],
#                             draw_cross=False
#                         )
                
#                 plt.tight_layout(rect=[0, 0.03, 1, 0.95])
#                 plt.show()
                
#                 count += 1

# plotting averaged fixed effects next to the averaged slope maps

In [33]:
def plot_slopes_and_fixed_effects(avg_results, averaged_fixed_maps, task, contrast, n_rois=400, atlas_name=req_atlas, atlas=atlas_obj, atlas_img=atlas_img, atlas_data=atlas_data, title="Average parcel slopes", threshold=None):
    """
    Plot parcel slope means and parcellated avg fixed effects maps side by side.
    """
    # Get atlas labels
    atlas_labels = [label.decode('utf-8') if isinstance(label, bytes) else label 
                   for label in atlas.labels]
    
    # Get slope data
    parcel_data = avg_results[task][contrast]
    
    # Get fixed effects data
    fixed_effects_df = averaged_fixed_maps[task][contrast]
    
    # Create brain image with slope values
    slope_data = np.zeros_like(atlas_data)
    
    # Create brain image with fixed effects values
    fixed_effects_data = np.zeros_like(atlas_data)
    
    # Map parcel names to atlas regions for both slopes and fixed effects
    for i, atlas_label in enumerate(atlas_labels):
        if atlas_label in parcel_data:
            slope_value = parcel_data[atlas_label]['slope_mean']
            fixed_effect_value = fixed_effects_df.loc[fixed_effects_df['region'] == atlas_label, 'activation'].values[0]
            
            if hasattr(atlas, 'roi_values'):
                roi_value = atlas.roi_values[i]
                slope_data[atlas_data == roi_value] = slope_value
                fixed_effects_data[atlas_data == roi_value] = fixed_effect_value
            else:
                slope_data[atlas_data == (i + 1)] = slope_value
                fixed_effects_data[atlas_data == (i + 1)] = fixed_effect_value

    # Create NIfTI images for slopes and fixed effects
    slope_img = nib.Nifti1Image(slope_data, atlas_img.affine)
    fixed_effects_img = nib.Nifti1Image(fixed_effects_data, atlas_img.affine)

    # Calculate vmin and vmax for slopes
    nonzero_slopes = slope_data[slope_data != 0]
    vmin_slope, vmax_slope = np.percentile(nonzero_slopes, [2, 98])
    abs_max_slope = max(abs(vmin_slope), abs(vmax_slope))
    vmin_slope, vmax_slope = -abs_max_slope, abs_max_slope

    # Calculate vmin and vmax for fixed effects
    nonzero_fe = fixed_effects_data[fixed_effects_data != 0]
    vmin_fe, vmax_fe = np.percentile(nonzero_fe, [2, 98])
    abs_max_fe = max(abs(vmin_fe), abs(vmax_fe))

    # Set up the layout for plotting
    fig, axes = plt.subplots(6, 6, figsize=(20, 20))
    fig.suptitle(f'{title}: {task}/{contrast} ({atlas_name} atlas)', fontsize=20)

    display_modes = ['x', 'y', 'z']
    cuts_by_view = [FIXED_X_CUTS, FIXED_Y_CUTS, FIXED_Z_CUTS]

    # Plot slopes (top 3 rows)
    for idx, ax_row in enumerate(axes[:3]):
        for j, ax in enumerate(ax_row):
            coord = cuts_by_view[idx][j]
            plotting.plot_stat_map(slope_img,
                                   colorbar=True,
                                   cmap='RdBu_r',
                                   vmin=vmin_slope,
                                   vmax=vmax_slope,
                                   threshold=threshold,
                                   display_mode=display_modes[idx],
                                   axes=ax,
                                   cut_coords=[coord],
                                   draw_cross=False)
            if j == 0:
                ax.set_ylabel(display_modes[idx].upper(), rotation=0, labelpad=20, fontsize=12)

    # Plot parcellated fixed effects (bottom 3 rows)
    for idx, ax_row in enumerate(axes[3:]):
        for j, ax in enumerate(ax_row):
            coord = cuts_by_view[idx][j]
            plotting.plot_stat_map(fixed_effects_img,
                                   colorbar=True,
                                   cmap='RdBu_r',
                                   vmin=-abs_max_fe,
                                   vmax=abs_max_fe,
                                   symmetric_cbar=True,
                                   threshold=threshold,
                                   display_mode=display_modes[idx],
                                   axes=ax,
                                   cut_coords=[coord],
                                   draw_cross=False)
            if j == 0:
                ax.set_ylabel(display_modes[idx].upper(), rotation=0, labelpad=20, fontsize=12)

    # Add labels for the two sections
    fig.text(0.5, 1.0, 'Parcel Slopes', ha='center', va='center', fontsize=20)
    fig.text(0.5, 0.50, 'Parcellated Fixed Effects (z-scored)', ha='center', va='center', fontsize=20)

    plt.tight_layout(rect=[0, 0.03, 0.9, 0.95])
    # plt.show()

    # Create output directory
    output_dir = "11_parcellated_slopes_and_fixed_effects_plots_z"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f"{task}_{contrast}_slopes_and_fixed_effects.png")
    plt.savefig(output_file, dpi=100, bbox_inches='tight')
    plt.close(fig)  # Close the figure to free up memory
    return output_file

In [None]:
for task in avg_parcel_traj_results.keys():
    for contrast in avg_parcel_traj_results[task].keys():
        if len(avg_parcel_traj_results[task][contrast]) > 0: 
            plot_slopes_and_fixed_effects(avg_parcel_traj_results, averaged_fixed_maps, task, contrast, atlas_name='smor', atlas=smorgasbord_atlas, atlas_img=SMORG_IMG, atlas_data=SMORG_DATA)