# Sep 2, 2024: group align individual level estimates (partition modes)

the below script is run by the .py and .sh files

In [1]:
import csv
import os
import sys
import numpy as np
import pandas as pd
import scipy as sp 
import dill as pickle 
from os.path import join as pjoin
from itertools import product
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
import subprocess
from scipy import sparse, stats
from multiprocessing import Pool
import glob
import random

import arviz as az

from itertools import product, combinations
import multiprocessing as mp
from functools import partial

# networks
import graph_tool.all as gt

# plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.cm import rainbow

plt.rcParamsDefault['font.family'] = "sans-serif"
plt.rcParamsDefault['font.sans-serif'] = "Arial"
plt.rcParams['font.size'] = 14
plt.rcParams["errorbar.capsize"] = 0.5

import cmasher as cmr  # CITE ITS PAPER IN YOUR MANUSCRIPT
import colorcet as cc

# ignore user warnings
import warnings
warnings.filterwarnings("ignore") #, category=UserWarning)

In [2]:
class ARGS():
    pass

args = ARGS()

args.SEED = 100

def set_seed(args):
    gt.seed_rng(args.SEED)
    np.random.seed(args.SEED)

set_seed(args)

In [3]:
args.type = 'spatial'
args.roi_size = 225
args.maintain_symmetry = True
args.brain_div = 'whl'
args.num_rois = 162

PARC_DESC = (
    f'type-{args.type}'
    f'_size-{args.roi_size}'
    f'_symm-{args.maintain_symmetry}'
    f'_braindiv-{args.brain_div}'
    f'_nrois-{args.num_rois}'
)

In [4]:
args.GRAPH_DEF = f'constructed'
args.GRAPH_METHOD = f'pearson-corr'
args.THRESHOLDING = f'positive'
args.EDGE_DEF = f'binary'
args.EDGE_DENSITY = 10
args.LAYER_DEF = f'individual'
args.DATA_UNIT = f'ses'

BASE_path = f'{os.environ["HOME"]}/mouse_dataset/roi_results_v2'
ROI_path = f'{BASE_path}/{PARC_DESC}'
TS_path = f'{ROI_path}/runwise_timeseries'
ROI_RESULTS_path = (
    f'{ROI_path}'
    f'/graph-{args.GRAPH_DEF}/method-{args.GRAPH_METHOD}'
    f'/threshold-{args.THRESHOLDING}/edge-{args.EDGE_DEF}/density-{args.EDGE_DENSITY}'
    f'/layer-{args.LAYER_DEF}/unit-{args.DATA_UNIT}'
)
GRAPH_path = f'{ROI_RESULTS_path}/graphs'
os.system(f'mkdir -p {GRAPH_path}')
SBM_path = f'{ROI_RESULTS_path}/model-fits'
os.system(f'mkdir -p {SBM_path}')
ESTIM_path = f'{ROI_RESULTS_path}/estimates'
os.system(f'mkdir -p {ESTIM_path}/individual')
os.system(f'mkdir -p {ESTIM_path}/group')

0

In [5]:
def collect_indiv_dfs(args, indiv_files):
    # individual estimates, per animal
    indiv_dfs = []
    for indiv_file in tqdm(indiv_files):
        with open(indiv_file, 'rb') as f:
            indiv_df = pickle.load(f)
        sub = indiv_file.split('/')[-3].split('-')[-1]
        indiv_df['sub'] = [sub]*len(indiv_df)
        cols = list(indiv_df.columns)
        reordered_cols = [cols[-1]] + cols[:-1]
        indiv_df = indiv_df.reindex(columns=reordered_cols)
        indiv_dfs += [indiv_df]
        # break
    indiv_dfs = pd.concat(indiv_dfs).reset_index(drop=True)
    return indiv_dfs

In [6]:
def nested_partitions(g, b):
    b = gt.nested_partition_clear_null(b)
    state = gt.NestedBlockState(g, bs=[g.new_vp("int", vals=b[0])] + b[1:])
    state = state.copy(bs=b)
    bs = []
    for l, bl in enumerate(b):
        bl_ = np.array(state.project_level(l).get_state().a)
        bs.append(bl_)
        if len(np.unique(bl_)) == 1: break
    return bs

def project_partitions_on_graph(args, mode, max_level=-1):
    proj_bs = []
    for bs in tqdm(list(mode.get_nested_partitions().values())):
        proj_bs += [nested_partitions(g, bs)]
    max_level = np.max([len(bs) for bs in proj_bs]) if max_level == -1 else max_level
    level_bs = [[] for _ in range(max_level)]
    for bs in proj_bs:
        for level in range(max_level):
            level_bs[level] += [bs[level] if len(bs) > level else [0]*len(bs[0])]
    return level_bs, max_level

def align_nested_mode_to_pmode(args, mode, pmode_level_bs, pmode_max_level):
    mode_level_bs, mode_max_level = project_partitions_on_graph(args, mode, max_level=pmode_max_level)
    
    gmodes = []
    for level in tqdm(range(pmode_max_level)):
        gmode = gt.PartitionModeState(mode_level_bs[level], relabel=False, nested=False, converge=True)
        pmode_level = gt.PartitionModeState(pmode_level_bs[level], relabel=False, nested=False, converge=False)
        gmode.align_mode(pmode_level)
        gmodes += [gmode]
    
    return gmodes

def get_pi_matrix(args, mrgnls):
    num_comms = np.max([len(mrgnl) for mrgnl in mrgnls])
    pi = np.zeros((len(mrgnls), num_comms))

    for idx_node, mrgnl in enumerate(mrgnls):
        mrgnl = np.array(mrgnl)
        pi[idx_node, np.where(mrgnl)[0]] = mrgnl[mrgnl > 0]

    pi = pi / np.expand_dims(pi.sum(axis=-1), axis=-1)
    return pi # marginals matrix

def get_nested_marginals(args, g, level_modes):
    marginals = [list(level_mode.get_marginal(g)) for level_mode in level_modes]
    pis = {}
    for level, mrgnls in enumerate(marginals):
        pis[level] = get_pi_matrix(args, mrgnls)
    return pis

def collect_marginals_single_mode(args, row, pi):
    df = pd.DataFrame()
    df['sub'] = [row['sub']]
    df['mode_id'] = [row['mode_id']]
    df['pi'] = [pi]
    df['omega'] = [row['omega']]
    df['sigma'] = [row['sigma']]
    df['ratio'] = [row['ratio']]
    return df

def collect_nested_marginals_single_mode(args, row, pis):
    dfs = []
    for level, pi in pis.items():
        df = pd.DataFrame()
        df['sub'] = [row['sub']]
        df['mode_id'] = [row['mode_id']]
        df['level'] = [level]
        df['pi'] = [pi]
        df['omega'] = [row['omega']]
        df['sigma'] = [row['sigma']]
        df['ratio'] = [row['ratio']]
        dfs += [df]

    dfs = pd.concat(dfs).reset_index(drop=True)
    return dfs

In [7]:
def post_align_modes(args, indiv_dfs):
    pmode = gt.PartitionModeState(indiv_dfs['b_hat'].to_list(), nested=args.nested, converge=True)
    pmode_level_bs, pmode_max_level = project_partitions_on_graph(args, pmode)

    if args.sbm in ['h']:
        indiv_marginals_dfs = []
        for idx, row in indiv_dfs.iterrows():
            mode = row['mode']
            aligned_level_modes = align_nested_mode_to_pmode(args, mode, pmode_level_bs, pmode_max_level)
            pis = get_nested_marginals(args, g, aligned_level_modes)
            marginal_df = collect_nested_marginals_single_mode(args, row, pis)
            indiv_marginals_dfs += [marginal_df]
            # break
        indiv_marginals_dfs = pd.concat(indiv_marginals_dfs).reset_index(drop=True)

    if args.sbm in ['a', 'd']:
        indiv_marginals_dfs = []
        for idx, row in tqdm(indiv_dfs.iterrows()):
            mode = row['mode']
            mode.align_mode(pmode) # align to the group 
            mrgnls = list(mode.get_marginal(g))
            pi = get_pi_matrix(args, mrgnls)
            df = collect_marginals_single_mode(args, row, pi)
            indiv_marginals_dfs += [df]
            # break
        indiv_marginals_dfs = pd.concat(indiv_marginals_dfs).reset_index(drop=True)
    
    return indiv_marginals_dfs
    

In [8]:
sbms_list = [
    (True, 'a'),
    (True, 'd'),
    (True, 'h'),
    (False, 'd'),
    (False, 'h'),
]

In [9]:
graph_file = sorted(glob.glob(f'{GRAPH_path}/*', recursive=True))[0]
g = gt.load_graph(graph_file)
g

<Graph object, undirected, with 162 vertices and 1304 edges, 1 internal edge property, at 0x7f9cc411ab20>

In [10]:
for sbm in sbms_list[:]:    
    args.dc, args.sbm = sbm

    args.nested = args.sbm == 'h'

    args.force_niter = 40000
    args.num_draws = int((1/2) * args.force_niter)

    def sbm_name(args):
        dc = f'dc' if args.dc else f'nd'
        dc = f'' if args.sbm in ['a'] else dc
        file = f'sbm-{dc}-{args.sbm}'
        return file

    SBM = sbm_name(args)
    print(SBM)

    indiv_files = sorted(glob.glob(f'{ESTIM_path}/individual/sub-*/partition-modes/{SBM}_desc-df.pkl', recursive=True))

    indiv_dfs = collect_indiv_dfs(args, indiv_files)

    indiv_marginals_dfs = post_align_modes(args, indiv_dfs)

    for sub in tqdm(indiv_marginals_dfs['sub'].unique()):
        folder = f'{ESTIM_path}/individual/sub-{sub}/partition-modes-group-aligned/{SBM}'
        os.system(f'mkdir -p {folder}')
        with open(f'{folder}/desc-marginals-df.pkl', 'wb') as f:
            pickle.dump(indiv_marginals_dfs[indiv_marginals_dfs['sub'] == sub], f)
        # break

sbm--a


100%|██████████| 10/10 [01:10<00:00,  7.01s/it]
100%|██████████| 100/100 [00:01<00:00, 82.82it/s]
100it [00:07, 13.91it/s]
100%|██████████| 10/10 [00:00<00:00, 183.70it/s]


sbm-dc-d


100%|██████████| 10/10 [01:10<00:00,  7.03s/it]
100%|██████████| 132/132 [00:01<00:00, 84.29it/s]
132it [00:09, 13.79it/s]
100%|██████████| 10/10 [00:00<00:00, 140.55it/s]


sbm-dc-h


100%|██████████| 10/10 [08:31<00:00, 51.18s/it]
100%|██████████| 115/115 [00:11<00:00, 10.40it/s]
100%|██████████| 3500/3500 [05:34<00:00, 10.46it/s]
100%|██████████| 5/5 [00:10<00:00,  2.06s/it]
100%|██████████| 3014/3014 [04:50<00:00, 10.39it/s]
100%|██████████| 5/5 [00:09<00:00,  1.84s/it]
100%|██████████| 2542/2542 [04:00<00:00, 10.55it/s]
100%|██████████| 5/5 [00:07<00:00,  1.50s/it]
100%|██████████| 1776/1776 [02:49<00:00, 10.49it/s]
100%|██████████| 5/5 [00:04<00:00,  1.08it/s]
100%|██████████| 1498/1498 [02:23<00:00, 10.45it/s]
100%|██████████| 5/5 [00:04<00:00,  1.23it/s]
100%|██████████| 1084/1084 [01:44<00:00, 10.40it/s]
100%|██████████| 5/5 [00:02<00:00,  1.75it/s]
100%|██████████| 906/906 [01:27<00:00, 10.33it/s]
100%|██████████| 5/5 [00:02<00:00,  2.09it/s]
100%|██████████| 681/681 [01:05<00:00, 10.44it/s]
100%|██████████| 5/5 [00:02<00:00,  2.17it/s]
100%|██████████| 6/6 [00:00<00:00,  9.87it/s]
100%|██████████| 5/5 [00:00<00:00, 310.51it/s]
100%|██████████| 3310/3310 [0

sbm-nd-d


100%|██████████| 10/10 [01:09<00:00,  6.96s/it]
100%|██████████| 116/116 [00:01<00:00, 86.94it/s]
116it [00:08, 14.07it/s]
100%|██████████| 10/10 [00:00<00:00, 123.96it/s]


sbm-nd-h


100%|██████████| 10/10 [08:56<00:00, 53.65s/it]
100%|██████████| 94/94 [00:08<00:00, 10.49it/s]
100%|██████████| 3419/3419 [05:23<00:00, 10.56it/s]
100%|██████████| 4/4 [00:15<00:00,  3.94s/it]
100%|██████████| 2599/2599 [04:14<00:00, 10.22it/s]
100%|██████████| 4/4 [00:08<00:00,  2.13s/it]
100%|██████████| 2398/2398 [03:48<00:00, 10.50it/s]
100%|██████████| 4/4 [00:09<00:00,  2.27s/it]
100%|██████████| 2110/2110 [03:19<00:00, 10.57it/s]
100%|██████████| 4/4 [00:06<00:00,  1.74s/it]
100%|██████████| 1588/1588 [02:33<00:00, 10.36it/s]
100%|██████████| 4/4 [00:07<00:00,  1.75s/it]
100%|██████████| 1576/1576 [02:31<00:00, 10.42it/s]
100%|██████████| 4/4 [00:06<00:00,  1.66s/it]
100%|██████████| 805/805 [01:16<00:00, 10.46it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]
100%|██████████| 505/505 [00:48<00:00, 10.32it/s]
100%|██████████| 4/4 [00:01<00:00,  3.29it/s]
100%|██████████| 3498/3498 [05:36<00:00, 10.38it/s]
100%|██████████| 4/4 [00:10<00:00,  2.59s/it]
100%|██████████| 3459/3459