# Dec 30, 2024: compare communities with mSBM communities

In [1]:
import csv
import os
import sys
import numpy as np
import pandas as pd
import scipy as sp 
import dill as pickle 
from os.path import join as pjoin
from itertools import product
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
import subprocess
from scipy import sparse, stats, linalg
from scipy.spatial.distance import jensenshannon, cosine
from multiprocessing import Pool
import glob
import random

from sklearn.cluster import DBSCAN, SpectralCoclustering, SpectralClustering
from scipy.optimize import linear_sum_assignment

import arviz as az

import ants
from nipype.interfaces import afni

from itertools import product, combinations, chain
import multiprocessing as mp
from functools import partial

# networks
import graph_tool.all as gt

# plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.cm import rainbow

plt.rcParamsDefault['font.family'] = "sans-serif"
plt.rcParamsDefault['font.sans-serif'] = "Arial"
plt.rcParams['font.size'] = 14
plt.rcParams["errorbar.capsize"] = 0.5

import cmasher as cmr  # CITE ITS PAPER IN YOUR MANUSCRIPT
import colorcet as cc

# ignore user warnings
import warnings
warnings.filterwarnings("ignore") #, category=UserWarning)

In [2]:
class ARGS():
    pass

args = ARGS()

args.SEED = 100

def set_seed(args):
    gt.seed_rng(args.SEED)
    np.random.seed(args.SEED)

set_seed(args)

In [3]:
args.type = 'spatial'
args.roi_size = 225
args.maintain_symmetry = True
args.brain_div = 'whl'
args.num_rois = 162

PARC_DESC = (
    f'type-{args.type}'
    f'_size-{args.roi_size}'
    f'_symm-{args.maintain_symmetry}'
    f'_braindiv-{args.brain_div}'
    f'_nrois-{args.num_rois}'
)

In [None]:
args.GRAPH_DEF = f'constructed'
args.GRAPH_METHOD = f'pearson-corr'
args.THRESHOLDING = f'positive'
args.EDGE_DEF = f'binary'
args.EDGE_DENSITY = 20
args.LAYER_DEF = f'individual'
args.DATA_UNIT = f'sub'

BASE_path = f'{os.environ["HOME"]}/mouse_dataset'
PARCELS_path = f'{BASE_path}/parcels'
ROI_path = f'{BASE_path}/roi_results_v2/{PARC_DESC}'
TS_path = f'{ROI_path}/runwise_timeseries'
ROI_RESULTS_path = (
    f'{ROI_path}'
    f'/graph-{args.GRAPH_DEF}/method-{args.GRAPH_METHOD}'
    f'/threshold-{args.THRESHOLDING}/edge-{args.EDGE_DEF}/density-{args.EDGE_DENSITY}'
    f'/layer-{args.LAYER_DEF}/unit-{args.DATA_UNIT}'
)
RSN_ROI_path = f'{ROI_path}/rsns'
os.system(f'mkdir -p {RSN_ROI_path}')
IC_ROI_path = f'{ROI_path}/ics'
os.system(f'mkdir -p {IC_ROI_path}')
GRAPH_path = f'{ROI_RESULTS_path}/graphs'
os.system(f'mkdir -p {GRAPH_path}')
SBM_path = f'{ROI_RESULTS_path}/model-fits'
os.system(f'mkdir -p {SBM_path}')
ESTIM_path = f'{ROI_RESULTS_path}/estimates'
os.system(f'mkdir -p {ESTIM_path}/individual')
os.system(f'mkdir -p {ESTIM_path}/group')

0

mSBM soft-communities per animal

In [5]:
args.dc, args.sbm = True, 'm'

args.nested = True if args.sbm in ['h'] else False

args.force_niter = 40000
args.num_draws = int((1/2) * args.force_niter)

def sbm_name(args):
    dc = f'dc' if args.dc else f'nd'
    dc = f'' if args.sbm in ['m', 'a'] else dc
    file = f'sbm-{dc}-{args.sbm}'
    return file

SBM = sbm_name(args)
SBM

'sbm--m'

In [6]:
marginals_files = sorted(glob.glob(f'{ESTIM_path}/individual/sub-*/partition-modes-group-aligned/{SBM}/desc-mem-mats.pkl', recursive=True))
marginals_df = []
for sbm_file in marginals_files:
    with open(f'{sbm_file}', 'rb') as f:
        row = pickle.load(f)
    marginals_df += [row]
marginals_df = pd.concat(marginals_df).reset_index(drop=True)
mode_ids = list(chain.from_iterable([list(range(count)) for count in marginals_df['sub'].value_counts().sort_index().to_list()]))
marginals_df['mode_id'] = mode_ids

marginals_df

Unnamed: 0,sub,sbm,pi_aligned,omega,mode_id
0,SLC01,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0....",0.806483,0
1,SLC01,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1....",0.165306,1
2,SLC01,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [...",0.019968,2
3,SLC01,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",0.008243,3
4,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0....",0.3338,0
5,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1....",0.3292,1
6,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [...",0.26572,2
7,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",0.07104,3
8,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.00024,4
9,SLC03,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0....",0.575495,0


In [7]:
cols = [col for col in  list(marginals_df.columns) if 'pi_' in col]
cols

['pi_aligned']

In [8]:
def get_membership_matrix(num_rois, df, col='pi'):
    pis = [np.zeros((num_rois, 1)) if np.isnan(pi).all() else pi for pi in df[col]]

    num_modes = len(df)
    num_comms = np.max([pi.shape[-1] for pi in pis])
    num_rois = num_rois
    M = np.zeros((num_rois, num_modes, num_comms)) # membership profile matrix

    for idx_mode, pi in enumerate(pis):
        M[:, idx_mode, :pi.shape[-1]] = pi
    
    return M

In [9]:
# SOFT MARGINALS
soft_marginals_df = []
for sub, group in marginals_df.groupby('sub'):
    omegas = group['omega'].to_list()
    dct = {'sub': [sub], 'sbm': [SBM]}
    for col in cols:
        M = get_membership_matrix(args.num_rois, group, col=col)
        SCs = np.average(M, axis=1, weights=omegas) # soft-comms.
        dct[col] = [SCs]
    soft_marginals_df += [pd.DataFrame(dct)]
soft_marginals_df = pd.concat(soft_marginals_df).reset_index(drop=True)

soft_marginals_msbm_df = soft_marginals_df
# neglecting useless comms: that are 0's throughout
soft_marginals_df['pi_aligned'] = soft_marginals_df['pi_aligned'].apply(lambda pi: pi[:, :4])
soft_marginals_msbm_df

Unnamed: 0,sub,sbm,pi_aligned
0,SLC01,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 0.571297898812130..."
1,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [3.592814371257485e-06,..."
2,SLC03,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 0.393461292951276..."
3,SLC04,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [5.989221796455099e-07,..."
4,SLC05,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 0.683622828784119..."
5,SLC06,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 0.655648982630272..."
6,SLC07,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 0.519252555831265..."
7,SLC08,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 0.410449908714820..."
8,SLC09,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 0.523038808933002..."
9,SLC10,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 0.683568138957816..."


SBM communities per mode per animal

In [10]:
args.dc, args.sbm = False, 'h'

args.nested = True if args.sbm in ['h'] else False

args.force_niter = 40000
args.num_draws = int((1/2) * args.force_niter)

def sbm_name(args):
    dc = f'dc' if args.dc else f'nd'
    dc = f'' if args.sbm in ['m', 'a'] else dc
    file = f'sbm-{dc}-{args.sbm}'
    return file

SBM = sbm_name(args)
SBM

'sbm-nd-h'

In [11]:
marginals_files = sorted(glob.glob(f'{ESTIM_path}/individual/sub-*/partition-modes-group-aligned/{SBM}/desc-mem-mats.pkl', recursive=True))
marginals_df = []
for sbm_file in marginals_files:
    with open(f'{sbm_file}', 'rb') as f:
        row = pickle.load(f)
    marginals_df += [row]
marginals_df = pd.concat(marginals_df).reset_index(drop=True)
mode_ids = list(chain.from_iterable([list(range(count)) for count in marginals_df['sub'].value_counts().sort_index().to_list()]))
marginals_df['mode_id'] = mode_ids

marginals_df

Unnamed: 0,sub,sbm,pi_0,pi_1,pi_2,pi_3,pi_4,pi_5,pi_6,pi_7,pi_8,omega,mode_id
0,SLC01,sbm-nd-h,"[[0.08362369337979095, 0.0, 0.0034843205574912...","[[0.7979094076655052, 0.017421602787456445, 0....","[[0.49477351916376305, 0.0, 0.3205574912891986...","[[0.926829268292683, 0.05574912891986063, 0.0,...","[[0.9860627177700348, 0.010452961672473868, 0....","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,0.286725,0
1,SLC01,sbm-nd-h,"[[0.06274509803921569, 0.0, 0.0, 0.0, 0.360784...","[[0.8705882352941177, 0.0196078431372549, 0.06...","[[0.5294117647058824, 0.0, 0.30980392156862746...","[[0.9529411764705882, 0.03529411764705882, 0.0...","[[0.9921568627450981, 0.00392156862745098, 0.0...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,0.254538,1
2,SLC01,sbm-nd-h,"[[0.0196078431372549, 0.0, 0.0, 0.0, 0.3431372...","[[0.9901960784313726, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.8529411764705882, 0.0, 0.10784313725490197...","[[0.9803921568627451, 0.0196078431372549], [0....","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,,0.102439,2
3,SLC01,sbm-nd-h,"[[0.02531645569620253, 0.0, 0.0, 0.0, 0.316455...","[[0.9620253164556962, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.7468354430379747, 0.0, 0.20253164556962025...","[[0.9746835443037974, 0.02531645569620253], [1...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,,0.079128,3
4,SLC01,sbm-nd-h,"[[0.07692307692307693, 0.0, 0.0, 0.0, 0.320512...","[[0.7948717948717948, 0.07692307692307693, 0.0...","[[0.7692307692307693, 0.0, 0.15384615384615385...","[[0.9615384615384616, 0.038461538461538464], [...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,,0.078369,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...
69,SLC10,sbm-nd-h,"[[0.02531645569620253, 0.0, 0.0, 0.0, 0.316455...","[[0.9620253164556962, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.7468354430379747, 0.0, 0.20253164556962025...","[[0.9746835443037974, 0.02531645569620253], [1...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,,0.080736,3
70,SLC10,sbm-nd-h,"[[0.07692307692307693, 0.0, 0.0, 0.0, 0.320512...","[[0.7948717948717948, 0.07692307692307693, 0.0...","[[0.7692307692307693, 0.0, 0.15384615384615385...","[[0.9615384615384616, 0.038461538461538464], [...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,,0.052651,4
71,SLC10,sbm-nd-h,"[[0.22666666666666666, 0.0, 0.0, 0.0, 0.133333...","[[0.8133333333333334, 0.06666666666666667, 0.0...","[[0.5733333333333334, 0.0, 0.26666666666666666...","[[0.9466666666666667, 0.05333333333333334], [0...","[[0.9733333333333334, 0.02666666666666667], [0...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,0.035607,5
72,SLC10,sbm-nd-h,"[[0.09375, 0.0, 0.0, 0.0, 0.078125, 0.0, 0.765...","[[0.734375, 0.0625, 0.078125, 0.0, 0.0, 0.0, 0...","[[0.828125, 0.0, 0.171875], [0.921875, 0.0, 0....","[[0.953125, 0.046875], [0.984375, 0.015625], [...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,,0.017924,6


In [12]:
cols = [col for col in  list(marginals_df.columns) if 'pi_' in col]
cols

['pi_0', 'pi_1', 'pi_2', 'pi_3', 'pi_4', 'pi_5', 'pi_6', 'pi_7', 'pi_8']

compare each mode in an animal with the mSBM soft-comms per animal

In [13]:
def find_similarity(sbm_comm, sys_comm):
    '''
        sbm_comm: comm vector of an SBM
        sys_comm: (soft-)comm vector of mSBM
    '''
    if (np.sum(sbm_comm) < 0.1) or (np.sum(sys_comm) < 0.1):
        dist = 1.0
    else:
        dist = cosine(sbm_comm, sys_comm)
    sim = 1 - dist
    return sim

def find_similarities_per_sys(pi, sys_comm):
    '''
    pi: membership matrix of an SBM

    '''
    sims = []
    for idx_comm in range(pi.shape[-1]):
        sim = find_similarity(pi[:, idx_comm], sys_comm)
        sims += [sim]
    return np.array(sims)

def fill_similarity_matrix(pi, sys_pi):
    # similarity matrix
    X = np.zeros((sys_pi.shape[-1], pi.shape[-1]))
    for idx_sys in range(sys_pi.shape[-1]):
        X[idx_sys, :] = find_similarities_per_sys(pi=pi, sys_comm=sys_pi[:, idx_sys])
    return X

In [14]:
def get_similarity_matrix(marginals_df, soft_marginals_msbm_df):
    similarities_df = marginals_df.copy(deep=True)
    for idx, row in tqdm(similarities_df.iterrows()): # loop on sub
        sys_pi = soft_marginals_msbm_df[soft_marginals_msbm_df['sub'] == row['sub']]['pi_aligned'].to_list()[0]
        for col in cols:
            pi = row[col]
            if np.isnan(pi).all(): continue
            X = fill_similarity_matrix(pi, sys_pi)
            similarities_df.at[idx, col] = X
    return similarities_df

In [15]:
similarities_df = get_similarity_matrix(marginals_df, soft_marginals_msbm_df)
similarities_df

74it [00:00, 179.91it/s]


Unnamed: 0,sub,sbm,pi_0,pi_1,pi_2,pi_3,pi_4,pi_5,pi_6,pi_7,pi_8,omega,mode_id
0,SLC01,sbm-nd-h,"[[0.4889715713774476, 0.0, 0.00678104121573541...","[[0.8013731509772752, 0.007942079503983068, 0....","[[0.34033179756497733, 0.0003301286761762112, ...","[[0.45358686428381667, 0.5991144828639039, 0.0...","[[0.46349110085515655, 0.5895524267103552, 0.0...","[[0.46571900968999214], [0.5224075763922196], ...",,,,0.286725,0
1,SLC01,sbm-nd-h,"[[0.5163206740807964, 0.0, 0.21776946747113668...","[[0.7859174632331904, 0.02130614949709697, 0.0...","[[0.3501678926648961, 0.00012876621123603105, ...","[[0.45762672245164526, 0.5199093441187926, 0.0...","[[0.465289667534982, 0.3150262945995852, 0.000...","[[0.46571900968999214], [0.5224075763922196], ...",,,,0.254538,1
2,SLC01,sbm-nd-h,"[[0.509120989630761, 0.0, 0.21471613840957393,...","[[0.7042404218439557, 0.0, 0.00233817227346799...","[[0.44130006766371876, 0.0, 0.5468178073673975...","[[0.46231174668738606, 0.6168789214835422], [0...","[[0.46571900968999214], [0.5224075763922196], ...",,,,,0.102439,2
3,SLC01,sbm-nd-h,"[[0.5124050357884926, 0.0, 0.22449122695767953...","[[0.6609217165725276, 0.00010252790164289038, ...","[[0.41315489418544016, 0.0, 0.6005022518312747...","[[0.4596436592184975, 0.6645069903503574], [0....","[[0.46571900968999214], [0.5224075763922196], ...",,,,,0.079128,3
4,SLC01,sbm-nd-h,"[[0.5111290887489445, 0.0, 0.21995208634396346...","[[0.7756198801392539, 0.06769959781664514, 0.0...","[[0.42468388865904194, 0.0, 0.5182823468295968...","[[0.459102277782767, 0.6268683576709171], [0.5...","[[0.46571900968999214], [0.5224075763922196], ...",,,,,0.078369,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...
69,SLC10,sbm-nd-h,"[[0.5124070410659938, 0.0, 0.22449210549550636...","[[0.6607851440237258, 0.00010172064117908164, ...","[[0.4130749914401399, 0.0, 0.6003789181911121,...","[[0.4595536017715328, 0.6643744788460266], [0....","[[0.46562772084285986], [0.5212587331026804], ...",,,,,0.080736,3
70,SLC10,sbm-nd-h,"[[0.5111310890330756, 0.0, 0.21995294711803726...","[[0.7754945644221192, 0.06767088910665953, 0.0...","[[0.4246013578356558, 0.0, 0.5181556833418024,...","[[0.45901234826705706, 0.6267433518089796], [0...","[[0.46562772084285986], [0.5212587331026804], ...",,,,,0.052651,4
71,SLC10,sbm-nd-h,"[[0.6471084416640377, 0.0, 0.00109120882481450...","[[0.7507606894159694, 0.04381733987234104, 0.0...","[[0.3685655235556192, 0.0, 0.5008459498683804,...","[[0.45583928232671367, 0.6355124265537873], [0...","[[0.4598157962905618, 0.6676234895443506], [0....","[[0.46562772084285986], [0.5212587331026804], ...",,,,0.035607,5
72,SLC10,sbm-nd-h,"[[0.4997880730146472, 0.0, 0.0, 0.090742631301...","[[0.7486121264606073, 0.05339468341285325, 0.0...","[[0.4317782224247173, 0.0, 0.6481168150689558]...","[[0.4599332267969609, 0.521655572130104], [0.5...","[[0.46562772084285986], [0.5212587331026804], ...",,,,,0.017924,6


In [16]:
for sub, group in similarities_df.groupby('sub'):
    folder = f'{ESTIM_path}/individual/sub-{sub}/comparions-with-msbm-comms/{SBM}'
    os.system(f'mkdir -p {folder}')

    with open(f'{folder}/sys-msbm_desc-similarities.pkl', 'wb') as f:
        pickle.dump(group, f)

In [17]:
def get_soft_similarity_matrix(similarities_df, num_sys):
    soft_similarities_df = []
    for sub, group in similarities_df.groupby('sub'):
        omegas = group['omega'].to_numpy()
        dct = {'sub': [sub], 'sbm': [SBM]}
        for col in cols:
            M = get_membership_matrix(num_sys, group, col)
            # num_rsns x num_modes x num_comms

            soft_X = np.average(M, axis=1, weights=omegas)
            dct[col] = [soft_X]
        soft_similarities_df += [pd.DataFrame(dct)]
    soft_similarities_df = pd.concat(soft_similarities_df).reset_index(drop=True)
    return soft_similarities_df

In [18]:
soft_similarities_df = get_soft_similarity_matrix(similarities_df, num_sys=4)
soft_similarities_df

Unnamed: 0,sub,sbm,pi_0,pi_1,pi_2,pi_3,pi_4,pi_5,pi_6,pi_7,pi_8
0,SLC01,sbm-nd-h,"[[0.5165526768526827, 0.0, 0.12806364115361296...","[[0.7596920150332357, 0.019718214347037, 0.049...","[[0.3793646323438597, 0.00012743216463767233, ...","[[0.4577569142954669, 0.5718765036329632, 0.0,...","[[0.4645335136349067, 0.2994732122188721, 0.00...","[[0.2871220076133542], [0.322071268308342], [0...","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]"
1,SLC02,sbm-nd-h,"[[0.5159420233326213, 2.9086908785939344e-05, ...","[[0.7520753797272884, 0.019697863112648397, 0....","[[0.38814253166742996, 0.0001328312818946584, ...","[[0.45869791155938044, 0.5900862007511206, 9.0...","[[0.46527084953654796, 0.2626325462756104, 0.0...","[[0.23729018554783515], [0.2753908009724088], ...","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]"
2,SLC03,sbm-nd-h,"[[0.49961854459606686, 0.0, 0.0989405153125615...","[[0.7768181293266482, 0.009519235845874854, 0....","[[0.3620199766817718, 0.00021962439661391247, ...","[[0.456242503741539, 0.5854661948161786, 0.0, ...","[[0.4643995654100902, 0.40610248997195364, 0.0...","[[0.37234178677067087], [0.41447642524682204],...","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]"
3,SLC04,sbm-nd-h,"[[0.5179251268544272, 4.830763243681485e-06, 0...","[[0.7603823221131731, 0.019067947500498568, 0....","[[0.3779368812917592, 0.00013556298236887395, ...","[[0.45762106306759687, 0.602803787043445, 0.0,...","[[0.46479190141050786, 0.3257024804758131, 0.0...","[[0.2905244659957725], [0.3359438388357023], [...","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]"
4,SLC05,sbm-nd-h,"[[0.5014481190068404, 0.0, 0.11739788111535188...","[[0.7684016032219989, 0.013092590238389136, 0....","[[0.3703438510640338, 0.00018507577955107594, ...","[[0.45660300382644736, 0.5935598121691739, 0.0...","[[0.4644758329244695, 0.34630750684694767, 0.0...","[[0.31828752934489946], [0.3563152164244627], ...","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]"
5,SLC06,sbm-nd-h,"[[0.5183288729060967, 2.5745328147426476e-05, ...","[[0.7112363541901516, 0.023244534725754135, 0....","[[0.3962407186822072, 0.0006742725700802878, 0...","[[0.45892000382643916, 0.5714369257578082, 0.0...","[[0.4646537541520052, 0.22543672906129147, 0.0...","[[0.22899177621635564, 0.0], [0.25626270062831...","[[0.012351239242050671], [0.013822813568013695...","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]"
6,SLC07,sbm-nd-h,"[[0.5040246746799957, 0.0, 0.11149910312390966...","[[0.7698293374635361, 0.013997162487616016, 0....","[[0.36925699088333364, 0.00018753543319329445,...","[[0.4565398274481788, 0.5944713220688793, 0.0,...","[[0.46441718813833205, 0.3610951701474743, 0.0...","[[0.32666879006188043], [0.36478559056622], [0...","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]"
7,SLC08,sbm-nd-h,"[[0.5185766998322994, 0.0, 0.12819740948326155...","[[0.7546291398812909, 0.021236853894957164, 0....","[[0.3841774053113841, 0.00010828095968408028, ...","[[0.4577281709116555, 0.5945170851085138, 0.0,...","[[0.464544032137886, 0.2770045845453218, 0.000...","[[0.25732107166261237], [0.2866275028691164], ...","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]"
8,SLC09,sbm-nd-h,"[[0.5163390369848163, 0.0, 0.11492636733472206...","[[0.7586177454492486, 0.02148254444055919, 0.0...","[[0.38257834130044355, 0.0001258282632316528, ...","[[0.4573753285984811, 0.5967076098726676, 0.0,...","[[0.4643933924609069, 0.2973932535102931, 0.00...","[[0.26351659687716433], [0.2942896972980016], ...","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]"
9,SLC10,sbm-nd-h,"[[0.5054547897630018, 0.0, 0.10200282482488034...","[[0.7682744285982246, 0.01349803537865391, 0.0...","[[0.37048294974254087, 0.000185334379972564, 0...","[[0.45648795189160635, 0.5954328406496271, 0.0...","[[0.46424900954616105, 0.36677337223659706, 0....","[[0.3230239919750026], [0.3616173807561704], [...","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]","[[0.0], [0.0], [0.0], [0.0]]"


In [19]:
folder = f'{ESTIM_path}/group/comparisons-with-msbm-comms/{SBM}'
os.system(f'mkdir -p {folder}')

for col in tqdm(cols):

    ncols = 1
    nrows = len(soft_similarities_df)
    fig, axs = plt.subplots(nrows, ncols, figsize=(20*ncols, 6*nrows),)
    fig.tight_layout(h_pad=3, w_pad=3)
    fig.suptitle(f'{SBM} {col}', x=0.0, y=1.0)
    for idx, row in soft_similarities_df.iterrows():
        r = idx
        ax = axs[r]
        sub = row['sub']
        soft_X = row[col]
        sns.heatmap(soft_X, ax=ax, vmin=0.0, vmax=1.0, annot=True, fmt='.2f')
        ax.set(xlabel=f'comm', ylabel=f'mSBM comm', title=f'{sub}: similarity matrix')

    fig.savefig(f'{folder}/col-{col}_sys-msbm_desc-similarity-matrices.pdf', bbox_inches='tight')
    plt.close('all')

100%|██████████| 9/9 [00:39<00:00,  4.43s/it]


groups of comms that resemble an mSBM comm.

In [20]:
def selected_comms_in_group(Xs, thresh=0.1):
    sel_comms = []
    for X in Xs:
        scs = np.sum(X, axis=0) > thresh
        scs = np.where(scs)[0]
        sel_comms += [scs]
    lens = list(map(len, sel_comms))
    loc = np.where(lens == np.min(lens))[0][0]
    sel_comms = sel_comms[loc]
    return sel_comms

In [21]:
def align_labels(labels_to_align, reference_labels, num_clusters):
    """
    Align labels_to_align with reference_labels using Hungarian algorithm.
    """
    cost_matrix = np.zeros((num_clusters, num_clusters))
    for i in range(num_clusters):
        for j in range(num_clusters):
            mask_i = labels_to_align == i
            mask_j = reference_labels == j
            cost_matrix[i,j] = -np.sum(mask_i & mask_j)

    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    aligned_labels = np.zeros_like(labels_to_align)
    for old_label, new_label in zip(row_ind, col_ind):
        aligned_labels[labels_to_align == old_label] = new_label
    return aligned_labels

def compute_block_matrix(matrices, row_labels, col_labels, num_clusters):
    """
    Compute block matrix
    
    Args:
        matrices: List of matrices
        row_labels: Row cluster labels
        col_labels: Column cluster labels
        n_clusters: Number of clusters
    Returns:
        block_matrix: Average values within each block (n_clusters, n_clusters)
    """
    block_matrix = np.zeros((num_clusters, num_clusters))
    for i in range(num_clusters):
        for j in range(num_clusters):
            row_mask = (row_labels == i)
            col_mask = (col_labels == j)
            block_matrix[i,j] = np.mean([mat[row_mask][:, col_mask] for mat in matrices])
    return block_matrix

def consensus_coclustering(matrices, num_clusters, compute_block_pattern=False):
    """
    Get consensus labels, with optional block pattern computation.
    Only compute block pattern for reference labels, skip for bootstrap.
    """
    # Get individual coclusterings
    models = [SpectralCoclustering(n_clusters=num_clusters).fit(mat) for mat in matrices]

    # Build consensus matrices
    n_rows, n_cols = matrices[0].shape
    row_labels = [model.row_labels_ for model in models]
    col_labels = [model.column_labels_ for model in models]

    row_consensus = np.mean([np.eye(n_rows)[labels][:, labels] for labels in row_labels], axis=0)
    col_consensus = np.mean([np.eye(n_cols)[labels][:, labels] for labels in col_labels], axis=0)

    # Get consensus labels
    final_rows = SpectralClustering(n_clusters=num_clusters, affinity='precomputed').fit_predict(row_consensus)
    final_cols = SpectralClustering(n_clusters=num_clusters, affinity='precomputed').fit_predict(col_consensus)

    if compute_block_pattern:
        # Compute block pattern only for reference
        block_matrix = compute_block_matrix(matrices, final_rows, final_cols, num_clusters)
        return final_rows, final_cols, block_matrix
    
    return final_rows, final_cols

def align_row_col_labels(row_labels, col_labels, block_pattern, n_clusters):
    """
    Align column labels to row labels based on block pattern.
    """
    _, col_ind = linear_sum_assignment(-block_pattern)
    col_mapping = np.zeros_like(col_ind)
    col_mapping[col_ind] = np.arange(n_clusters)
    aligned_cols = col_mapping[col_labels]

    return row_labels, aligned_cols

def compute_element_stability(ref_labels, boot_labels_list):
    """
    Compute stability of cluster assignment for each element.

    Args:
        ref_labels: Reference labels
        boot_labels_list: List of labels from bootstrap samples
    Returns:
        element_stability: Array of stability scores for each element
    """
    boot_labels = np.array(boot_labels_list)
    matches = (boot_labels == ref_labels[np.newaxis, :])
    element_stability = np.mean(matches, axis=0)
    return element_stability

def bootstrap_consensus(matrices, num_clusters, n_bootstrap=100):
    """
    Bootstrap with stability analysis.

    Returns:
        ref_rows: Reference row labels
        ref_cols: Reference column labels
        mean_pattern: Mean block pattern
        ci_<lower, upper>: 95% CI around mean_pattern
        row_stability: Stability matrix for rows
        col_stability: Stability matrix for columns
        row_element_stability: Individual stability for each row
        col_element_stability: Individual stability for each column
    """
    ref_rows, ref_cols, ref_pattern = consensus_coclustering(matrices, num_clusters, compute_block_pattern=True)
    ref_rows, ref_cols = align_row_col_labels(ref_rows, ref_cols, ref_pattern, num_clusters)

    n_rows, n_cols = matrices[0].shape
    row_stability = np.zeros((n_rows, n_rows))
    col_stability = np.zeros((n_cols, n_cols))
    
    boot_row_labels = []
    boot_col_labels = []
    block_patterns = []

    for _ in tqdm(range(n_bootstrap)):
        indices = np.random.choice(len(matrices), size=len(matrices), replace=True)
        boot_matrices = [matrices[i] for i in indices]
        
        boot_rows, boot_cols = consensus_coclustering(boot_matrices, num_clusters, compute_block_pattern=False)
        aligned_rows = align_labels(boot_rows, ref_rows, num_clusters)
        aligned_cols = align_labels(boot_cols, ref_cols, num_clusters)
        
        # co-occurrence matrix
        row_stability += np.eye(n_rows)[aligned_rows][:, aligned_rows]
        col_stability += np.eye(n_cols)[aligned_cols][:, aligned_cols]
        
        boot_row_labels.append(aligned_rows)
        boot_col_labels.append(aligned_cols)
        
        block_matrix = compute_block_matrix(boot_matrices, aligned_rows, aligned_cols, num_clusters)
        block_patterns.append(block_matrix)
    
    row_stability /= n_bootstrap
    col_stability /= n_bootstrap

    row_element_stability = compute_element_stability(ref_rows, boot_row_labels)
    col_element_stability = compute_element_stability(ref_cols, boot_col_labels)

    mean_pattern = np.mean(block_patterns, axis=0)
    ci_lower = np.percentile(block_patterns, 2.5, axis=0)
    ci_upper = np.percentile(block_patterns, 97.5, axis=0)

    return (
        ref_rows, ref_cols, mean_pattern, ci_lower, ci_upper, 
        row_stability, col_stability, 
        row_element_stability, col_element_stability,
    )

In [22]:
def concatenate(in_files, out_file):
    try:
        os.remove(out_file)
    except:
        pass

    tcat = afni.TCat()
    tcat.inputs.in_files = in_files
    tcat.inputs.out_file = out_file
    tcat.inputs.rlt = ''
    tcat.inputs.outputtype = 'NIFTI'
    tcat.cmdline 
    tcat.run()

    for file in in_files:
        try:
            os.remove(file)
        except:
            pass
    return None

In [23]:
# save sys groups
def save_sys_groups(col, num_clusters, row_labels, sys_name='msbm'):

    cmask_img = ants.image_read(
        f'{BASE_path}/voxel/common_brain_mask.nii.gz'
    )

    name = f'mode-00'

    # these comm files are only for visualization purposes
    comms_file = sorted(glob.glob(f'{ESTIM_path}/group/membership-mats-group-aligned/sbm--m/marginal-visuals/nii/{name}.nii.gz'))[0]
    comms_vol = ants.image_read(comms_file).numpy()

    for i in range(num_clusters):
        comms = np.where(row_labels == i)[0]
        grp_comms_vol = np.sum(comms_vol[..., comms], axis=-1)
        cmask_img.new_image_like(grp_comms_vol).to_file(
            f'{folder}/group-{i:02d}.nii.gz'
        )
    in_files = sorted(glob.glob(f'{folder}/group-*.nii.gz'))
    out_file = f'{folder}/col-{col}_sys-{sys_name}_desc-sys-groups.nii.gz'
    concatenate(in_files, out_file)
    return None

In [24]:
# save comm groups
def save_comm_groups(col, num_clusters, sel_comms, col_labels, sys_name='msbm'):

    cmask_img = ants.image_read(
        f'{BASE_path}/voxel/common_brain_mask.nii.gz'
    )

    level = col.split('_')[-1]
    if level == 'aligned':
        name = f'mode-00'
    else:
        name = f'mode-00_level-{level}'

    # these comm files are only for visualization purposes
    comms_file = sorted(glob.glob(f'{ESTIM_path}/group/membership-mats-group-aligned/{SBM}/marginal-visuals/nii/{name}.nii.gz'))[0]
    comms_vol = ants.image_read(comms_file).numpy()

    for i in range(num_clusters):
        comms = sel_comms[col_labels == i]
        grp_comms_vol = np.sum(comms_vol[..., comms], axis=-1)
        cmask_img.new_image_like(grp_comms_vol).to_file(
            f'{folder}/group-{i:02d}.nii.gz'
        )
    in_files = sorted(glob.glob(f'{folder}/group-*.nii.gz'))
    out_file = f'{folder}/col-{col}_sys-{sys_name}_desc-comm-groups.nii.gz'
    concatenate(in_files, out_file)
    return None

In [27]:
num_clusters = 4
folder = f'{ESTIM_path}/group/comparisons-with-msbm-comms/{SBM}/num-clusters-{num_clusters}'
os.system(f'mkdir -p {folder}')

sys_name = 'msbm'
for col in cols[:2]:
    Xs = soft_similarities_df[col].to_list()
    sel_comms = selected_comms_in_group(Xs)
    Xs = [X[:, sel_comms] for X in Xs]

    (
        row_labels, col_labels, 
        mean_pat, ci_lower, ci_upper,
        row_stability, col_stability,
        row_element_stability, col_element_stability,
    ) = bootstrap_consensus(matrices=Xs, num_clusters=num_clusters, n_bootstrap=100)

    with open(f'{folder}/col-{col}_desc-groups.pkl', 'wb') as f:
        pickle.dump(
            (   
                sel_comms,
                row_labels, col_labels, 
                mean_pat, ci_lower, ci_upper,
                row_stability, col_stability,
                row_element_stability, col_element_stability,
            ), 
            f
        )
    
    save_sys_groups(col, num_clusters, row_labels, sys_name)
    save_comm_groups(col, num_clusters, sel_comms, col_labels, sys_name, )
    # break

100%|██████████| 100/100 [00:07<00:00, 12.52it/s]


241230-16:08:00,188 nipype.interface INFO:
	 stderr 2024-12-30T16:08:00.188664:++ 3dTcat: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
241230-16:08:00,192 nipype.interface INFO:
241230-16:08:00,195 nipype.interface INFO:
241230-16:08:00,241 nipype.interface INFO:
	 stderr 2024-12-30T16:08:00.241144:++ elapsed time = 0.1 s
241230-16:08:00,594 nipype.interface INFO:
	 stderr 2024-12-30T16:08:00.594443:++ 3dTcat: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
241230-16:08:00,597 nipype.interface INFO:
241230-16:08:00,598 nipype.interface INFO:
241230-16:08:00,625 nipype.interface INFO:
	 stderr 2024-12-30T16:08:00.625299:++ elapsed time = 0.0 s


100%|██████████| 100/100 [00:07<00:00, 12.80it/s]


241230-16:08:08,966 nipype.interface INFO:
	 stderr 2024-12-30T16:08:08.966351:++ 3dTcat: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
241230-16:08:08,969 nipype.interface INFO:
241230-16:08:08,970 nipype.interface INFO:
241230-16:08:09,2 nipype.interface INFO:
	 stderr 2024-12-30T16:08:09.002264:++ elapsed time = 0.0 s
241230-16:08:09,342 nipype.interface INFO:
	 stderr 2024-12-30T16:08:09.342048:++ 3dTcat: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
241230-16:08:09,345 nipype.interface INFO:
241230-16:08:09,348 nipype.interface INFO:
241230-16:08:09,393 nipype.interface INFO:
	 stderr 2024-12-30T16:08:09.393687:++ elapsed time = 0.1 s
