# Aug 20, 2025: group-averaged graph
unit-grp

create animal-level FC matrices, Fisher transform them, average, Inverse Fisher transform, then threshold.

In [1]:
import os
import glob
import pandas as pd
import re
import numpy as np
from tqdm import tqdm
from itertools import product, combinations
from sklearn.covariance import GraphicalLasso
from scipy.stats import entropy, zscore
from sklearn.metrics import mutual_info_score
from joblib import Parallel, delayed
import graph_tool.all as gt 
import seaborn as sns

In [2]:
class ARGS():
    pass

args = ARGS()

args.SEED = 100

np.random.seed(args.SEED)

In [3]:
args.source = 'allen' #'allen'
args.space = 'ccfv2' #'ccfv2'
args.brain_div = 'whl' #'whl'
args.num_rois = 172 #216 #334 #162 #172
args.resolution = 200 #200

PARC_DESC = (
    f'source-{args.source}'
    f'_space-{args.space}'
    f'_braindiv-{args.brain_div}'
    f'_nrois-{args.num_rois}'
    f'_res-{args.resolution}'
)
PARC_DESC

'source-allen_space-ccfv2_braindiv-whl_nrois-172_res-200'

In [4]:
BASE_path = f'{os.environ["HOME"]}/new_mouse_dataset'
PARCELS_path = f'{BASE_path}/parcels'
ROI_path = (
    f'{BASE_path}/roi-results-v3'
    f'/{PARC_DESC}'
)
os.system(f'mkdir -p {ROI_path}') 
TS_path = f'{ROI_path}/roi-timeseries'
os.system(f'mkdir -p {TS_path}')

0

In [5]:
args.eps = 1e-7

In [6]:
def collect_timeseries(files):
    data_df = []

    pattern = re.compile(
        r"sub-(?P<sub>\w+)_ses-(?P<ses>\d+)_run-(?P<run>\d+)_task-(?P<task>\w+)_desc-ts\.txt"
    )
    # sub-SLC01_ses-1_run-11_task-rest_desc-ts.txt

    for file in tqdm(files):
        file_name = os.path.basename(file)
        match = pattern.match(file_name)
        metadata = match.groupdict()
        ts = np.loadtxt(file)
        metadata['ts'] = zscore(ts, axis=0, nan_policy='omit')
        data_df.append(metadata)
    data_df = pd.DataFrame(data_df).reset_index(drop=True)
    return data_df

In [7]:
data_df = collect_timeseries(
    files=sorted(glob.glob(f'{TS_path}/*', recursive=True))
)

  0%|          | 0/86 [00:00<?, ?it/s]

100%|██████████| 86/86 [00:00<00:00, 130.63it/s]


In [8]:
data_df

Unnamed: 0,sub,ses,run,task,ts
0,SLC01,1,11,rest,"[[-0.0955117651496492, -0.9480141361079346, 0...."
1,SLC01,1,15,rest,"[[0.2300138323488809, -1.7545658974959732, -1...."
2,SLC01,1,19,rest,"[[0.17467503871495776, 2.0563678304287225, 1.3..."
3,SLC01,2,10,rest,"[[2.193628914198822, -0.10906563819651435, 1.5..."
4,SLC01,2,6,rest,"[[0.6019518044701574, -0.13158998274825498, -1..."
...,...,...,...,...,...
81,SLC10,2,9,rest,"[[-1.4238413632568756, 1.0018252019653109, 1.1..."
82,SLC10,3,13,rest,"[[0.12182039381642772, 0.7583543245732659, 0.3..."
83,SLC10,3,17,rest,"[[-0.6009069370177516, -0.4645906291763468, -0..."
84,SLC10,3,5,rest,"[[-1.8217132895432822, 0.3352282395733613, 0.6..."


In [9]:
def get_cols(args):
    if args.DATA_UNIT == 'ses':
        cols = ['sub', 'ses', 'task']
    if args.DATA_UNIT == 'sub':
        cols = ['sub', 'task']
    if args.DATA_UNIT == 'grp':
        cols = ['task'] 
    return cols

In [10]:
# normalized mutual information
def optimal_bin_size(ts, method="fd"):
    """Computes the optimal number of bins for fMRI time series based on the selected method."""
    N = len(ts)  # Number of time points

    if method == "sturges":
        return int(np.ceil(np.log2(N) + 1))
    
    elif method == "rice":
        return int(np.ceil(2 * N ** (1/3)))

    elif method == "fd":  # Freedman-Diaconis Rule
        iqr = np.percentile(ts, 75) - np.percentile(ts, 25)
        bin_width = (2 * iqr) / (N ** (1/3))
        return int(np.ceil((np.max(ts) - np.min(ts)) / bin_width))

    elif method == "scott":  # Scott's Rule
        std_dev = np.std(ts)
        bin_width = (3.5 * std_dev) / (N ** (1/3))
        return int(np.ceil((np.max(ts) - np.min(ts)) / bin_width))
    
def compute_joint_density(ts1, ts2, bins=100):
    hist_xy, x_edges, y_edges = np.histogram2d(ts1, ts2, bins=bins, density=True)
    hist_x = np.histogram(ts1, bins=x_edges, density=True)[0]
    hist_y = np.histogram(ts2, bins=y_edges, density=True)[0]

    p_xy = hist_xy / np.sum(hist_xy) # joint density
    p_x = hist_x / np.sum(hist_x) # marginal of x
    p_y = hist_y / np.sum(hist_y) # marginal of y

    return p_xy, p_x, p_y

def compute_nmi(ts1, ts2, bins=100):
    # densities
    p_xy, p_x, p_y = compute_joint_density(ts1, ts2, bins)
    
    # entropies
    Hxy = entropy(p_xy.flatten(), base=2) # joint entropy: same as summing `- p_xy log(p_xy)` over each (x, y)
    Hx = entropy(p_x, base=2) 
    Hy = entropy(p_y, base=2)

    # mutual information
    Ixy = Hx + Hy - Hxy

    # normalize MI
    Ixy = Ixy / np.sqrt(Hx * Hy) if Hx > 0 and Hy > 0 else 0
    return Ixy

def compute_nmi_matrix(ts, bins=100, n_jobs=10):
    num_rois = ts.shape[1]
    nmi_matrix = np.zeros((num_rois, num_rois))

    def compute_nmi_pair(i, j):
        return compute_nmi(ts[:, i], ts[:, j], bins)
    
    results = Parallel(n_jobs=n_jobs)(
        delayed(compute_nmi_pair)(i, j)
        for i, j in combinations(range(num_rois), 2)
    )

    # fill nmi matrix
    for idx, (i, j) in enumerate(combinations(range(num_rois), 2)):
        nmi_matrix[i, j] = results[idx]
        nmi_matrix[j, i] = results[idx]
    
    return nmi_matrix

In [11]:
def compute_fc(args, ts):
    # ts.shape : time x rois
    if args.GRAPH_METHOD == 'pearson':
        fc = np.corrcoef(ts.T)
        # fc -= np.diag(np.diag(fc))
    if args.GRAPH_METHOD == 'partial':
        model = GraphicalLasso(alpha=0.01)
        model.fit(ts)
        fc = -model.precision_ # inverse covariance matrix
    if args.GRAPH_METHOD == 'mutualinfo':
        bins = optimal_bin_size(ts)
        fc = compute_nmi_matrix(ts, bins=bins, n_jobs=10)
    return np.nan_to_num(fc)

def threshold_fc(args, fc_matrix):
    keep_ratio = args.EDGE_DENSITY / 100
    
    fc_thresh = np.zeros_like(fc_matrix)

    # Compute percentile threshold
    fc_values = fc_matrix[np.triu_indices_from(fc_matrix, k=1)]  # Extract upper triangle
    if args.THRESHOLD=='signed':
        fc_values = fc_values  # Consider values with their signs
    if args.THRESHOLD=='unsigned':
        fc_values = np.abs(fc_values) # Consider values without their signs
    percentile_thresh = np.percentile(fc_values, 100 * (1 - keep_ratio))

    # Apply percentile threshold
    mask = fc_matrix >= percentile_thresh

    # construct edges by their definition
    if args.EDGE_DEF == 'binary':
        fc_thresh = mask
    elif args.EDGE_DEF == 'weighted':
        fc_thresh = fc_matrix * mask

    return fc_thresh

def make_graph(fc):
    fc = np.tril(fc, k=-1)

    edges = np.where(fc)
    edge_list = list(zip(*[*edges, fc[edges]]))

    g = gt.Graph(
        edge_list,
        eprops=[('weight', 'double')],
        directed=False, 
    )
    
    return g

def save_graph(g, identity, GRAPH_path):
    file = '_'.join([identity] + [f'desc-graph.gt.gz'])
    file = f'{GRAPH_path}/{file}'
    g.save(file)
    return file

In [12]:
def fisher_z(r):
    r = np.clip(r, -1+args.eps, 1-args.eps)
    return np.arctanh(r)

def tidy_corr(R):
    R = 0.5 * (R + R.T)
    np.fill_diagonal(R, 1.0)
    return R

In [13]:
GRAPH_DEFS = [f'constructed']
GRAPH_METHODS = [f'pearson'] # [f'pearson', f'mutualinfo']
THRESHOLDINGS = [f'signed', f'unsigned']
EDGE_DEFS = [f'binary', f'weighted']
EDGE_DENSITIES = [10, 20, 30] #[10, 15, 20, 25]
LAYER_DEFS = [f'individual'] #, f'multilayer']
DATA_UNITS = [f'ses', f'sub', f'grp']

In [14]:
args.GRAPH_DEF = f'constructed'
args.GRAPH_METHOD = f'pearson'
args.LAYER_DEF = f'individual'


In [15]:
data_df['t_eff'] = data_df['ts'].apply(lambda ts: len(ts))
data_df['fc'] = data_df['ts'].apply(lambda ts: compute_fc(args, ts))
data_df['t_eff'] = data_df['ts'].apply(lambda ts: len(ts))
data_df['w'] = data_df["t_eff"].clip(lower=3+1e-9) - 3  # ensure >0
data_df['z'] = data_df['fc'].apply(lambda fc: fisher_z(fc))
data_df

Unnamed: 0,sub,ses,run,task,ts,t_eff,fc,w,z
0,SLC01,1,11,rest,"[[-0.0955117651496492, -0.9480141361079346, 0....",533,"[[1.0, -0.10886562658853842, 0.311368884476664...",530,"[[8.40562139102231, -0.1092987928567174, 0.322..."
1,SLC01,1,15,rest,"[[0.2300138323488809, -1.7545658974959732, -1....",531,"[[0.9999999999999998, -0.12575106474913797, 0....",528,"[[8.40562139102231, -0.12642027347584914, 0.21..."
2,SLC01,1,19,rest,"[[0.17467503871495776, 2.0563678304287225, 1.3...",531,"[[0.9999999999999998, -0.0342111144644278, 0.1...",528,"[[8.40562139102231, -0.034224470745147936, 0.1..."
3,SLC01,2,10,rest,"[[2.193628914198822, -0.10906563819651435, 1.5...",532,"[[1.0, 0.0359209863381718, 0.3441747585209697,...",529,"[[8.40562139102231, 0.03593644813319431, 0.358..."
4,SLC01,2,6,rest,"[[0.6019518044701574, -0.13158998274825498, -1...",521,"[[1.0, -0.21089558937657402, 0.220005265470504...",518,"[[8.40562139102231, -0.21410843818645034, 0.22..."
...,...,...,...,...,...,...,...,...,...
81,SLC10,2,9,rest,"[[-1.4238413632568756, 1.0018252019653109, 1.1...",533,"[[1.0, -0.019341345642795108, 0.08169658468217...",530,"[[8.40562139102231, -0.019343757970459452, 0.0..."
82,SLC10,3,13,rest,"[[0.12182039381642772, 0.7583543245732659, 0.3...",536,"[[1.0, -0.2165871011678161, 0.2605791576485126...",533,"[[8.40562139102231, -0.2200724365902028, 0.266..."
83,SLC10,3,17,rest,"[[-0.6009069370177516, -0.4645906291763468, -0...",537,"[[0.9999999999999998, -0.022806096757538592, 0...",534,"[[8.40562139102231, -0.022810051946096956, 0.1..."
84,SLC10,3,5,rest,"[[-1.8217132895432822, 0.3352282395733613, 0.6...",540,"[[1.0, -0.004297112101083848, 0.22982908059433...",537,"[[8.40562139102231, -0.004297138550348819, 0.2..."


In [16]:
def average_z(group):
    Z = np.stack(group['z'].to_list())
    W = group['w'].to_numpy()[:, None, None]
    z = np.nansum(W * Z, axis=0) / np.nansum(W, axis=0)
    fc = np.tanh(z)
    fc = tidy_corr(fc)
    return pd.Series({
        'z': z,
        'w': group['w'].sum(),
        'fc': fc,
    })

In [17]:
ses_df = data_df.groupby(by=['sub', 'ses', 'task']).apply(average_z, include_groups=True).reset_index()
ses_df

  ses_df = data_df.groupby(by=['sub', 'ses', 'task']).apply(average_z, include_groups=True).reset_index()


Unnamed: 0,sub,ses,task,z,w,fc
0,SLC01,1,rest,"[[8.40562139102231, -0.09000553919467004, 0.24...",1586,"[[1.0, -0.08976327931279841, 0.236649683137248..."
1,SLC01,2,rest,"[[8.40562139102231, -0.08777248320737485, 0.29...",1047,"[[1.0, -0.08754777564700199, 0.283929635154008..."
2,SLC01,3,rest,"[[8.405621391022311, 0.014798909570876604, 0.2...",2098,"[[1.0, 0.014797829307007621, 0.213540796490828..."
3,SLC03,1,rest,"[[8.40562139102231, 0.0037545551174770603, 0.2...",1592,"[[1.0, 0.003754537475317359, 0.259644334930892..."
4,SLC03,2,rest,"[[8.40562139102231, 0.06688078786549213, 0.361...",2118,"[[1.0, 0.06678124582223768, 0.3466828335548699..."
5,SLC03,3,rest,"[[8.405621391022311, 0.04500849967106798, 0.37...",2127,"[[1.0, 0.04497813206280121, 0.3605277001036738..."
6,SLC04,1,rest,"[[8.40562139102231, -0.06940542974987683, 0.25...",2041,"[[1.0, -0.06929419945354481, 0.249712571831157..."
7,SLC04,2,rest,"[[8.40562139102231, -0.07243789477580885, 0.26...",2103,"[[1.0, -0.07231146026187109, 0.254388277161901..."
8,SLC04,3,rest,"[[8.40562139102231, 0.0964413439000736, 0.4177...",2094,"[[1.0, 0.09614345395939361, 0.3950000876275014..."
9,SLC05,1,rest,"[[8.40562139102231, 0.0797431017588722, 0.3839...",1074,"[[1.0, 0.07957450279965034, 0.3661385256547809..."


In [18]:
sub_df = ses_df.groupby(by=['sub', 'task']).apply(average_z, include_groups=True).reset_index()
sub_df['w'] = 1
sub_df

  sub_df = ses_df.groupby(by=['sub', 'task']).apply(average_z, include_groups=True).reset_index()


Unnamed: 0,sub,task,z,w,fc
0,SLC01,rest,"[[8.40562139102231, -0.04303497417061278, 0.24...",1,"[[1.0, -0.04300842678365287, 0.237056769502204..."
1,SLC03,rest,"[[8.40562139102231, 0.041693308077179615, 0.34...",1,"[[1.0, 0.04166916592751896, 0.3286083377113853..."
2,SLC04,rest,"[[8.405621391022311, -0.014755562793566932, 0....",1,"[[1.0, -0.014754491993820052, 0.30168592145432..."
3,SLC05,rest,"[[8.40562139102231, 0.10473434378424565, 0.306...",1,"[[1.0, 0.10435306310298437, 0.2968590679713448..."
4,SLC06,rest,"[[8.40562139102231, -0.06182466547252938, 0.33...",1,"[[1.0, -0.06174601513511453, 0.322730455419894..."
5,SLC07,rest,"[[8.40562139102231, 0.05746150258193489, 0.255...",1,"[[1.0, 0.05739834340205185, 0.2501561095635521..."
6,SLC08,rest,"[[8.40562139102231, 0.10045713202674197, 0.353...",1,"[[1.0, 0.10012056498382152, 0.3394780844648069..."
7,SLC09,rest,"[[8.40562139102231, -0.05761223699354435, 0.24...",1,"[[1.0, -0.05754857990764729, 0.240632688470436..."
8,SLC10,rest,"[[8.40562139102231, -0.034438876417996755, 0.2...",1,"[[1.0, -0.03442526762269431, 0.212924209719911..."


In [19]:
grp_df = sub_df.groupby(by=['task']).apply(average_z, include_groups=True).reset_index()
grp_df

  grp_df = sub_df.groupby(by=['task']).apply(average_z, include_groups=True).reset_index()


Unnamed: 0,task,z,w,fc
0,rest,"[[8.40562139102231, 0.010297774513539103, 0.28...",9,"[[1.0, 0.010297410522696252, 0.281707018799712..."


In [20]:
def save_graphs(args, df, ITERS):
    for (
        THRESHOLD, 
        EDGE_DEF, 
        EDGE_DENSITY,
    ) in ITERS:
        args.THRESHOLD = THRESHOLD
        args.EDGE_DEF = EDGE_DEF
        args.EDGE_DENSITY = EDGE_DENSITY
        
        ROI_RESULTS_path = (
            f'{ROI_path}'
            f'/graph-{args.GRAPH_DEF}/method-{args.GRAPH_METHOD}'
            f'/threshold-{args.THRESHOLD}/edge-{args.EDGE_DEF}/density-{args.EDGE_DENSITY}'
            f'/layer-{args.LAYER_DEF}/unit-{args.DATA_UNIT}'
        )
        GRAPH_path = f'{ROI_RESULTS_path}/graphs'
        os.system(f'mkdir -p {GRAPH_path}')
        
        cols = get_cols(args)
        
        for key, group in tqdm(df.groupby(by=cols)):
            identity = '_'.join([f'{c}-{k}' for c, k in zip(cols, key)])
            fc = group['fc'].to_list()[0]
            fc = threshold_fc(args, fc)
            g = make_graph(fc)
            file = save_graph(g, identity, GRAPH_path)
            # break
        
        # break

In [21]:
args.DATA_UNIT = 'ses'
ITERS = product( 
    THRESHOLDINGS, 
    EDGE_DEFS, 
    EDGE_DENSITIES, 
)
save_graphs(args, ses_df, ITERS)

100%|██████████| 27/27 [00:00<00:00, 206.87it/s]
100%|██████████| 27/27 [00:00<00:00, 152.84it/s]
100%|██████████| 27/27 [00:00<00:00, 103.94it/s]
100%|██████████| 27/27 [00:00<00:00, 230.13it/s]
100%|██████████| 27/27 [00:00<00:00, 165.42it/s]
100%|██████████| 27/27 [00:00<00:00, 109.35it/s]
100%|██████████| 27/27 [00:00<00:00, 357.57it/s]
100%|██████████| 27/27 [00:00<00:00, 183.33it/s]
100%|██████████| 27/27 [00:00<00:00, 152.32it/s]
100%|██████████| 27/27 [00:00<00:00, 353.93it/s]
100%|██████████| 27/27 [00:00<00:00, 225.15it/s]
100%|██████████| 27/27 [00:00<00:00, 204.71it/s]


In [22]:
args.DATA_UNIT = 'sub'
ITERS = product( 
    THRESHOLDINGS, 
    EDGE_DEFS, 
    EDGE_DENSITIES, 
)
save_graphs(args, sub_df, ITERS)

100%|██████████| 9/9 [00:00<00:00, 247.05it/s]
100%|██████████| 9/9 [00:00<00:00, 168.61it/s]
100%|██████████| 9/9 [00:00<00:00, 121.11it/s]
100%|██████████| 9/9 [00:00<00:00, 267.65it/s]
100%|██████████| 9/9 [00:00<00:00, 164.48it/s]
100%|██████████| 9/9 [00:00<00:00, 119.08it/s]
100%|██████████| 9/9 [00:00<00:00, 167.25it/s]
100%|██████████| 9/9 [00:00<00:00, 249.66it/s]
100%|██████████| 9/9 [00:00<00:00, 209.82it/s]
100%|██████████| 9/9 [00:00<00:00, 342.20it/s]
100%|██████████| 9/9 [00:00<00:00, 255.73it/s]
100%|██████████| 9/9 [00:00<00:00, 205.04it/s]


In [23]:
args.DATA_UNIT = 'grp'
ITERS = product( 
    THRESHOLDINGS, 
    EDGE_DEFS, 
    EDGE_DENSITIES, 
)
save_graphs(args, grp_df, ITERS)

100%|██████████| 1/1 [00:00<00:00, 83.44it/s]
100%|██████████| 1/1 [00:00<00:00, 56.16it/s]
100%|██████████| 1/1 [00:00<00:00, 68.51it/s]
100%|██████████| 1/1 [00:00<00:00, 194.52it/s]
100%|██████████| 1/1 [00:00<00:00, 107.16it/s]
100%|██████████| 1/1 [00:00<00:00, 68.18it/s]
100%|██████████| 1/1 [00:00<00:00, 240.58it/s]
100%|██████████| 1/1 [00:00<00:00, 151.51it/s]
100%|██████████| 1/1 [00:00<00:00, 132.50it/s]
100%|██████████| 1/1 [00:00<00:00, 204.11it/s]
100%|██████████| 1/1 [00:00<00:00, 121.17it/s]
100%|██████████| 1/1 [00:00<00:00, 108.78it/s]
