# Oct 7, 2025: enlist rois per comm

conda env: gt

In [1]:
import csv
import os
import sys
import numpy as np
import pandas as pd
import scipy as sp 
import dill as pickle 
from os.path import join as pjoin
from itertools import product
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
import re
from scipy import stats
from scipy.spatial.distance import jensenshannon, squareform, pdist
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from sklearn.metrics import silhouette_score

import glob
import random

from itertools import product, combinations
import multiprocessing as mp
from functools import partial
from joblib import Parallel, delayed

from sklearn.cluster import KMeans
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics.pairwise import cosine_similarity
from munkres import Munkres

# networks
import graph_tool.all as gt

# plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.cm import rainbow
from cycler import cycler

plt.rcParamsDefault['font.family'] = "sans-serif"
plt.rcParamsDefault['font.sans-serif'] = "Arial"
plt.rcParams['font.size'] = 14
plt.rcParams["errorbar.capsize"] = 0.5

import colorcet as cc

# ignore user warnings
import warnings
warnings.filterwarnings("ignore") #, category=UserWarning)

In [2]:
def get_colorblind_palette(n=20):
    """Return a merged, deduplicated colorblind-safe palette from CUD, Seaborn, and extended sources."""
    base = [
        "#0072B2", "#D55E00", "#009E73", "#CC79A7",
        "#F0E442", "#56B4E9", "#E69F00", "#000000",
        "#999999", "#882255", "#44AA99", "#117733"
    ]
    
    new = [
        "#0173B2", "#DE8F05", "#029E73", "#D55E00",  # D55E00 already in base
        "#CC78BC", "#CA9161", "#FBAFE4", "#949494",
        "#ECE133", "#56B4E9"  # 56B4E9 already in base
    ]
    
    # Deduplicate while preserving order
    seen = set()
    full = base + new
    merged = []
    for color in full:
        if color.lower() not in seen:
            merged.append(color)
            seen.add(color.lower())
    
    return merged[:n]

def setup_mpl(fontsize=7):
    """Configure matplotlib for Illustrator export with Helvetica-style fonts and clean styles."""
    
    CUD_COLORS = get_colorblind_palette()
    
    mpl.rcParams.update({
        # Fonts and layout
        "font.family": "sans-serif",
        "font.sans-serif": ["Arial", "DejaVu Sans"],
        "font.size": fontsize,
        "axes.titlesize": fontsize,
        "axes.labelsize": fontsize,
        "xtick.labelsize": fontsize,
        "ytick.labelsize": fontsize,
        "legend.fontsize": fontsize,

        # Export settings
        "svg.fonttype": 'none',
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "text.usetex": False,

        # Axes and ticks
        "axes.linewidth": 0.5,
        "xtick.major.width": 0.5,
        "ytick.major.width": 0.5,
        "xtick.minor.width": 0.5,
        "ytick.minor.width": 0.5,
        "xtick.major.size": 2.5,
        "ytick.major.size": 2.5,

        # Lines
        "lines.linewidth": 1.0,

        # Default color cycle (Color Universal Design)
        "axes.prop_cycle": cycler('color', CUD_COLORS),
    })

    # mpl.rcParams["axes.prop_cycle"] = cycler('color', get_colorblind_palette())

In [3]:
setup_mpl(fontsize=7)
CUD_COLORS = get_colorblind_palette()

In [4]:
import seaborn as sns

In [5]:
class ARGS():
    pass

args = ARGS()

args.SEED = 100

In [6]:
args.source = 'allen' #'spatial' #'allen'
args.space = 'ccfv2' #'ccfv2'
args.brain_div = 'whl' #'whl'
args.num_rois = 172 #162 #172
args.resolution = 200 #200

PARC_DESC = (
    f'source-{args.source}'
    f'_space-{args.space}'
    f'_braindiv-{args.brain_div}'
    f'_nrois-{args.num_rois}'
    f'_res-{args.resolution}'
)
PARC_DESC

'source-allen_space-ccfv2_braindiv-whl_nrois-172_res-200'

In [7]:
args.GRAPH_DEF = f'constructed'
args.GRAPH_METHOD = f'pearson'
args.THRESHOLD = f'signed'
args.EDGE_DEF = f'binary'
args.EDGE_DENSITY = 20
args.LAYER_DEF = f'individual'
args.DATA_UNIT = f'grp'

BASE_path = f'{os.environ["HOME"]}/new_mouse_dataset'
PARCELS_path = f'{BASE_path}/parcels'
ROI_path = (
    f'{BASE_path}/roi-results-v3'
    f'/{PARC_DESC}'
)
TS_path = f'{ROI_path}/roi_timeseries'

In [8]:
ROI_RESULTS_path = (
    f'{ROI_path}'
    f'/graph-{args.GRAPH_DEF}/method-{args.GRAPH_METHOD}'
    f'/threshold-{args.THRESHOLD}/edge-{args.EDGE_DEF}/density-{args.EDGE_DENSITY}'
    f'/layer-{args.LAYER_DEF}/unit-{args.DATA_UNIT}'
)
GRAPH_path = f'{ROI_RESULTS_path}/graphs'
os.system(f'mkdir -p {GRAPH_path}')
SBM_path = f'{ROI_RESULTS_path}/model-fits'
os.system(f'mkdir -p {SBM_path}')
DIAG_path = f'{ROI_RESULTS_path}/diagnostics'
os.system(f'mkdir -p {DIAG_path}')
ESTIM_path = f'{ROI_RESULTS_path}/estimates'
os.system(f'mkdir -p {ESTIM_path}/individual')
os.system(f'mkdir -p {ESTIM_path}/group')

0

In [9]:
args.dc, args.sbm = False, 'h'
args.nested = args.sbm == 'h'

args.force_niter = 100000
args.num_draws = int((1/2) * args.force_niter)

args.epsilon = 0.4 # threshold KSD for convergence
args.delta = np.ceil(args.force_niter / 100).astype(int)

def sbm_name(args):
    dc = f'dc' if args.dc else f'nd'
    dc = f'' if args.sbm in ['a', 'm'] else dc
    file = f'sbm-{dc}-{args.sbm}'
    return file

SBM = sbm_name(args)
SBM

'sbm-nd-h'

roi info

In [10]:
parcels_file = f'{PARCELS_path}/{PARC_DESC}_desc-parcels.nii.gz'
# parcels_img = nib.load(parcels_file)

try:
    roi_table = pd.read_csv(f'{PARCELS_path}/{PARC_DESC}_desc-names.csv')
    roi_labels = roi_table['roi'].to_numpy()
except:
    roi_labels = np.arange(1, args.num_rois+1)

In [11]:
roi_table

Unnamed: 0,old_roi,name,roi
0,1,"R-Frontal pole, cerebral cortex (FRP,184)",1
1,2,"R-Primary motor area (MOp,985)",2
2,3,"R-Secondary motor area (MOs,993)",3
3,4,"R-Primary somatosensory area, nose (SSp-n,353)",4
4,5,"R-Primary somatosensory area, barrel field (SS...",5
...,...,...,...
167,168,"L-Medulla, sensory related (MY-sen,472)",168
168,169,"L-Medulla, motor related (MY-mot,456)",169
169,170,"L-Medulla, behavioral state related (MY-sat,465)",170
170,171,"L-Cerebellar cortex (CBX,614)",171


membership matrices

In [12]:
indiv_files = sorted(glob.glob(f'{ESTIM_path}/individual/sub-SLC01/partition-modes-group-aligned/{SBM}/desc-mem-mats.pkl'))
grp_df = []
for file in indiv_files:
    with open(file, 'rb') as f:
        df = pickle.load(f)
    grp_df += [df]

In [13]:
pis_df = pd.concat(grp_df).reset_index(drop=True)

In [14]:
def make_same_shape(Ms: list):
    num_rois = Ms[0].shape[0]
    num_comms = [M.shape[-1] for M in Ms]
    max_comms = np.max(num_comms)
    Rs = np.zeros((len(Ms), num_rois, max_comms))
    for idx, M in enumerate(Ms):
        Rs[idx, :, :M.shape[-1]] = M
    Rs = [Rs[idx, :, :] for idx in range(Rs.shape[0])]
    return Rs

In [15]:
cols = [c for c in pis_df.columns if 'pi' in c]

In [16]:
pis_df = pd.concat(
    [
        pis_df[['sub', 'sbm', 'omega']],
        pis_df[cols].apply(lambda pis: make_same_shape(pis.to_list())),
    ],
    axis=1
)

In [17]:
pis_df

Unnamed: 0,sub,sbm,omega,pi_0_aligned,pi_1_aligned,pi_2_aligned,pi_3_aligned,pi_4_aligned,pi_5_aligned,pi_6_aligned,pi_7_aligned,pi_8_aligned
0,grp,sbm-nd-h,0.257972,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.9496124031007752, 0.0, 0.0, 0.050387596899...","[[0.647709547170728, 0.08821121761555348, 0.26...","[[0.9174479586394734, 0.018381444206383837, 0....","[[0.9896977992624463, 0.004984344183347049, 0....","[[0.9963806866484217, 0.0024921720916735246, 0...","[[0.996944257278374, 0.0024921720916735246, 0....","[[0.9999999999999998, 0.0], [0.999999999999999...","[[0.9999999999999998, 0.0], [0.999999999999999..."
1,grp,sbm-nd-h,0.141756,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.9014084507042254, 0.0, 0.03521126760563380...","[[0.6123371809806205, 0.08592608281266546, 0.3...","[[0.9233035617069993, 0.040373225267095525, 0....","[[0.9846816259731426, 0.004284669238914727, 0....","[[0.9978750934069948, 0.0, 0.0, 0.002124906593...","[[0.9978750934069948, 0.0, 0.0, 0.002124906593...","[[0.9978750934069948, 0.002124906593005028], [...","[[0.9999999999999998, 0.0], [1.0, 0.0], [0.999..."
2,grp,sbm-nd-h,0.126375,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.9603174603174603, 0.0, 0.02380952380952380...","[[0.8385668276972625, 0.0006901311249137335, 0...","[[0.9879442719116633, 0.012055728088336784, 0....","[[1.0000000000000002, 0.0, 0.0, 0.0], [1.0, 0....","[[1.0000000000000002, 0.0, 0.0, 0.0], [1.0, 0....","[[1.0000000000000002, 0.0, 0.0, 0.0], [1.0, 0....","[[1.0000000000000002, 0.0], [1.0, 0.0], [1.0, ...","[[1.0000000000000002, 0.0], [1.0, 0.0], [1.0, ..."
3,grp,sbm-nd-h,0.121384,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.7768595041322314, 0.05785123966942149, 0.1...","[[0.8862158039468333, 0.011912739185466458, 0....","[[0.9985109076018167, 0.0, 0.00148909239818330...","[[1.0, 0.0, 0.0, 0.0], [0.9999999999999999, 0....","[[1.0, 0.0, 0.0, 0.0], [0.9999999999999999, 0....","[[1.0, 0.0], [0.9999999999999999, 0.0], [1.0, ...","[[1.0, 0.0], [0.9999999999999999, 0.0], [1.0, ..."
4,grp,sbm-nd-h,0.076945,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.961038961038961, 0.0, 0.025974025974025976...","[[0.2759226185662763, 0.13746918443591633, 0.5...","[[0.9306361811886198, 0.05332801790889157, 0.0...","[[0.98401400579353, 0.007734207387983437, 0.00...","[[0.9917482131815135, 0.0, 0.00063349854578763...","[[0.9923817117273012, 0.0, 0.0, 0.007618288272...","[[0.9923817117273012, 0.007618288272698796], [...","[[0.9923817117273012, 0.007618288272698796], [..."
5,grp,sbm-nd-h,0.071034,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.5352112676056338, 0.16901408450704225, 0.2...","[[0.834818165655927, 0.015980878256918792, 0.1...","[[0.9667409649211467, 0.007460047804357703, 0....","[[0.991400337575168, 0.008599662424831801, 0.0...","[[0.9999999999999998, 0.0, 0.0, 0.0], [1.00000...","[[0.9999999999999998, 0.0, 0.0, 0.0], [1.00000...","[[0.9999999999999998, 0.0], [1.000000000000000...","[[0.9999999999999998, 0.0], [1.000000000000000..."
6,grp,sbm-nd-h,0.056213,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.7321428571428571, 0.0, 0.05357142857142857...","[[0.48781218781218777, 0.1417832167832168, 0.3...","[[0.9386988011988012, 0.026457471100328242, 0....","[[1.0, 0.0, 0.0, 0.0], [0.9999999999999999, 0....","[[1.0, 0.0, 0.0, 0.0], [0.9999999999999999, 0....","[[1.0, 0.0, 0.0, 0.0], [0.9999999999999999, 0....","[[1.0, 0.0], [0.9999999999999999, 0.0], [1.0, ...","[[1.0, 0.0], [0.9999999999999999, 0.0], [1.0, ..."
7,grp,sbm-nd-h,0.035825,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.9722222222222222, 0.0, 0.02777777777777777...","[[0.7299382716049383, 0.0, 0.27006172839506176...","[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [...","[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [...","[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [...","[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [...","[[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0...","[[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0..."
8,grp,sbm-nd-h,0.034593,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.2571428571428571, 0.2857142857142857, 0.34...","[[0.8154081632653061, 0.09035714285714286, 0.0...","[[0.961061990730358, 0.003593014128728414, 0.0...","[[0.9676669096209911, 0.029321185617103982, 0....","[[0.9999999999999999, 0.0, 0.0, 0.0], [1.0, 0....","[[0.9999999999999999, 0.0, 0.0, 0.0], [1.0, 0....","[[0.9999999999999999, 0.0], [1.0, 0.0], [0.999...","[[0.9999999999999999, 0.0], [1.0, 0.0], [0.999..."
9,grp,sbm-nd-h,0.02949,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.6896551724137931, 0.06896551724137931, 0.2...","[[0.9375279892521272, 0.04276757725033588, 0.0...","[[1.0, 0.0, 0.0, 0.0], [0.9999999999999998, 0....","[[1.0, 0.0, 0.0, 0.0], [0.9999999999999998, 0....","[[1.0, 0.0, 0.0, 0.0], [0.9999999999999998, 0....","[[1.0, 0.0], [0.9999999999999998, 0.0], [1.0, ...","[[1.0, 0.0], [0.9999999999999998, 0.0], [1.0, ..."


roi list per comm.

In [18]:
col = cols[1]
level = col.replace('pi', '').replace('_', '').replace('aligned', '')

folder = f'{ESTIM_path}/group/community-compositions/modewise/level-{level}'
os.makedirs(folder, exist_ok=True)

for idx_mode, row in pis_df.iterrows():
    pi = row[col]
    
    comm_rois_df = {}
    for idx_comm in range(pi.shape[1]):
        comm_vec = pi[:, idx_comm]
        rois_df = roi_table.copy(deep=True)
        rois_df['membership'] = comm_vec
        rois_df = rois_df.sort_values(by=['membership'], ascending=False)
        rois_df = rois_df[rois_df['membership'] > 0.2] # just a random threshold, we want to focus on strongly belonging regions...
        comm_rois_df[idx_comm+1] = rois_df[['name', 'membership']] 
        # note the index starts with 1, as convenience for writing text in the paper.
        
    comm_rois_df = pd.concat(
        comm_rois_df.values(), 
        keys=comm_rois_df.keys(),
        names=['comm', 'roi_id'],
    )
    
    comm_rois_df.to_csv(f'{folder}/mode-{idx_mode:02d}_desc-comm-rois.csv')
    # break
    

stable vs peripheral rois per comm across modes

In [19]:
mode_files = sorted(glob.glob(f'{folder}/*.csv'))
mode_files

['/home/govindas/new_mouse_dataset/roi-results-v3/source-allen_space-ccfv2_braindiv-whl_nrois-172_res-200/graph-constructed/method-pearson/threshold-signed/edge-binary/density-20/layer-individual/unit-grp/estimates/group/community-compositions/modewise/level-1/mode-00_desc-comm-rois.csv',
 '/home/govindas/new_mouse_dataset/roi-results-v3/source-allen_space-ccfv2_braindiv-whl_nrois-172_res-200/graph-constructed/method-pearson/threshold-signed/edge-binary/density-20/layer-individual/unit-grp/estimates/group/community-compositions/modewise/level-1/mode-01_desc-comm-rois.csv',
 '/home/govindas/new_mouse_dataset/roi-results-v3/source-allen_space-ccfv2_braindiv-whl_nrois-172_res-200/graph-constructed/method-pearson/threshold-signed/edge-binary/density-20/layer-individual/unit-grp/estimates/group/community-compositions/modewise/level-1/mode-02_desc-comm-rois.csv',
 '/home/govindas/new_mouse_dataset/roi-results-v3/source-allen_space-ccfv2_braindiv-whl_nrois-172_res-200/graph-constructed/method

In [20]:
num_modes = len(mode_files)

all_modes_df = []
for idx_mode, file in enumerate(mode_files):
    df = pd.read_csv(file)
    df = pd.concat(
        [
            pd.Series([idx_mode]*len(df), name='mode'),
            df,
        ],
        axis=1
    )
    all_modes_df += [df]
all_modes_df = pd.concat(all_modes_df).reset_index(drop=True)
all_modes_df

Unnamed: 0,mode,comm,roi_id,name,membership
0,0,1,43,"R-Main olfactory bulb (MOB,507)",0.992248
1,0,1,156,"L-Thalamus, polymodal association cortex relat...",0.992248
2,0,1,149,"L-Lateral septal complex (LSX,361)",0.992248
3,0,1,154,"L-Pallidum, caudal region (PALc,895)",0.992248
4,0,1,129,"L-Main olfactory bulb (MOB,593)",0.992248
...,...,...,...,...,...
2614,11,13,153,"L-Pallidum, medial region (PALm,912)",0.306238
2615,11,13,67,"R-Pallidum, medial region (PALm,826)",0.272212
2616,11,13,71,"R-Periventricular zone (PVZ,157)",0.204159
2617,11,13,74,"R-Hypothalamic lateral zone (LZ,290)",0.204159


In [21]:
comms = list(sorted(all_modes_df['comm'].unique()))
roi_col = 'name'
score_threshold = 0.5 # arbitrary threshold
results = {}
for comm in comms:
    comm_df = all_modes_df[all_modes_df['comm'] == comm]
    num_modes_comm_exists = comm_df['mode'].nunique()
    
    roi_stats = comm_df.groupby(roi_col).agg(
        presence_count=('mode', 'nunique'),
        mean_membership=('membership', 'mean'),
    )
    
    roi_stats['presence_ratio'] = roi_stats['presence_count'] / num_modes_comm_exists
    roi_stats['stability_score'] = roi_stats['presence_ratio'] * roi_stats['mean_membership']
    
    stable_core = roi_stats[roi_stats['stability_score'] >= score_threshold].index.tolist()
    flexible_periphery = roi_stats[roi_stats['stability_score'] < score_threshold].index.tolist()
    
    results[comm] = {
        'stable_core': (stable_core),
        'flexible_periphery': (flexible_periphery),
        'modes_present': num_modes_comm_exists,
    }

In [22]:
stability_df = pd.DataFrame.from_dict(results, orient='index')
stability_df.index.name = 'comm'
stability_df = stability_df.reset_index()
stability_df

Unnamed: 0,comm,stable_core,flexible_periphery,modes_present
0,1,"[L-Accessory olfactory bulb (AOB,237), L-Anter...","[L-Agranular insular area, ventral part (AIv,2...",12
1,2,"[L-Anterior area (VISa,312782632), L-Anteromed...","[L-Agranular insular area, dorsal part (AId,19...",12
2,3,"[L-Agranular insular area, dorsal part (AId,19...","[L-Anterior cingulate area, dorsal part (ACAd,...",12
3,4,"[L-Anterior cingulate area, dorsal part (ACAd,...","[L-Anteromedial visual area (VISam,480), L-Dor...",7
4,5,"[L-Anterolateral visual area (VISal,488), L-La...","[L-Agranular insular area, posterior part (AIp...",12
5,6,"[L-Agranular insular area, posterior part (AIp...","[L-Agranular insular area, ventral part (AIv,2...",12
6,7,"[L-Hypothalamic lateral zone (LZ,376), L-Hypot...","[L-Agranular insular area, ventral part (AIv,2...",12
7,8,"[L-Cerebellar cortex (CBX,614), L-Cerebellar n...","[R-Agranular insular area, posterior part (AIp...",11
8,9,"[L-Midbrain, behavioral state related (MBsta,4...","[L-Cerebellar cortex (CBX,614), L-Cerebellar n...",12
9,13,[],"[L-Agranular insular area, posterior part (AIp...",8


In [29]:
core_df = stability_df[['comm', 'stable_core']].copy()
core_df = core_df.explode('stable_core').dropna()
core_df['type'] = 'stable_core'
core_df = core_df.rename(columns={'stable_core': 'roi_name'})
core_df

periphery_df = stability_df[['comm', 'flexible_periphery']].copy()
periphery_df = periphery_df.explode('flexible_periphery').dropna()
periphery_df['type'] = 'flexible_periphery'
periphery_df = periphery_df.rename(columns={'flexible_periphery': 'roi_name'})
periphery_df

comm_rois_df = pd.concat([core_df, periphery_df], ignore_index=True).reset_index(drop=True)
comm_rois_df = comm_rois_df[['comm', 'roi_name', 'type']].reset_index(drop=True)
comm_rois_df

Unnamed: 0,comm,roi_name,type
0,1,"L-Accessory olfactory bulb (AOB,237)",stable_core
1,1,"L-Anterior olfactory nucleus (AON,245)",stable_core
2,1,"L-Dorsal peduncular area (DP,900)",stable_core
3,1,"L-Frontal pole, cerebral cortex (FLP,184)",stable_core
4,1,"L-Infralimbic area (ILA,130)",stable_core
...,...,...,...
435,13,"L-Visceral area (VISC,763)",flexible_periphery
436,13,"R-Hypothalamic lateral zone (LZ,290)",flexible_periphery
437,13,"R-Pallidum, medial region (PALm,826)",flexible_periphery
438,13,"R-Periventricular region (PVR,141)",flexible_periphery


In [30]:
folder = f'{ESTIM_path}/group/community-compositions/roi-stability/level-{level}'
os.makedirs(folder, exist_ok=True)

comm_rois_df.to_csv(f'{folder}/desc-stable-flexible-rois.csv')