# Nov 25, 2024: membership histograms
for each community per mode marginals and soft-marginals

In [1]:
import csv
import os
import sys
import numpy as np
import pandas as pd
import scipy as sp 
import dill as pickle 
from os.path import join as pjoin
from itertools import product
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
import subprocess
from scipy import sparse, stats
from multiprocessing import Pool
import glob
import random

import arviz as az

import ants
from nipype.interfaces import afni

from itertools import product, combinations, chain
import multiprocessing as mp
from functools import partial

# networks
import graph_tool.all as gt

# plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.cm import rainbow

plt.rcParamsDefault['font.family'] = "sans-serif"
plt.rcParamsDefault['font.sans-serif'] = "Arial"
plt.rcParams['font.size'] = 14
plt.rcParams["errorbar.capsize"] = 0.5

import cmasher as cmr  # CITE ITS PAPER IN YOUR MANUSCRIPT
import colorcet as cc

# ignore user warnings
import warnings
warnings.filterwarnings("ignore") #, category=UserWarning)

In [2]:
class ARGS():
    pass

args = ARGS()

args.SEED = 100

def set_seed(args):
    gt.seed_rng(args.SEED)
    np.random.seed(args.SEED)

set_seed(args)

In [3]:
args.type = 'spatial'
args.roi_size = 225
args.maintain_symmetry = True
args.brain_div = 'whl'
args.num_rois = 162

PARC_DESC = (
    f'type-{args.type}'
    f'_size-{args.roi_size}'
    f'_symm-{args.maintain_symmetry}'
    f'_braindiv-{args.brain_div}'
    f'_nrois-{args.num_rois}'
)

In [4]:
args.GRAPH_DEF = f'constructed'
args.GRAPH_METHOD = f'pearson-corr'
args.THRESHOLDING = f'positive'
args.EDGE_DEF = f'binary'
args.EDGE_DENSITY = 20
args.LAYER_DEF = f'individual'
args.DATA_UNIT = f'sub'

BASE_path = f'{os.environ["HOME"]}/mouse_dataset'
PARCELS_path = f'{BASE_path}/parcels'
ROI_path = f'{BASE_path}/roi_results_v2/{PARC_DESC}'
TS_path = f'{ROI_path}/runwise_timeseries'
ROI_RESULTS_path = (
    f'{ROI_path}'
    f'/graph-{args.GRAPH_DEF}/method-{args.GRAPH_METHOD}'
    f'/threshold-{args.THRESHOLDING}/edge-{args.EDGE_DEF}/density-{args.EDGE_DENSITY}'
    f'/layer-{args.LAYER_DEF}/unit-{args.DATA_UNIT}'
)
GRAPH_path = f'{ROI_RESULTS_path}/graphs'
os.system(f'mkdir -p {GRAPH_path}')
SBM_path = f'{ROI_RESULTS_path}/model-fits'
os.system(f'mkdir -p {SBM_path}')
ESTIM_path = f'{ROI_RESULTS_path}/estimates'
os.system(f'mkdir -p {ESTIM_path}/individual')
os.system(f'mkdir -p {ESTIM_path}/group')

0

In [5]:
args.dc, args.sbm = True, 'm'

args.nested = True if args.sbm in ['h'] else False

args.force_niter = 40000
args.num_draws = int((1/2) * args.force_niter)

def sbm_name(args):
    dc = f'dc' if args.dc else f'nd'
    dc = f'' if args.sbm in ['m', 'a'] else dc
    file = f'sbm-{dc}-{args.sbm}'
    return file

SBM = sbm_name(args)
SBM

'sbm--m'

In [6]:
def get_membership_matrix(num_rois, df, col='pi'):
    pis = [np.zeros((num_rois, 1)) if np.isnan(pi).all() else pi for pi in df[col]]

    num_modes = len(df)
    num_comms = np.max([pi.shape[-1] for pi in pis])
    num_rois = num_rois
    M = np.zeros((num_rois, num_modes, num_comms)) # membership profile matrix

    for idx_mode, pi in enumerate(pis):
        M[:, idx_mode, :pi.shape[-1]] = pi
    
    return M

In [7]:
marginals_files = sorted(glob.glob(f'{ESTIM_path}/individual/sub-*/partition-modes-group-aligned/{SBM}/desc-mem-mats.pkl', recursive=True))
marginals_df = []
for sbm_file in marginals_files:
    with open(f'{sbm_file}', 'rb') as f:
        row = pickle.load(f)
    marginals_df += [row]
marginals_df = pd.concat(marginals_df).reset_index(drop=True)
mode_ids = list(chain.from_iterable([list(range(count)) for count in marginals_df['sub'].value_counts().sort_index().to_list()]))
marginals_df['mode_id'] = mode_ids
marginals_df

Unnamed: 0,sub,sbm,pi_aligned,omega,mode_id
0,SLC01,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0....",0.806483,0
1,SLC01,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1....",0.165306,1
2,SLC01,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [...",0.019968,2
3,SLC01,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",0.008243,3
4,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0....",0.3338,0
5,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1....",0.3292,1
6,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [...",0.26572,2
7,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",0.07104,3
8,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.00024,4
9,SLC03,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0....",0.575495,0


In [8]:
cols = [col for col in  list(marginals_df.columns) if 'pi_' in col]
cols

['pi_aligned']

In [9]:
# SOFT MARGINALS
soft_marginals_df = []
for sub, group in marginals_df.groupby('sub'):
    omegas = group['omega'].to_list()
    dct = {'sub': [sub], 'sbm': [SBM]}
    for col in cols:
        M = get_membership_matrix(args.num_rois, group, col=col)
        SCs = np.average(M, axis=1, weights=omegas) # soft-comms.
        dct[col] = [SCs]
    soft_marginals_df += [pd.DataFrame(dct)]
soft_marginals_df = pd.concat(soft_marginals_df).reset_index(drop=True)
soft_marginals_df

Unnamed: 0,sub,sbm,pi_aligned
0,SLC01,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0...."
1,SLC02,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
2,SLC03,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0...."
3,SLC04,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
4,SLC05,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0...."
5,SLC06,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0...."
6,SLC07,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0...."
7,SLC08,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0...."
8,SLC09,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0...."
9,SLC10,sbm--m,"[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0...."


historgams 

In [10]:
def mem_bins_per_marginal(args, pi):
    if np.isnan(pi).all(): return np.nan
    # membweahip histogram per marginal
    bins = np.arange(0, 1+args.binwidth, args.binwidth)
    thresh = (bins[1] + bins[0]) / 2 # if bins are 0, 2.5, ..., then thresh = 0.25 / 2
    h = np.zeros((len(bins)-1, pi.shape[-1]))
    for col in range(pi.shape[-1]):
        vals = pi[:, col]
        vals = vals[vals > 0.05]
        hist, _ = np.histogram(vals, bins=bins)
        h[:, col] = hist
    h /= np.expand_dims(np.sum(h, axis=0), axis=0)
    h = np.round(np.nan_to_num(h), decimals=3)
    return h

In [11]:
args.binwidth = 0.2
bins_df = []
for idx, row in marginals_df.iterrows():
    sub = row['sub']
    sbm = row['sbm']
    mode_id = row['mode_id']
    omega = row['omega']
    
    dct = {'sub': [sub], 'sbm': [sbm], 'mode_id': [mode_id], 'omega': [omega]}
    for col in cols:
        pi = row[col]
        h = mem_bins_per_marginal(args, pi)
        dct[col] = [h]
    bins_df += [pd.DataFrame(dct)]
bins_df = pd.concat(bins_df).reset_index(drop=True)
bins_df

Unnamed: 0,sub,sbm,mode_id,omega,pi_aligned
0,SLC01,sbm--m,0,0.806483,"[[0.0, 0.0, 0.043, 0.0, 0.0, 0.0, 0.0], [0.0, ..."
1,SLC01,sbm--m,1,0.165306,"[[0.0, 0.0, 0.0, 0.048, 0.0, 1.0], [0.0, 0.0, ..."
2,SLC01,sbm--m,2,0.019968,"[[0.054, 0.019, 0.0, 0.029], [0.0, 0.019, 0.02..."
3,SLC01,sbm--m,3,0.008243,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0...."
4,SLC02,sbm--m,0,0.3338,"[[0.0, 0.0, 0.043, 0.0, 0.0, 0.0, 0.0], [0.0, ..."
5,SLC02,sbm--m,1,0.3292,"[[0.0, 0.0, 0.0, 0.048, 0.0, 1.0], [0.0, 0.0, ..."
6,SLC02,sbm--m,2,0.26572,"[[0.054, 0.019, 0.0, 0.029], [0.0, 0.019, 0.02..."
7,SLC02,sbm--m,3,0.07104,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0...."
8,SLC02,sbm--m,4,0.00024,"[[0.067, 0.217, 0.055, 0.071, 0.0, 1.0, 0.0, 0..."
9,SLC03,sbm--m,0,0.575495,"[[0.0, 0.0, 0.043, 0.0, 0.0, 0.0, 0.0], [0.0, ..."


In [12]:
soft_bins_df = []
for idx, row in soft_marginals_df.iterrows():
    sub = row['sub']
    sbm = row['sbm']

    dct = {'sub':[sub], 'sbm':[sbm]}
    for col in cols:
        pi = row[col]
        h = mem_bins_per_marginal(args, pi)
        dct[col] = [h]
    soft_bins_df += [pd.DataFrame(dct)]
soft_bins_df = pd.concat(soft_bins_df).reset_index(drop=True)
soft_bins_df

Unnamed: 0,sub,sbm,pi_aligned
0,SLC01,sbm--m,"[[0.0, 0.0, 0.062, 0.024, 0.0, 1.0, 0.0, 0.0],..."
1,SLC02,sbm--m,"[[0.0, 0.0, 0.021, 0.024, 0.0, 1.0, 0.0, 1.0, ..."
2,SLC03,sbm--m,"[[0.0, 0.0, 0.042, 0.024, 0.0, 1.0, 0.0], [0.0..."
3,SLC04,sbm--m,"[[0.0, 0.0, 0.021, 0.024, 0.0, 1.0, 0.0, 1.0, ..."
4,SLC05,sbm--m,"[[0.0, 0.0, 0.043, 0.0, 0.0, 0.0, 0.0], [0.0, ..."
5,SLC06,sbm--m,"[[0.0, 0.0, 0.062, 0.0, 0.0, 1.0, 0.0], [0.0, ..."
6,SLC07,sbm--m,"[[0.0, 0.0, 0.042, 0.024, 0.0, 1.0, 0.0], [0.0..."
7,SLC08,sbm--m,"[[0.0, 0.0, 0.042, 0.024, 0.0, 1.0, 0.0, 0.0],..."
8,SLC09,sbm--m,"[[0.0, 0.0, 0.042, 0.024, 0.0, 1.0, 0.0], [0.0..."
9,SLC10,sbm--m,"[[0.0, 0.0, 0.043, 0.0, 0.0, 0.0, 0.0], [0.0, ..."


plot histograms per animal

In [13]:
# plot membership histograms per mode, and overlay soft-marginal histograms 

In [14]:
bins = np.arange(0, 1+args.binwidth, args.binwidth)
args.num_bins = len(bins) - 1
xs = (bins[1:] + bins[:-1]) / 2

In [15]:
def plot_sub_hists_per_comm(idx_comm, ax, hs, soft_h):
    ax.plot(xs, hs[:, 0], marker='o', c='grey', markersize=7, linewidth=1, alpha=0.3, label='per mode')
    ax.plot(xs, hs[:, 1:], marker='o', c='grey', markersize=7, linewidth=1, alpha=0.3)
    ax.plot(xs, soft_h, marker='o', c='cornflowerblue', markersize=7, linewidth=3, alpha=1.0, label='across modes')
    ax.set(title=f'comm {idx_comm:02d}', xlabel=f'membership', ylabel=f'roi proportion')
    ax.grid(alpha=0.3)
    ax.legend()
    return ax

In [16]:
def plot_histograms_per_sub(args, sub, bdf, srow, col):
    H = get_membership_matrix(args.num_bins, bdf, col) 
    # using the same function to collect all matrices (membership/hist.bins)
    soft_h = srow[col].to_list()[0]
    
    ncols = 5
    nrows = np.ceil(H.shape[-1] / ncols).astype(int)
    fig, axs = plt.subplots(nrows, ncols, figsize=(5*ncols, 4*nrows))
    fig.tight_layout(h_pad=3, w_pad=3)

    level = col.split('_')[-1]
    sbm_title = f'{SBM} level-{level}' if level != 'aligned' else f'{SBM}'
    fig.suptitle(f'{sbm_title} {sub}', x=0.0, y=1.0)

    for idx_comm in range(H.shape[-1]):
        r, c = idx_comm // ncols, idx_comm % ncols
        ax = axs[r,c] if nrows > 1 else axs[c]
        ax = plot_sub_hists_per_comm(
            idx_comm, ax, 
            hs=H[:, :, idx_comm], 
            soft_h=soft_h[:, idx_comm]
        )
    for c_ in range(c+1, ncols):
        ax = axs[r, c_] if nrows > 1 else axs[c_]
        fig.delaxes(ax)
    return fig 

In [17]:
groups = list(bins_df.groupby('sub'))

for col in cols:
    for sub, bdf in tqdm(groups):
        srow = soft_bins_df[soft_bins_df['sub'] == sub]
        fig = plot_histograms_per_sub(args, sub, bdf, srow, col)

        folder = f'{ESTIM_path}/individual/sub-{sub}/membership-histograms/{SBM}'
        os.system(f'mkdir -p {folder}')
        fig.savefig(f'{folder}/desc-{col}.pdf', bbox_inches='tight')
        # break
    plt.close('all')

100%|██████████| 10/10 [00:12<00:00,  1.21s/it]


plot histograms across animals

In [18]:
# plot soft-marginal histograms per animal and their bootstrap mean

In [19]:
def bootstrap_histogram_means(matrix, n_bootstrap=1000, confidence_level=0.95):
    n_realizations, n_bins = matrix.shape
    
    compute_sample_mean = lambda x: np.mean(x, axis=0)
    
    # bootstrap
    bootstrap_means = np.array([
        compute_sample_mean(matrix[np.random.choice(n_realizations, size=n_realizations, replace=True), :])
        for _ in range(n_bootstrap)
    ])
    
    # statistics
    mean_estimate = np.mean(bootstrap_means, axis=0)
    
    # confidence intervals
    # in case of bootstrapping, one can calculate CI as follows:
    # sort all bootstrap means in ascending order 
    # lower bound = confidence_level / 2 th percentile of the bootstrap means
    # upper bound = 100 - confidence_level / 2 th percentile of the bootstrap means
    ci_lower = np.percentile(bootstrap_means, q=100*(1 - confidence_level) / 2, axis=0)
    ci_upper = np.percentile(bootstrap_means, q=100 * (1 - (1 - confidence_level) / 2), axis=0)
    # 95% CI based on standard error of the mean
    # vs
    # 95% CI based on standard deviation of data
    
    return mean_estimate, ci_lower, ci_upper

In [20]:
def plot_soft_hists_per_comm(idx_comm, ax, xs, hs, mu, cil, ciu):
    ax.plot(xs, hs[:, 0], marker='o', c='grey', markersize=7, fillstyle='none', linewidth=1, alpha=0.3, label='animal')
    ax.plot(xs, hs[:, 1:], marker='o', c='grey', markersize=7, fillstyle='none', linewidth=1, alpha=0.3)

    ax.plot(xs, mu, marker='o', c='cornflowerblue', markersize=10, linewidth=3, alpha=1.0, label='mean')
    ax.fill_between(
        x=xs, 
        y1=cil,
        y2=ciu,
        color='cornflowerblue',
        alpha=0.3,
        label='95% CI'
    )
    ax.legend()
    ax.grid(alpha=0.3)
    ax.set(title=f'comm {idx_comm:02d}', xlabel=f'membership', ylabel=f'roi proportion')
    return ax

In [21]:
def plot_soft_hists(H):
    ncols = 5
    nrows = np.ceil(H.shape[-1] / ncols).astype(int)
    fig, axs = plt.subplots(nrows, ncols, figsize=(5*ncols, 4*nrows))
    fig.tight_layout(h_pad=3, w_pad=3)

    level = col.split('_')[-1]
    sbm_title = f'{SBM} level-{level}' if level != 'aligned' else f'{SBM}'
    fig.suptitle(f'{sbm_title}', x=0.0, y=1.0)

    for idx_comm in range(H.shape[-1]):
        r, c = idx_comm // ncols, idx_comm % ncols
        ax = axs[r, c] if nrows > 1 else axs[c]

        hs = H[:, :, idx_comm]
        (mu, cil, ciu) = bootstrap_histogram_means(H[:, :, idx_comm].T)
        ax = plot_soft_hists_per_comm(idx_comm, ax, xs, hs, mu, cil, ciu)

    for c_ in range(c+1, ncols):
        ax = axs[r, c_] if nrows > 1 else axs[c_]
        fig.delaxes(ax)
    return fig  

In [22]:
for col in tqdm(cols):
    H = get_membership_matrix(args.num_bins, soft_bins_df, col=col)
    fig = plot_soft_hists(H)

    folder = f'{ESTIM_path}/group/membership-histograms/{SBM}'
    os.system(f'mkdir -p {folder}')
    fig.savefig(f'{folder}/desc-{col}.pdf', bbox_inches='tight')
plt.close('all')

100%|██████████| 1/1 [00:02<00:00,  2.64s/it]
