# Nov 25, 2024: visualize membership matrices on brain per solution mode per animal

In [1]:
import csv
import os
import sys
import numpy as np
import pandas as pd
import scipy as sp 
import dill as pickle 
from os.path import join as pjoin
from itertools import product
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
import subprocess
from scipy import sparse, stats
from multiprocessing import Pool
import glob
import random

import arviz as az

import ants
from nipype.interfaces import afni

from itertools import product, combinations, chain
import multiprocessing as mp
from functools import partial

# networks
import graph_tool.all as gt

# plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.cm import rainbow

plt.rcParamsDefault['font.family'] = "sans-serif"
plt.rcParamsDefault['font.sans-serif'] = "Arial"
plt.rcParams['font.size'] = 14
plt.rcParams["errorbar.capsize"] = 0.5

import cmasher as cmr  # CITE ITS PAPER IN YOUR MANUSCRIPT
import colorcet as cc

# ignore user warnings
import warnings
warnings.filterwarnings("ignore") #, category=UserWarning)

In [2]:
class ARGS():
    pass

args = ARGS()

args.SEED = 100

def set_seed(args):
    gt.seed_rng(args.SEED)
    np.random.seed(args.SEED)

set_seed(args)

In [3]:
args.type = 'allen' #'spatial'
args.roi_size = 'x' #225
args.maintain_symmetry = True #True
args.brain_div = 'whl' #'whl'
args.num_rois = 172 #162

PARC_DESC = (
    f'type-{args.type}'
    f'_size-{args.roi_size}'
    f'_symm-{args.maintain_symmetry}'
    f'_braindiv-{args.brain_div}'
    f'_nrois-{args.num_rois}'
)

In [4]:
args.GRAPH_DEF = f'constructed'
args.GRAPH_METHOD = f'pearson-corr'
args.THRESHOLDING = f'positive'
args.EDGE_DEF = f'binary'
args.EDGE_DENSITY = 20
args.LAYER_DEF = f'individual'
args.DATA_UNIT = f'sub'

BASE_path = f'{os.environ["HOME"]}/mouse_dataset'
PARCELS_path = f'{BASE_path}/parcels'
ROI_path = f'{BASE_path}/roi_results_v2/{PARC_DESC}'
TS_path = f'{ROI_path}/runwise_timeseries'
ROI_RESULTS_path = (
    f'{ROI_path}'
    f'/graph-{args.GRAPH_DEF}/method-{args.GRAPH_METHOD}'
    f'/threshold-{args.THRESHOLDING}/edge-{args.EDGE_DEF}/density-{args.EDGE_DENSITY}'
    f'/layer-{args.LAYER_DEF}/unit-{args.DATA_UNIT}'
)
GRAPH_path = f'{ROI_RESULTS_path}/graphs'
os.system(f'mkdir -p {GRAPH_path}')
SBM_path = f'{ROI_RESULTS_path}/model-fits'
os.system(f'mkdir -p {SBM_path}')
ESTIM_path = f'{ROI_RESULTS_path}/estimates'
os.system(f'mkdir -p {ESTIM_path}/individual')
os.system(f'mkdir -p {ESTIM_path}/group')

0

In [5]:
parcels_img = ants.image_read(f'{PARCELS_path}/{PARC_DESC}_desc-parcels.nii.gz')
parcels = parcels_img.numpy()
roi_labels = np.loadtxt(f'{PARCELS_path}/{PARC_DESC}_desc-labels.txt')

In [6]:
parcels_img

ANTsImage (LPI)
	 Pixel Type : float (float32)
	 Components : 1
	 Dimensions : (58, 79, 45)
	 Spacing    : (0.2, 0.2, 0.2)
	 Origin     : (18.1, 2.7, -7.8)
	 Direction  : [-1.  0.  0.  0. -1.  0.  0.  0.  1.]

In [7]:
def concatenate(in_files, out_file):
    try:
        os.remove(out_file)
    except:
        pass

    tcat = afni.TCat()
    tcat.inputs.in_files = in_files
    tcat.inputs.out_file = out_file
    tcat.inputs.rlt = ''
    tcat.cmdline 
    tcat.run()

    for file in in_files:
        try:
            os.remove(file)
        except:
            pass
    return None

def marginal_to_nifti(args, X, mode_id, folder, level=-1):
    os.system(f'mkdir -p {folder}')
    args.num_rois, args.num_comms = X.shape

    in_files = []
    for idx_comm in range(args.num_comms):
        x = X[:, idx_comm]
        x_img = np.zeros_like(parcels)
        for idx, roi in enumerate(roi_labels):
            x_img += (parcels == roi) * (x[idx])
        
        file = f'{folder}/mode-{mode_id}_comm-{idx_comm}.nii.gz'
        parcels_img.new_image_like(x_img).to_filename(file)
        in_files.append(file)

    if level == -1:
        out_file = f'{folder}/mode-{mode_id:02d}.nii.gz'
    else:
        out_file = f'{folder}/mode-{mode_id:02d}_level-{level}.nii.gz'
    concatenate(in_files, out_file)
    return None

In [8]:
args.dc, args.sbm = False, 'h'

args.nested = True if args.sbm in ['h'] else False

args.force_niter = 40000
args.num_draws = int((1/2) * args.force_niter)

def sbm_name(args):
    dc = f'dc' if args.dc else f'nd'
    dc = f'' if args.sbm in ['m', 'a'] else dc
    file = f'sbm-{dc}-{args.sbm}'
    return file

SBM = sbm_name(args)
SBM

'sbm-nd-h'

In [9]:
def get_membership_matrix(num_rois, df, col='pi'):
    pis = [np.zeros((num_rois, 1)) if np.isnan(pi).all() else pi for pi in df[col]]

    num_modes = len(df)
    num_comms = np.max([pi.shape[-1] for pi in pis])
    num_rois = num_rois
    M = np.zeros((num_rois, num_modes, num_comms)) # membership profile matrix

    for idx_mode, pi in enumerate(pis):
        M[:, idx_mode, :pi.shape[-1]] = pi
    
    return M

In [10]:
marginals_files = sorted(glob.glob(f'{ESTIM_path}/individual/sub-*/partition-modes-group-aligned/{SBM}/desc-mem-mats.pkl', recursive=True))
marginals_df = []
for sbm_file in marginals_files:
    with open(f'{sbm_file}', 'rb') as f:
        row = pickle.load(f)
    marginals_df += [row]
marginals_df = pd.concat(marginals_df).reset_index(drop=True)
mode_ids = list(chain.from_iterable([list(range(count)) for count in marginals_df['sub'].value_counts().sort_index().to_list()]))
marginals_df['mode_id'] = mode_ids
marginals_df

Unnamed: 0,sub,sbm,pi_0,pi_1,pi_2,pi_3,pi_4,pi_5,pi_6,pi_7,pi_8,omega,mode_id
0,SLC01,sbm-nd-h,"[[0.9901153212520593, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.8813838550247117, 0.0032948929159802307, 0...","[[0.957166392092257, 0.026359143327841845, 0.0...","[[0.9868204283360791, 0.013179571663920923], [...","[[0.9983525535420099, 0.0016474464579901153], ...","[[0.9983525535420099, 0.0016474464579901153], ...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,0.60744,0
1,SLC01,sbm-nd-h,"[[0.9892857142857143, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.9142857142857143, 0.0, 0.0, 0.0, 0.0857142...","[[0.9357142857142857, 0.039285714285714285, 0....","[[0.9857142857142858, 0.014285714285714285], [...","[[0.9964285714285714, 0.0035714285714285713], ...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,0.28032,1
2,SLC01,sbm-nd-h,"[[0.16964285714285715, 0.0, 0.0, 0.0, 0.0, 0.0...","[[0.7142857142857143, 0.0, 0.0, 0.0, 0.25, 0.0...","[[0.9107142857142857, 0.07142857142857142, 0.0...","[[0.9910714285714286, 0.008928571428571428], [...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,,0.11224,2
3,SLC02,sbm-nd-h,"[[0.9901153212520593, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.8813838550247117, 0.0032948929159802307, 0...","[[0.957166392092257, 0.026359143327841845, 0.0...","[[0.9868204283360791, 0.013179571663920923], [...","[[0.9983525535420099, 0.0016474464579901153], ...","[[0.9983525535420099, 0.0016474464579901153], ...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,0.50516,0
4,SLC02,sbm-nd-h,"[[0.9892857142857143, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.9142857142857143, 0.0, 0.0, 0.0, 0.0857142...","[[0.9357142857142857, 0.039285714285714285, 0....","[[0.9857142857142858, 0.014285714285714285], [...","[[0.9964285714285714, 0.0035714285714285713], ...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,0.338587,1
5,SLC02,sbm-nd-h,"[[0.16964285714285715, 0.0, 0.0, 0.0, 0.0, 0.0...","[[0.7142857142857143, 0.0, 0.0, 0.0, 0.25, 0.0...","[[0.9107142857142857, 0.07142857142857142, 0.0...","[[0.9910714285714286, 0.008928571428571428], [...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,,0.156253,2
6,SLC03,sbm-nd-h,"[[0.9901153212520593, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.8813838550247117, 0.0032948929159802307, 0...","[[0.957166392092257, 0.026359143327841845, 0.0...","[[0.9868204283360791, 0.013179571663920923], [...","[[0.9983525535420099, 0.0016474464579901153], ...","[[0.9983525535420099, 0.0016474464579901153], ...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,0.161628,0
7,SLC03,sbm-nd-h,"[[0.9892857142857143, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.9142857142857143, 0.0, 0.0, 0.0, 0.0857142...","[[0.9357142857142857, 0.039285714285714285, 0....","[[0.9857142857142858, 0.014285714285714285], [...","[[0.9964285714285714, 0.0035714285714285713], ...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,0.15539,1
8,SLC03,sbm-nd-h,"[[0.16964285714285715, 0.0, 0.0, 0.0, 0.0, 0.0...","[[0.7142857142857143, 0.0, 0.0, 0.0, 0.25, 0.0...","[[0.9107142857142857, 0.07142857142857142, 0.0...","[[0.9910714285714286, 0.008928571428571428], [...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,,,,0.143074,2
9,SLC03,sbm-nd-h,"[[0.9405940594059405, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.9821782178217822, 0.0, 0.0, 0.0, 0.0138613...","[[0.9465346534653465, 0.033663366336633666, 0....","[[0.9900990099009901, 0.009900990099009901, 0....","[[0.996039603960396, 0.0039603960396039604], [...","[[1.0, 0.0], [1.0, 0.0], [0.998019801980198, 0...","[[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0...","[[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1....",,0.12468,3


In [11]:
cols = [col for col in  list(marginals_df.columns) if 'pi_' in col]
cols

['pi_0', 'pi_1', 'pi_2', 'pi_3', 'pi_4', 'pi_5', 'pi_6', 'pi_7', 'pi_8']

In [12]:
# SOFT MARGINALS
soft_marginals_df = []
for sub, group in marginals_df.groupby('sub'):
    omegas = group['omega'].to_list()
    dct = {'sub': [sub], 'sbm': [SBM]}
    for col in cols:
        M = get_membership_matrix(args.num_rois, group, col=col)
        SCs = np.average(M, axis=1, weights=omegas) # soft-comms.
        dct[col] = [SCs]
    soft_marginals_df += [pd.DataFrame(dct)]
soft_marginals_df = pd.concat(soft_marginals_df).reset_index(drop=True)
soft_marginals_df

Unnamed: 0,sub,sbm,pi_0,pi_1,pi_2,pi_3,pi_4,pi_5,pi_6,pi_7,pi_8
0,SLC01,sbm-nd-h,"[[0.8977929364556366, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.8718518088962108, 0.0020014497528830313, 0...","[[0.9459391532125205, 0.035041312308778536, 0....","[[0.9869874867027537, 0.013012513297246411], [...","[[0.9979981322664156, 0.002001867733584373], [...","[[0.8867592751235585, 0.0010007248764415156], ...","[[0.60744], [0.60744], [0.60744], [0.60744], [...","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0....","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0...."
1,SLC02,sbm-nd-h,"[[0.8616335531836663, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.866414497438449, 0.0016644494656771884, 0....","[[0.9426447281233193, 0.03777812415626341, 0.0...","[[0.9871101321431203, 0.012889867856879716], [...","[[0.9979585356708509, 0.0020414643291491565], ...","[[0.8429152750671455, 0.0008322247328385942], ...","[[0.5051604128330267], [0.5051604128330267], [...","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0....","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0...."
2,SLC03,sbm-nd-h,"[[0.7347218143727945, 0.0, 0.01904950922275958...","[[0.8027296121914562, 0.09455013968026232, 0.0...","[[0.8805463280375332, 0.056739282130032, 0.0, ...","[[0.9787149023468772, 0.019748842924379402, 0....","[[0.9979020767759192, 0.0020979232240807954], ...","[[0.8163130106637612, 0.0006255689907493272], ...","[[0.5958893154190659, 0.0], [0.595350722336753...","[[0.2119321817018554, 0.0], [0.211393588619543...","[[0.08725207933461292], [0.08725207933461292],..."
3,SLC04,sbm-nd-h,"[[0.9459125043909121, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.8863237124765099, 0.0017297001457630258, 0...","[[0.9456396553691001, 0.03421771912783523, 0.0...","[[0.9865813022905506, 0.01341869770944947], [0...","[[0.9976294804485735, 0.002370519551426652], [...","[[0.9456865981588357, 0.0008648500728815129], ...","[[0.5249639942390784], [0.5249639942390784], [...","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0....","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0...."
4,SLC05,sbm-nd-h,"[[0.8130407774443612, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.874092392812236, 0.0012080019852788717, 0....","[[0.9404885581915815, 0.040065316915495336, 0....","[[0.9880332791031377, 0.011966249222332899, 2....","[[0.9977904353375576, 0.002209564662442389], [...","[[0.7951267313386685, 0.0006040009926394359], ...","[[0.5538055644387593, 0.0], [0.553805564438759...","[[0.18724016629357212], [0.18724016629357212],...","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0...."
5,SLC06,sbm-nd-h,"[[0.8973256100813988, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.8729552523304034, 0.0018802736107749298, 0...","[[0.9451368866106166, 0.03553384231794725, 0.0...","[[0.9869496611469873, 0.013050338853012768], [...","[[0.9979292773923479, 0.002070722607652064], [...","[[0.8862869286988916, 0.0009401368053874649], ...","[[0.5706630408701912], [0.5706630408701912], [...","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0....","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0...."
6,SLC07,sbm-nd-h,"[[0.916710109513164, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[[0.8846781446861861, 0.0014928319833808522, 0...","[[0.9436751094882899, 0.03600391625820136, 0.0...","[[0.9868477994721776, 0.01315172866466849, 2.3...","[[0.9975946356212689, 0.002405364378731002], [...","[[0.9125543918144361, 0.0007464159916904261], ...","[[0.49172198672318645, 0.0], [0.49172198672318...","[[0.03871070942973686], [0.03871070942973686],...","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0...."
7,SLC08,sbm-nd-h,"[[0.8483587803360707, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.8785646411340852, 0.0013375125218844722, 0...","[[0.9419809372493975, 0.03843527216566195, 0.0...","[[0.9877036299233776, 0.012296370076622522, 0....","[[0.9977338981828051, 0.0022661018171949614], ...","[[0.8359973771577107, 0.0006687562609422361], ...","[[0.5576707726763718, 0.0], [0.557670772676371...","[[0.1517357222844345], [0.1517357222844345], [...","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0...."
8,SLC09,sbm-nd-h,"[[0.8244173424603205, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0.8744602200433057, 0.0013709832785897586, 0...","[[0.9417882510415184, 0.039047373507553454, 0....","[[0.9879771131182439, 0.01202288688175625, 0.0...","[[0.9978421725058705, 0.0021578274941295343], ...","[[0.8083050698708635, 0.0006854916392948793], ...","[[0.5938249880019197, 0.0], [0.593824988001919...","[[0.177731562949928], [0.177731562949928], [0....","[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0...."
9,SLC10,sbm-nd-h,"[[0.7989304771524457, 0.0, 0.01264977847769932...","[[0.8338968592681216, 0.06031522983244758, 0.0...","[[0.903433128280149, 0.04956069161920257, 0.0,...","[[0.9811809514725756, 0.017688507861423543, 0....","[[0.9979195244653538, 0.0020804755346461056], ...","[[0.8500143886920136, 0.000805939439238832], [...","[[0.6384153661464586, 0.0], [0.638053740014524...","[[0.1621448579431773, 0.0], [0.161783231811243...","[[0.05858343337334934], [0.05858343337334934],..."


In [13]:
for idx, row in marginals_df.iterrows():
    for col in cols:
        pi = row[col]
        mode_id = row['mode_id']
        sub = row['sub']
        level = col.split('_')[-1]
        if level == 'aligned': level = -1
        folder = f'{ESTIM_path}/individual/sub-{sub}/partition-modes-group-aligned/{SBM}/marginal-visuals/nii'
        if np.any(np.isnan(pi)): continue
        marginal_to_nifti(args, pi, mode_id, folder, level=level)

250226-12:03:09,310 nipype.interface INFO:
	 stderr 2025-02-26T12:03:09.310254:++ 3dTcat: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
250226-12:03:09,321 nipype.interface INFO:
250226-12:03:09,322 nipype.interface INFO:
250226-12:03:09,497 nipype.interface INFO:
	 stderr 2025-02-26T12:03:09.497227:++ elapsed time = 0.2 s
250226-12:03:10,873 nipype.interface INFO:
	 stderr 2025-02-26T12:03:10.873695:++ 3dTcat: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
250226-12:03:10,887 nipype.interface INFO:
250226-12:03:10,889 nipype.interface INFO:
250226-12:03:11,28 nipype.interface INFO:
	 stderr 2025-02-26T12:03:11.028640:++ elapsed time = 0.2 s
250226-12:03:11,619 nipype.interface INFO:
	 stderr 2025-02-26T12:03:11.619611:++ 3dTcat: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
250226-12:03:11,624 nipype.interface INFO:
250226-12:03:11,627 nipype.interface INFO:
250226-12:03:11,692 nipype.interface INFO:
	 stderr 2025-02-26T12:03:11.692150:++ elapsed time = 0.1 s
250226-12:03:12,4

In [14]:
folder = f'{ESTIM_path}/group/membership-mats-group-aligned/{SBM}/marginal-visuals/nii'
os.system(f'mkdir -p {folder}')

for col in cols:
    M = get_membership_matrix(args.num_rois, soft_marginals_df, col=col)
    grp_pi = np.mean(M, axis=1)
    
    level = col.split('_')[-1]
    if level == 'aligned': level = -1

    marginal_to_nifti(args, grp_pi, 0, folder, level)
    # break


250226-12:08:14,284 nipype.interface INFO:
	 stderr 2025-02-26T12:08:14.284653:++ 3dTcat: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
250226-12:08:14,311 nipype.interface INFO:
250226-12:08:14,313 nipype.interface INFO:
250226-12:08:14,561 nipype.interface INFO:
	 stderr 2025-02-26T12:08:14.561175:++ elapsed time = 0.3 s
250226-12:08:16,47 nipype.interface INFO:
	 stderr 2025-02-26T12:08:16.047850:++ 3dTcat: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
250226-12:08:16,61 nipype.interface INFO:
250226-12:08:16,64 nipype.interface INFO:
250226-12:08:16,210 nipype.interface INFO:
	 stderr 2025-02-26T12:08:16.210149:++ elapsed time = 0.2 s
250226-12:08:17,10 nipype.interface INFO:
	 stderr 2025-02-26T12:08:17.010641:++ 3dTcat: AFNI version=AFNI_20.2.18 (Sep 17 2020) [64-bit]
250226-12:08:17,15 nipype.interface INFO:
250226-12:08:17,17 nipype.interface INFO:
250226-12:08:17,104 nipype.interface INFO:
	 stderr 2025-02-26T12:08:17.104558:++ elapsed time = 0.1 s
250226-12:08:17,778 ni