# Aug 22, 2024: individual level estimates

moved to zaratan as 02c-*.sh

In [1]:
import csv
import os
import sys
import numpy as np
import pandas as pd
import scipy as sp 
import dill as pickle 
from os.path import join as pjoin
from itertools import product
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
import subprocess
from scipy import sparse, stats
from multiprocessing import Pool
import glob
import random

import arviz as az

from itertools import product, combinations
import multiprocessing as mp
from functools import partial

# networks
import graph_tool.all as gt

# plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.cm import rainbow

plt.rcParamsDefault['font.family'] = "sans-serif"
plt.rcParamsDefault['font.sans-serif'] = "Arial"
plt.rcParams['font.size'] = 14
plt.rcParams["errorbar.capsize"] = 0.5

import cmasher as cmr  # CITE ITS PAPER IN YOUR MANUSCRIPT

# ignore user warnings
import warnings
warnings.filterwarnings("ignore") #, category=UserWarning)

In [2]:
class ARGS():
    pass

args = ARGS()

args.SEED = 100

def set_seed(args):
    gt.seed_rng(args.SEED)
    np.random.seed(args.SEED)

set_seed(args)

In [3]:
args.type = 'spatial'
args.roi_size = 225
args.maintain_symmetry = True
args.brain_div = 'whl'
args.num_rois = 162

PARC_DESC = (
    f'type-{args.type}'
    f'_size-{args.roi_size}'
    f'_symm-{args.maintain_symmetry}'
    f'_braindiv-{args.brain_div}'
    f'_nrois-{args.num_rois}'
)

In [4]:
args.GRAPH_DEF = f'constructed'
args.GRAPH_METHOD = f'pearson-corr'
args.THRESHOLDING = f'positive'
args.EDGE_DEF = f'binary'
args.EDGE_DENSITY = 10
args.LAYER_DEF = f'individual'
args.DATA_UNIT = f'ses'

BASE_path = f'{os.environ["HOME"]}/mouse_dataset/roi_results_v2'
ROI_path = f'{BASE_path}/{PARC_DESC}'
TS_path = f'{ROI_path}/runwise_timeseries'
ROI_RESULTS_path = (
    f'{ROI_path}'
    f'/graph-{args.GRAPH_DEF}/method-{args.GRAPH_METHOD}'
    f'/threshold-{args.THRESHOLDING}/edge-{args.EDGE_DEF}/density-{args.EDGE_DENSITY}'
    f'/layer-{args.LAYER_DEF}/unit-{args.DATA_UNIT}'
)
GRAPH_path = f'{ROI_RESULTS_path}/graphs'
os.system(f'mkdir -p {GRAPH_path}')
SBM_path = f'{ROI_RESULTS_path}/model-fits'
os.system(f'mkdir -p {SBM_path}')
ESTIM_path = f'{ROI_RESULTS_path}/estimates'
os.system(f'mkdir -p {ESTIM_path}/individual')
os.system(f'mkdir -p {ESTIM_path}/group')

0

In [5]:
args.dc = True
args.sbm = 'd'

args.nested = args.sbm == 'h'

args.force_niter = 40000
args.num_draws = int((1/2) * args.force_niter)

def sbm_name(args):
    dc = f'dc' if args.dc else f'nd'
    dc = f'' if args.sbm in ['a'] else dc
    file = f'sbm-{dc}-{args.sbm}'
    return file

SBM = sbm_name(args)
SBM

'sbm-dc-d'

In [6]:
args.num_samples = 100 # from a chain for aggregating

In [7]:
def collect_modes_single_chain(args, sbm_file):
    fs = sbm_file.split('/')

    g = gt.load_graph(f"{GRAPH_path}/{'_'.join([fs[-4], 'desc-graph.gt.gz'])}")

    with open(sbm_file, 'rb') as f:
        [modes] = pickle.load(f)

    dfs = []
    for idx, mode in enumerate(modes):
        df = pd.DataFrame({})
        for desc in fs[-4].split('_'): # sub ses
            D = desc.split('-')
            df[D[0]] = D[1:]

        D = fs[-3] # sbm
        df['sbm'] = D

        D = fs[-2].split('-') # B
        df[D[0]] = '_'.join(D[1:])

        df['graph'] = [g]
        b_hat = list(mode.get_max(g)) if not args.sbm in ['h'] else mode.get_max_nested()

        df['mode_id'] = [idx]
        df['mode'] = [mode]
        df['b_hat'] = [b_hat]
        df['omega'] = [np.round(mode.get_M()/args.num_draws, 3)]
        df['sigma'] = [np.round(mode.posterior_cdev(), 3)]
        ratio = df['omega'][0] / df['sigma'][0]
        ratio = ratio if not np.isnan(ratio) else 0.0
        df['ratio'] = [ratio]
        df['num_samples'] = [round(args.num_samples*df['omega'][0])]
        
        dfs += [df]
    dfs = pd.concat(dfs).reset_index(drop=True)
    return dfs

def collect_modes(args, sbm_files):
    sbm_dfs = []
    for sbm_file in tqdm(sbm_files):
        df = collect_modes_single_chain(args, sbm_file)
        sbm_dfs += [df]
    sbm_dfs = pd.concat(sbm_dfs).reset_index(drop=True)
    sbm_dfs = sbm_dfs.sort_values(['sub', 'ses', 'sbm', 'B'])
    return sbm_dfs

def sample_partitions(args, sbm_dfs):
    all_bs = []
    for idx, row in tqdm(sbm_dfs.iterrows()):
        bs = random.sample(list(row['mode'].get_partitions().values()), row['num_samples'])
        all_bs += bs
        # all_bs += [row['mode'].sample_partition(MLE=True) for _ in range(row['num_samples'])]
        # all_bs += [row['b_hat']]
    return all_bs

def sample_nested_partitions(args, sbm_dfs):
    all_bs = []
    for idx, row in tqdm(sbm_dfs.iterrows()):
        bs = random.sample(list(row['mode'].get_nested_partitions().values()), row['num_samples'])
        bs = [gt.nested_partition_clear_null(b) for b in bs]
        all_bs += bs
        # all_bs += [row['mode'].sample_partition(MLE=True) for _ in range(row['num_samples'])]
        # all_bs += [row['b_hat']]
    return all_bs

def posterior_modes(args, bs):
    pmode = gt.ModeClusterState(bs, nested=args.nested)
    gt.mcmc_equilibrate(pmode, wait=1, mcmc_args=dict(niter=1, beta=np.inf))
    return pmode

def catalog_modes(args, cmode, g, all_bs):
    indiv_dfs = []
    for idx, mode in enumerate(cmode.get_modes()):
        b_hat = list(mode.get_max(g)) if not args.sbm in ['h'] else mode.get_max_nested()
        omega = np.round(mode.get_M()/len(all_bs), 3)
        sigma = np.round(mode.posterior_cdev(MLE=False), 3)
        ratio = omega / sigma
        ratio = np.round(ratio,3) if not np.isnan(ratio) else 0.0
        df = pd.DataFrame(dict(
            mode_id=[idx],
            mode=[mode],
            b_hat=[b_hat],
            omega=[omega],
            sigma=[sigma],
            ratio=[ratio],
        ))
        indiv_dfs += [df]
    indiv_dfs = pd.concat(indiv_dfs).reset_index(drop=True)
    return indiv_dfs

def nested_partitions(g, b):
    b = gt.nested_partition_clear_null(b)
    state = gt.NestedBlockState(g, bs=[g.new_vp("int", vals=b[0])] + b[1:])
    state = state.copy(bs=b)
    bs = []
    for l, bl in enumerate(b):
        bl_ = np.array(state.project_level(l).get_state().a)
        bs.append(bl_)
        if len(np.unique(bl_)) == 1: break
    return bs

def post_align_modes(args, indiv_dfs, g):
    gcmode = gt.ModeClusterState(indiv_dfs['b_hat'].to_list(), nested=args.nested)
    if args.sbm in ['a', 'd']:
        indiv_dfs['b_hat'] = [b[0] for b in gcmode.bs]
    elif args.sbm in ['h']:
        indiv_dfs['b_hat'] = [nested_partitions(g, bs) for bs in gcmode.bs]
    return indiv_dfs

In [8]:
def individual_level_estimates(args, sub):
    print(f'sub {sub}')
    sbm_files = sorted(glob.glob(f'{SBM_path}/*{sub}*/{SBM}/*/desc-partition-modes.pkl', recursive=True))
    sbm_dfs = collect_modes(args, sbm_files)

    g = sbm_dfs.iloc[0]['graph']

    # sample b's from mode proportional to omega
    if args.sbm in ['h']:
        all_bs = sample_nested_partitions(args, sbm_dfs)
    elif args.sbm in ['a', 'd']:
        all_bs = sample_partitions(args, sbm_dfs)

    cmode = posterior_modes(args, all_bs)
    indiv_dfs = catalog_modes(args, cmode, g, all_bs)
    indiv_dfs = post_align_modes(args, indiv_dfs, g)
    display(indiv_dfs)
    return indiv_dfs, sbm_dfs

In [11]:
for sub in np.arange(1, 11):
    set_seed(args)
    
    sub = f'SLC{sub:02d}'
    indiv_dfs, sbm_dfs = individual_level_estimates(args, sub=sub)

    folder = f'{ESTIM_path}/individual/sub-{sub}/partition-modes'
    os.system(f'mkdir -p {folder}')
    with open(f'{folder}/{SBM}_desc-df.pkl', 'wb') as f:
        pickle.dump(indiv_dfs, f)
    break

sub SLC01


  0%|          | 0/15 [00:00<?, ?it/s]

100%|██████████| 15/15 [02:12<00:00,  8.86s/it]
115it [00:00, 309.90it/s]


Unnamed: 0,mode_id,mode,b_hat,omega,sigma,ratio
0,0,<graph_tool.inference.partition_modes.Partitio...,"[0, 1, 2, 3, 3, 0, 4, 0, 4, 4, 0, 2, 5, 0, 0, ...",0.334,0.131,2.55
1,1,<graph_tool.inference.partition_modes.Partitio...,"[0, 2, 2, 3, 3, 0, 4, 0, 5, 4, 0, 2, 5, 3, 0, ...",0.324,0.126,2.571
2,2,<graph_tool.inference.partition_modes.Partitio...,"[5, 2, 11, 0, 3, 0, 4, 0, 5, 4, 0, 11, 5, 0, 0...",0.181,0.176,1.028
3,3,<graph_tool.inference.partition_modes.Partitio...,"[5, 2, 11, 6, 7, 0, 4, 3, 5, 4, 0, 11, 5, 3, 3...",0.161,0.158,1.019


In [10]:
# sub = 'SLC02'
# folder = f'{ESTIM_path}/individual/sub-{sub}/partition-modes'
# with open(f'{folder}/{SBM}_desc-df.pkl', 'rb') as f:
#     indiv_dfs = pickle.load(f)
# indiv_dfs