# Nov 12, 2024: Variability types per RSN

In [1]:
import csv
import os
import sys
import numpy as np
import pandas as pd
import scipy as sp 
import dill as pickle 
from os.path import join as pjoin
from itertools import product
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
import subprocess
from scipy import sparse, stats
from multiprocessing import Pool
import glob
import random
from sklearn.preprocessing import OneHotEncoder, MultiLabelBinarizer

import arviz as az

import ants
from nipype.interfaces import afni

from itertools import product, combinations
import multiprocessing as mp
from functools import partial

# networks
import graph_tool.all as gt

# plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.cm import rainbow
from matplotlib.patches import Rectangle

plt.rcParamsDefault['font.family'] = "sans-serif"
plt.rcParamsDefault['font.sans-serif'] = "Arial"
plt.rcParams['font.size'] = 14
plt.rcParams["errorbar.capsize"] = 0.5

import cmasher as cmr  # CITE ITS PAPER IN YOUR MANUSCRIPT
import colorcet as cc

# ignore user warnings
import warnings
warnings.filterwarnings("ignore") #, category=UserWarning)

In [2]:
class ARGS():
    pass

args = ARGS()

args.SEED = 100

def set_seed(args):
    gt.seed_rng(args.SEED)
    np.random.seed(args.SEED)

set_seed(args)

In [3]:
args.type = 'spatial'
args.roi_size = 225
args.maintain_symmetry = True
args.brain_div = 'whl'
args.num_rois = 162

PARC_DESC = (
    f'type-{args.type}'
    f'_size-{args.roi_size}'
    f'_symm-{args.maintain_symmetry}'
    f'_braindiv-{args.brain_div}'
    f'_nrois-{args.num_rois}'
)

In [4]:
args.GRAPH_DEF = f'constructed'
args.GRAPH_METHOD = f'pearson-corr'
args.THRESHOLDING = f'positive'
args.EDGE_DEF = f'binary'
args.EDGE_DENSITY = 20
args.LAYER_DEF = f'individual'
args.DATA_UNIT = f'ses'

BASE_path = f'{os.environ["HOME"]}/mouse_dataset'
PARCELS_path = f'{BASE_path}/parcels'
ROI_path = f'{BASE_path}/roi_results_v2/{PARC_DESC}'
TS_path = f'{ROI_path}/runwise_timeseries'
ROI_RESULTS_path = (
    f'{ROI_path}'
    f'/graph-{args.GRAPH_DEF}/method-{args.GRAPH_METHOD}'
    f'/threshold-{args.THRESHOLDING}/edge-{args.EDGE_DEF}/density-{args.EDGE_DENSITY}'
    f'/layer-{args.LAYER_DEF}/unit-{args.DATA_UNIT}'
)
RSN_ROI_path = f'{ROI_path}/rsns'
os.system(f'mkdir -p {RSN_ROI_path}')
GRAPH_path = f'{ROI_RESULTS_path}/graphs'
os.system(f'mkdir -p {GRAPH_path}')
SBM_path = f'{ROI_RESULTS_path}/model-fits'
os.system(f'mkdir -p {SBM_path}')
ESTIM_path = f'{ROI_RESULTS_path}/estimates'
os.system(f'mkdir -p {ESTIM_path}/individual')
os.system(f'mkdir -p {ESTIM_path}/group')

0

In [5]:
parcels_img = ants.image_read(f'{PARCELS_path}/{PARC_DESC}_desc-parcels.nii.gz')
parcels = parcels_img.numpy()
roi_labels = np.loadtxt(f'{PARCELS_path}/{PARC_DESC}_desc-labels.txt')

In [6]:
rsn_files = sorted(glob.glob(f'{RSN_ROI_path}/*.txt', recursive=True))
rsns_df = []
for file in rsn_files:
    rsn = np.loadtxt(file)
    name = '-'.join([n for n in file.split('/')[-1].split('-') if not any(st in n for st in ['desc', 'rois'])])
    row = pd.DataFrame(dict(
        name=[name],
        rsn=[rsn],
    ))
    rsns_df += [row]
rsns_df = pd.concat(rsns_df).reset_index(drop=True)
rsns_df

Unnamed: 0,name,rsn
0,j-amygdala,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,j-basal_ganglia,"[1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, ..."
2,j-default_mode,"[0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, ..."
3,j-limbic,"[1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, ..."
4,j-olfactory,"[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, ..."
5,j-sensory,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
6,j-somatomotor,"[1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
7,j-somatosensory,"[1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
8,j-subcortical,"[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
9,j-thalamus,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [7]:
def concatenate(in_files, out_file):
    try:
        os.remove(out_file)
    except:
        pass

    tcat = afni.TCat()
    tcat.inputs.in_files = in_files
    tcat.inputs.out_file = out_file
    tcat.inputs.rlt = ''
    tcat.cmdline 
    tcat.run()

    for file in in_files:
        try:
            os.remove(file)
        except:
            pass
    return None

def var_types_to_nifti(args, X, folder, name=f'desc-variability-type'):
    os.system(f'mkdir -p {folder}')
    in_files = []
    for idx_type in range(0, 4):
        x = X[:, idx_type]
        x_img = np.zeros_like(parcels)
        for idx, roi in enumerate(roi_labels):
            x_img += (parcels == roi) * (x[idx])
        
        file = f'{folder}/vartype-{idx_type:01d}.nii.gz'
        parcels_img.new_image_like(x_img).to_filename(file)
        in_files.append(file)

    out_file = f'{folder}/{name}.nii.gz'
    concatenate(in_files, out_file)
    return None

def apply_rsn_mask(args, folder, rsn_name):
    calc = afni.Calc()
    calc.inputs.in_file_a = f'{folder}/rsns/rsn-{rsn_name}_desc-var-type.nii.gz'
    calc.inputs.in_file_b = f'{RSN_ROI_path}/desc-{rsn_name}-mask.nii.gz'
    calc.inputs.expr = 'a*b'
    calc.inputs.out_file = f'{folder}/rsns/rsn-{rsn_name}_desc-var-type.nii.gz'
    calc.inputs.outputtype = 'NIFTI'
    calc.inputs.overwrite = True
    calc.run()
    return None

def rsn_var_types_to_nifti(args, rsns_df, var_type_modes, folder):
    for idx_row, row in rsns_df.iterrows():
        rsn = row['rsn']
        rsn_name = row['name']
        var_types_per_rsn = var_type_modes * rsn[:, None]
        var_types_to_nifti(args, X=var_types_per_rsn, folder=f'{folder}/rsns', name=f'rsn-{rsn_name}_desc-var-type')
        apply_rsn_mask(args, folder, rsn_name, )
    return None

In [8]:
def collect_indiv_var_types(args, level):
    if level == -1:
        main_folders = sorted(glob.glob(f'{ESTIM_path}/individual/sub-*/compare-entropies/{SBM}', recursive=True))
    else:
        main_folders = sorted(glob.glob(f'{ESTIM_path}/individual/sub-*/compare-entropies/{SBM}/level-{level}', recursive=True))

    var_types = []
    for main_folder in main_folders:
        with open(f'{main_folder}/desc-variability-type.pkl', 'rb') as f:
            var_types += [pickle.load(f)]
    var_types = np.stack(var_types, axis=-1)
    return var_types

In [9]:
def bootstrap_histogram_means(matrix, n_bootstrap=1000, confidence_level=0.95):
    n_realizations, n_bins = matrix.shape
    
    compute_sample_mean = lambda x: np.mean(x, axis=0)
    
    # bootstrap
    bootstrap_means = np.array([
        compute_sample_mean(matrix[np.random.choice(n_realizations, size=n_realizations, replace=True), :])
        for _ in range(n_bootstrap)
    ])
    
    # statistics
    mean_estimate = np.mean(bootstrap_means, axis=0)
    
    # confidence intervals
    # in case of bootstrapping, one can calculate CI as follows:
    # sort all bootstrap means in ascending order 
    # lower bound = confidence_level / 2 th percentile of the bootstrap means
    # upper bound = 100 - confidence_level / 2 th percentile of the bootstrap means
    ci_lower = np.percentile(bootstrap_means, q=100*(1 - confidence_level) / 2, axis=0)
    ci_upper = np.percentile(bootstrap_means, q=100 * (1 - (1 - confidence_level) / 2), axis=0)
    # 95% CI based on standard error of the mean
    # vs
    # 95% CI based on standard deviation of data
    
    return mean_estimate, ci_lower, ci_upper


In [10]:
def plot_var_type_per_rsn(args, ax, X):
    sns.stripplot(
        X.T, 
        ax=ax, 
        orient='v', 
        jitter=False, 
        size=4, 
        color='royalblue',
        alpha=1.0,
    )
    sns.boxplot(
        X.T, 
        ax=ax, 
        orient='v', 
        fill=False, 
        width=0.7, 
        fliersize=0.0, 
        color='grey', 
        medianprops={"color": "grey", "linewidth": 2}, 
        whis=[2.5, 97.5],
    )

    xs = np.arange(X.shape[0])[:, None]
    xs = np.repeat(xs, repeats=X.shape[1], axis=-1)
    ys = X
    ax.plot(xs, ys, c='grey', alpha=0.5)

    mu, cil, ciu = bootstrap_histogram_means(X.T)
    ax.plot(mu, c='royalblue', linewidth=3, marker='o')
    ax.fill_between(
        x=range(len(mu)),
        y1=cil, 
        y2=ciu,
        color='royalblue',
        alpha=0.15,
    )
    return None

def plot_var_type_pval_per_rsn(args, ax, X, alpha=0.05):
    corr_factor = len(X) * (len(X) - 1) / 2
    alpha = alpha / corr_factor

    tests = np.zeros((len(X), len(X)))
    for (c1, c2) in list(combinations(range(len(X)), 2)):
        _, pval = stats.ttest_rel(X[c1, :], X[c2, :], nan_policy='omit')
        pval = np.minimum(pval * corr_factor, 1)
        tests[c1, c2] = pval
        tests[c2, c1] = tests[c1, c2]

    sns.heatmap(tests, ax=ax, square=True, cmap=cc.cm.CET_L12, annot=True, fmt='.2f')
    for (c1, c2) in list(combinations(range(len(X)), 2)):
        pval = tests[c1, c2]
        if pval > alpha:
            ax.add_patch(Rectangle((c1, c2), 1, 1, fill=False, edgecolor='tab:red', lw=3))
            ax.add_patch(Rectangle((c2, c1), 1, 1, fill=False, edgecolor='tab:red', lw=3))
    return None

In [11]:
def plot_var_type_all_rsns(args, rsns_df, var_types, level=-1):

    labels = ['lm,\nnsv', 'lm,\nasv', 'dm,\nnsv', 'dm,\nasv']
    full_labels = [
        'lm: localized membership', 'dm: distributed membership', 
        'hsv: high solution variability', 'lsv: low solution variability'
    ]
    handles = [plt.Line2D([], [], marker='_', color='w', markersize=12, label=label) for label in full_labels]

    ncols = 5
    nrows = np.ceil(len(rsns_df) / ncols).astype(int)
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5*ncols, 4*nrows), dpi=90, sharey=True)
    fig.tight_layout(h_pad=4, w_pad=3)

    figp, axsp = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5*ncols, 4*nrows), dpi=90, sharey=False)
    figp.tight_layout(h_pad=4, w_pad=3)

    if level == -1:
        title = f'{SBM}'
    else:
        title = f'{SBM}_level-{level}'
    fig.suptitle(title, x=0.0, y=1.0)
    figp.suptitle(title, x=0.0, y=1.0)

    for idx, row in rsns_df.iterrows():
        r, c = idx // ncols, idx % ncols
        ax = axs[r, c] if nrows > 1 else axs[c]
        axp = axsp[r, c] if nrows > 1 else axsp[c]

        name = row['name']
        rsn = row['rsn']
        X = var_types[np.where(rsn)[0], :, :] # rois x types x subs
        X = np.sum(X, axis=0) / X.shape[0] # types x subs

        plot_var_type_per_rsn(args, ax, X)
        ax.set(title=f'rsn: {name}', xlabel=f'variability type', ylabel=f'roi proportion')
        ax.set_xticks(range(4), labels, rotation=0, ha='center')
        ax.grid(alpha=0.3)

        plot_var_type_pval_per_rsn(args, axp, X)
        axp.set(title=f'rsn: {name}', xlabel=f'variability type', ylabel=f'variability type')
        axp.set_xticks(np.arange(4)+0.5, labels, rotation=0, ha='center')
        axp.set_yticks(np.arange(4)+0.5, labels)
        axp.grid(alpha=0.3)
        # break
    
    fig.legend(handles=handles, loc='lower center', ncol=len(full_labels))
    figp.legend(handles=handles, loc='lower center', ncol=len(full_labels))

    for c_ in range(c+1, ncols):
        ax = axs[r, c_] if nrows > 1 else axs[c_]
        fig.delaxes(ax)

        axp = axsp[r, c_] if nrows > 1 else axsp[c_]
        figp.delaxes(axp)

    if level == -1:
        out_folder = f'{ESTIM_path}/group/compare-entropies/{SBM}'
    else:
        out_folder = f'{ESTIM_path}/group/compare-entropies/{SBM}/level-{level}'
    os.system(f'mkdir -p {out_folder}')

    fig.savefig(f'{out_folder}/desc-variability-types-rsns.png')
    figp.savefig(f'{out_folder}/desc-variability-types-comparisons-pvals-rsns.png')
    return fig, figp

In [12]:
sbms_list = [
    (True, 'a'),
    (True, 'd'),
    (True, 'h'),
    (False, 'd'),
    (False, 'h'),
]

In [13]:
for sbm in sbms_list[:]:
    args.dc, args.sbm = sbm

    args.nested = args.sbm == 'h'

    args.force_niter = 40000
    args.num_draws = int((1/2) * args.force_niter)

    def sbm_name(args):
        dc = f'dc' if args.dc else f'nd'
        dc = f'' if args.sbm in ['a'] else dc
        file = f'sbm-{dc}-{args.sbm}'
        return file

    SBM = sbm_name(args)
    print(SBM)

    if args.sbm in ['a', 'd']:
        levels = np.arange(4)
        level = -1
        folder = f'{ESTIM_path}/group/compare-entropies/{SBM}'

        var_types = collect_indiv_var_types(args, level=level)
        fig, figp = plot_var_type_all_rsns(args, rsns_df, var_types, level=level)

        # common rois across animals per variability type
        var_type_modes = np.loadtxt(f'{folder}/desc-variability-type-modes.txt')
        rsn_var_types_to_nifti(args, rsns_df, var_type_modes, folder)

    if args.sbm in ['h']:
        levels = np.arange(4)
        for level in levels:
            folder = f'{ESTIM_path}/group/compare-entropies/{SBM}'
            if level != -1: folder += f'/level-{level}'

            var_types = collect_indiv_var_types(args, level=level)
            fig, figp = plot_var_type_all_rsns(args, rsns_df, var_types, level=level)

            # common rois across animals per variability type
            var_type_modes = np.loadtxt(f'{folder}/desc-variability-type-modes.txt')
            rsn_var_types_to_nifti(args, rsns_df, var_type_modes, folder)
            
    plt.close('all')

sbm--a
241120-23:29:47,288 nipype.interface INFO:
	 stderr 2024-11-20T23:29:47.288175:++ 3dTcat: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
241120-23:29:47,293 nipype.interface INFO:
241120-23:29:47,294 nipype.interface INFO:
241120-23:29:47,342 nipype.interface INFO:
	 stderr 2024-11-20T23:29:47.342577:++ elapsed time = 0.1 s
241120-23:29:47,606 nipype.interface INFO:
	 stderr 2024-11-20T23:29:47.606693:++ 3dcalc: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
241120-23:29:47,609 nipype.interface INFO:
	 stderr 2024-11-20T23:29:47.606693:++ Authored by: A cast of thousands
241120-23:29:47,697 nipype.interface INFO:
	 stderr 2024-11-20T23:29:47.697101:++ Output dataset /home/govindas/mouse_dataset/roi_results_v2/type-spatial_size-225_symm-True_braindiv-whl_nrois-162/graph-constructed/method-pearson-corr/threshold-positive/edge-binary/density-20/layer-individual/unit-ses/estimates/group/compare-entropies/sbm--a/rsns/rsn-j-amygdala_desc-var-type.nii.gz
241120-23:29:48,182 nipype.