In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
import pandas as pd
from scipy.spatial.distance import pdist, squareform

datavignettes = pd.read_parquet("./zenodo/maindata_2.parquet")
 
namingtable = pd.DataFrame({
    "cluster": [
        11111, 11112, 11121, 11122, 11211, 11212, 11221, 11222, 12111, 12112, 
        12121, 12122, 12211, 12212, 12221, 12222, 21111, 21112, 21120, 21211, 
        21212, 21221, 21222, 22111, 22112, 22121, 22122, 22211, 22212, 22221, 22222
    ],
    "zone": [
        "Mixed and hindbrain white matter", "Core callosal white matter", 
        "Callosal and cerebellar white matter", "Ventral white matter", 
        "Boundary white matter", "Thalamic and mid/hindbrain white matter", 
        "Mid/hindbrain white matter", "Mixed white matter", 
        "Choroid plexus and ventricles", "Ventricular linings", 
        "Thalamic and midbrain regions", "White and gray matter boundary", 
        "Thalamic mixed gray and white matter", "Thalamic mixed gray and white matter #2", 
        "Neuron-rich lateral white matter", "Neuron-rich lateral white matter #2", 
        "Pallidum and projections", "Cortical layer 4", 
        "Subcortical plate, hippocampus and hypothalamus", 
        "GABA-ergic Purkinje cells of the cerebellum", "Cortical layers 2-3 and 4", 
        "Piriform cortex", "Cortical layers 1 and 2-3", "Cortical layer 5", 
        "Cortical layer 6, dentate gyrus", "Striatum, hypothalamus and hippocampus", 
        "Striatum, hypothalamus and hippocampus #2", 
        "Retrosplenial, cortical, cerebellar", "Cortical layer 6 and cerebellar Y", 
        "Cerebellar glutamatergic neurons", "Cortical layer 6 and thalamic"
    ],
    "color": [
        "#360064", "#980053", "#170b3b", "#ac2f5c", "#2a3f6d", "#002657", 
        "#21366b", "#3e4b6c", "#f75400", "#ef633e", "#a5d4e6", "#6399c6", 
        "#853a00", "#edeef4", "#fdbf71", "#ce710e", "#940457", "#a2d36c", 
        "#d5edb5", "#0065d6", "#bcf18b", "#a68d68", "#79e47e", "#2f0097", 
        "#47029f", "#7500a8", "#d70021", "#ca99c9", "#d4b9da", "#e00085", 
        "#f6f3f8"
    ]
})

clusters_short = pd.Series([x[:5] for x in datavignettes['lipizone']], name='cluster').astype(int)
clusters = clusters_short.to_frame().merge(namingtable, on='cluster', how='left')
clusters

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import anndata
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

atlas = datavignettes

allen = pd.read_csv("./zenodo/csv/allenconnectome_ipsi.csv", index_col=0)  #pd.read_csv("allenconnectome_contra.csv", index_col=0) + 

# match the Allen acronyms used in the 2014 Oh et al paper
import difflib

injsitesall = allen.index.dropna().astype(str).values
unique_acronyms = datavignettes['acronym'].dropna().astype(str).unique()

best_match_dict = {}

for unique_acr in unique_acronyms:
    if not isinstance(unique_acr, str) or unique_acr.strip() == "":
        continue 
    matches = difflib.get_close_matches(unique_acr, injsitesall, n=1)
    best_match_dict[unique_acr] = matches[0] if matches else None

datavignettes['adaptedacronym'] = datavignettes['acronym'].map(best_match_dict)
print(f"Atlas length: {len(datavignettes)}")
print(f"Adapted acronym column length: {datavignettes['adaptedacronym'].notna().sum()}")
datavignettes['adaptedacronym'].value_counts()

## 1 Connectomic strength matrix

In [None]:
connectome = allen.loc[allen.index.isin(atlas['adaptedacronym'].unique()), allen.index.isin(atlas['adaptedacronym'].unique())]
atlas = atlas.loc[atlas['adaptedacronym'].isin(allen.index),:]

connectome

In [None]:
connectome = connectome + connectome.T # symmetrize
connectome

In [None]:
connectome_values = connectome.values  
i_upper = np.triu_indices_from(connectome_values, k=1)
values = connectome_values[i_upper]

top_5_percentile = np.percentile(values, 95)
plt.hist(values, bins=50)
plt.axvline(x=top_5_percentile, color='r', linestyle='--', label=f'Top 5% threshold: {top_5_percentile:.4f}')

plt.xlabel("Value")
plt.ylabel("Count")
plt.title("Histogram of upper triangle (excluding diagonal)")
plt.legend()
plt.show()

## 2 Physical distance matrix

In [None]:
from allensdk.core.mouse_connectivity_cache import MouseConnectivityCache
import pandas as pd
import numpy as np
from tqdm import tqdm

## use with care!
import warnings
warnings.filterwarnings('ignore')

mcc = MouseConnectivityCache(manifest_file='mouse_connectivity_manifest.json')
structure_tree = mcc.get_structure_tree()

annotation, _ = mcc.get_annotation_volume()

In [None]:
annotation = annotation[:,:, :int(annotation.shape[2]/2)]

In [None]:
import requests
import numpy as np
import matplotlib.pyplot as plt

# --- Allen Structure Graph Functions ---

def download_structure_graph():
    """
    Downloads the Allen Brain Atlas structure graph from the Allen API.
    Returns the JSON response as a dictionary.
    """
    url = "http://api.brain-map.org/api/v2/structure_graph_download/1.json"
    response = requests.get(url)
    if response.status_code != 200:
        raise Exception("Failed to download structure graph; status code " + str(response.status_code))
    return response.json()

def find_node_by_acronym(node, acronym):
    """
    Recursively searches for a node with the given acronym in the structure graph.
    Returns the node (a dict) if found, or None otherwise.
    """
    if node.get("acronym") == acronym:
        return node
    for child in node.get("children", []):
        found = find_node_by_acronym(child, acronym)
        if found is not None:
            return found
    return None

def extract_all_ids(node):
    """
    Recursively extracts and returns a list of all region IDs from the given node and its children.
    """
    ids = [node.get("id")]
    for child in node.get("children", []):
        ids.extend(extract_all_ids(child))
    return ids

def get_downstream_ids(acronym, structure_graph):
    """
    Given an acronym and the Allen structure graph (as a dict), finds the node with that acronym
    and returns a list of all downstream region IDs (including the node's own ID).
    """
    for node in structure_graph.get("msg", []):
        found = find_node_by_acronym(node, acronym)
        if found is not None:
            return extract_all_ids(found)
    return []


# --- Annotation Volume Processing Functions ---

def compute_region_center(annotation, region_ids):
    """
    Given a 3D annotation volume (where voxel values are Allen region IDs)
    and a list of region IDs, creates a mask, computes the average (x,y,z)
    voxel coordinate (center-of-mass) of those voxels, and returns that along with
    the voxel count, the mask, and the indices of the matching voxels.
    """
    # Create a boolean mask for voxels matching any of the region_ids
    mask = np.isin(annotation, region_ids)
    
    # Find the indices of all voxels in the mask
    indices = np.nonzero(mask)
    
    if np.sum(mask) == 0:
        raise ValueError("No voxels found matching the provided region IDs: " + str(region_ids))
    
    # Compute the average coordinate for each axis.
    # Note: The array shape is (rostral-caudal, dorsoventral, mediolateral)
    avg_coords = (np.mean(indices[0]), np.mean(indices[1]), np.mean(indices[2]),  np.sum(mask))

    return avg_coords # mask, indices

structure_graph = download_structure_graph()

In [None]:
allacro = []

for acronym in tqdm(connectome.index.values):
    try:
        allacro.append(compute_region_center(annotation, get_downstream_ids(acronym, structure_graph)))
    except:
        print(acronym)
        continue

spatialcentroids = pd.DataFrame(allacro, index = connectome.index.values[connectome.index.values != "SUBv"], columns = ['x_index','y_index','z_index','count']) # (only) this got lost, omit for now SUBv

In [None]:
import pandas as pd
import numpy as np
from scipy.spatial.distance import cdist

def compute_distance_matrix(spatialcentroids):
    """
    Computes the pairwise Euclidean distance matrix between regions given a
    pandas DataFrame of centroids with columns: 'x_index', 'y_index', 'z_index'.

    Parameters:
        spatialcentroids (pd.DataFrame): DataFrame containing the region centroids.

    Returns:
        pd.DataFrame: A DataFrame representing the pairwise distance matrix,
                      with rows and columns corresponding to the index of the input DataFrame.
    """
    coords = spatialcentroids[['x_index', 'y_index', 'z_index']].values
    
    dist_matrix = cdist(coords, coords, metric='euclidean')
    
    distance_df = pd.DataFrame(dist_matrix, index=spatialcentroids.index, columns=spatialcentroids.index)
    return distance_df

distance_matrix_df = compute_distance_matrix(spatialcentroids)
distance_matrix_df

In [None]:
connectome = connectome.loc[distance_matrix_df.index, distance_matrix_df.columns]
atlas = datavignettes.loc[datavignettes['adaptedacronym'].isin(allen.index),:] # filter accordingly
spatialcentroids

In [None]:
spatialcentroids['count'].min() # sanity check on region size used for estimate...

In [None]:
spatialcentroids.loc["ICd",:] # sanity check on posterior guy

In [None]:
spatialcentroids['x_index'].sort_values() # makes sense

In [None]:
distance_matrix_df_values = distance_matrix_df.values  

i_upper = np.triu_indices_from(distance_matrix_df_values, k=1)
values = distance_matrix_df_values[i_upper]

plt.hist(values, bins=100)
plt.xlabel("Value")
plt.ylabel("Count")
plt.title("Histogram of upper triangle (excluding diagonal)")
plt.show()

## Prepare the lipizones-based region x region matrix

In [None]:
# 1) who's anatomical?

import scipy.cluster.hierarchy as sch

acronyms = atlas['acronym'].copy()
lipizones = atlas['lipizone_names'].copy()

acronyms = acronyms.loc[acronyms.isin(acronyms.value_counts().index[acronyms.value_counts() > 50])]
lipizones = lipizones.loc[acronyms.index]

cmat = pd.crosstab(acronyms, lipizones)

normalized_df = cmat / cmat.sum() # fraction 
normalized_df = (normalized_df.T / normalized_df.T.mean()).T ## switch to enrichments
normalized_df1 = normalized_df.copy()
normalized_df1

cmat = pd.crosstab(lipizones, acronyms)
normalized_df = cmat / cmat.sum() 
normalized_df = (normalized_df.T / normalized_df.T.mean()).T 
normalized_df2 = normalized_df.copy().T
normalized_df2

normalized_df = normalized_df2 * normalized_df1
normalized_df[cmat.T < 20] = 0

linkage = sch.linkage(sch.distance.pdist(normalized_df.T), method='weighted', optimal_ordering=True)
order = sch.leaves_list(linkage)
normalized_df = normalized_df.iloc[:, order]

order = np.argmax(normalized_df.values, axis=1)
order = np.argsort(order)
normalized_df = normalized_df.iloc[order,:]

plt.figure(figsize=(10, 10))
sns.heatmap(normalized_df, cmap="Grays", cbar_kws={'label': 'Enrichment'}, xticklabels=True, yticklabels=False, vmin = np.percentile(normalized_df, 2), vmax = np.percentile(normalized_df, 98))

plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
plt.tick_params(axis='y', which='both', left=False, right=False)

plt.yticks(rotation=0)

plt.tight_layout()
plt.show()

In [None]:
normalized_df[cmat.T < 50] = 0
values = normalized_df.values.flatten()

np.sum(values > 10) / len(values)

In [None]:
presencematrix = normalized_df > 10
sharingmatrix = presencematrix @ presencematrix.T
sharingmatrix = sharingmatrix.loc[sharingmatrix.index.isin(distance_matrix_df.index), sharingmatrix.index.isin(distance_matrix_df.index)]
sharingmatrix # we lost 60 regions... check later!

In [None]:
S = sharingmatrix.copy().values
C = connectome.loc[sharingmatrix.index, sharingmatrix.index].copy().values
D = distance_matrix_df.loc[sharingmatrix.index, sharingmatrix.index].copy().values

In [None]:
plt.imshow(S)

## Make a null distribution with Metropolis sampling and test significance accounting for the spatial distance confounder

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# --- Helper functions ---

def get_upper_triangular_indices(n):
    """Return indices for the upper triangular part of an n x n matrix (excluding the diagonal)."""
    return np.triu_indices(n, k=1)

def get_high_connectivity_mask(C, indices, percentile=90):
    """
    Given a connectivity matrix C and indices for the upper triangle,
    determine which pairs are 'highly connected' based on a given percentile.
    """
    c_vals = C[indices]
    thresh = np.percentile(c_vals, percentile)
    high_mask = c_vals >= thresh
    return high_mask, thresh

def compute_T_from_pairs(L, S_pairs):
    """
    Compute the test statistic T as the difference between:
      - the proportion of pairs with connectivity label 1 (highly connected)
        that share at least one cell type
      - the proportion of pairs with connectivity label 0 (not highly connected)
        that share at least one cell type
    L: binary connectivity labels for each pair
    S_pairs: corresponding cell type sharing values for each pair
    """
    high_indices = (L == 1)
    low_indices = (L == 0)
    prop_high = np.mean(S_pairs[high_indices]) if np.sum(high_indices) > 0 else 0
    prop_low = np.mean(S_pairs[low_indices]) if np.sum(low_indices) > 0 else 0
    return prop_high - prop_low

def estimate_probability_function(D, high_mask, indices, bins=20):
    """
    Estimate the probability of a pair being highly connected as a function of its physical distance.
    Returns:
      - bin_centers: centers of distance bins
      - p_est: estimated probabilities for each bin
      - bin_edges: the edges of the bins used
    """
    d_vals = D[indices]
    # Histogram of all pairs' distances.
    hist_total, bin_edges = np.histogram(d_vals, bins=bins)
    # Histogram for pairs that are highly connected.
    hist_high, _ = np.histogram(d_vals[high_mask], bins=bin_edges)
    # Calculate probability for each bin, taking care to avoid division by zero.
    p_est = np.divide(hist_high, hist_total, out=np.zeros_like(hist_high, dtype=float), where=hist_total != 0)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    return bin_centers, p_est, bin_edges

def get_probability_for_distance(d, bin_edges, p_est):
    """
    For a given distance d, return the estimated probability from the binned p_est.
    """
    bin_index = np.digitize(d, bin_edges) - 1  # np.digitize returns index starting at 1
    # Ensure index is within valid range.
    bin_index = np.clip(bin_index, 0, len(p_est) - 1)
    return p_est[bin_index]

def simulate_connectivity_labels(D, indices, bin_edges, p_est):
    """
    For each region pair (given by indices), simulate a connectivity label (0/1)
    from a Bernoulli distribution with probability determined by its physical distance.
    """
    d_vals = D[indices]
    # For each distance, get the corresponding probability.
    p_vals = np.array([get_probability_for_distance(d, bin_edges, p_est) for d in d_vals])
    # Sample a binary label for each pair.
    simulated_labels = np.random.rand(len(p_vals)) < p_vals
    return simulated_labels.astype(int)

# --- Main permutation test function ---

def run_permutation_test(C, D, S, percentile=90, bins=20, n_perm=1000, random_state=42):
    """
    Run the permutation test using a distance-controlled simulation to generate
    a null distribution of the test statistic.
    
    C: connectomic strength matrix (n x n)
    D: physical distance matrix (n x n)
    S: cell type similarity matrix (n x n) (assumed binary here)
    percentile: threshold percentile for 'high connectivity'
    bins: number of bins for estimating p(d)
    n_perm: number of permutations to simulate
    random_state: for reproducibility
    """
    np.random.seed(random_state)
    n = C.shape[0]
    indices = get_upper_triangular_indices(n)
    
    # Step 1: Determine observed high connectivity pairs.
    high_mask_obs, thresh = get_high_connectivity_mask(C, indices, percentile=percentile)
    L_obs = high_mask_obs.astype(int)
    
    # Step 2: Compute observed test statistic.
    S_pairs = S[indices]  # cell type sharing for each region pair
    T_obs = compute_T_from_pairs(L_obs, S_pairs)
    print("Observed T statistic:", T_obs)
    
    # Step 3: Estimate the conditional probability p(d) from observed data.
    bin_centers, p_est, bin_edges = estimate_probability_function(D, high_mask_obs, indices, bins=bins)
    
    # Plot the estimated probability function.
    plt.figure(figsize=(6, 4))
    plt.plot(bin_centers, p_est, 'o-')
    plt.xlabel('Physical Distance')
    plt.ylabel('P(high connectivity | distance)')
    plt.title('Estimated Connectivity Probability vs. Distance')
    plt.show()
    
    # Step 4: Run simulations.
    T_null = []
    for i in range(n_perm):
        sim_labels = simulate_connectivity_labels(D, indices, bin_edges, p_est)
        T_sim = compute_T_from_pairs(sim_labels, S_pairs)
        T_null.append(T_sim)
    T_null = np.array(T_null)
    
    # Step 5: Calculate one-tailed p-value.
    p_value = np.mean(T_null >= T_obs)
    print("One-tailed p-value:", p_value)
    
    # (Optional) Plot the null distribution and mark T_obs.
    plt.figure(figsize=(6, 4))
    plt.hist(T_null, bins=30, alpha=0.7, color='gray', label='Null distribution')
    plt.axvline(T_obs, color='red', linestyle='dashed', linewidth=2, label='T_obs')
    plt.xlabel('Test Statistic T')
    plt.ylabel('Frequency')
    plt.title('Null Distribution of T with Observed T')
    plt.legend()
    plt.savefig("significancelipizonesconnections.pdf")
    plt.show()
    
    return T_obs, T_null, p_value

T_obs, T_null, p_value = run_permutation_test(C, D, S, percentile=95, bins=50, n_perm=1000)

In [None]:
percentile=95
bins=50
n_perm=1000
random_state=42
np.random.seed(random_state)
n = C.shape[0]
indices = get_upper_triangular_indices(n)

high_mask_obs, thresh = get_high_connectivity_mask(C, indices, percentile=percentile)
L_obs = high_mask_obs.astype(int)

S_pairs = S[indices]

bin_centers, p_est, bin_edges = estimate_probability_function(D, high_mask_obs, indices, bins=bins)

plt.figure(figsize=(6, 4))
plt.plot(bin_centers, p_est, 'o-')
plt.xlabel('Physical Distance')
plt.ylabel('P(high connectivity | distance)')
plt.title('Estimated Connectivity Probability vs. Distance')
plt.show()

In [None]:
# try different percentiles
T_obs, T_null, p_value = run_permutation_test(C, D, S, percentile=99.5, bins=50, n_perm=1000)

In [None]:
T_obs, T_null, p_value = run_permutation_test(C, D, S, percentile=75.0, bins=50, n_perm=1000)

In [None]:
# try different percentiles
T_obs, T_null, p_value = run_permutation_test(C, D, S, percentile=97.5, bins=50, n_perm=1000)

In [None]:
def diagnostic_convergence(C, D, S, percentile=90, bins=20, n_perm=1000, random_state=42):
    """
    Plot the trace of the test statistic T over the permutation iterations and
    its autocorrelation function.
    """
    
    np.random.seed(random_state)
    n = C.shape[0]
    indices = get_upper_triangular_indices(n)
    
    # Step 1: Determine observed high connectivity pairs.
    high_mask_obs, thresh = get_high_connectivity_mask(C, indices, percentile=percentile)
    L_obs = high_mask_obs.astype(int)
    
    # Step 2: Compute observed test statistic.
    S_pairs = S[indices]  # cell type sharing for each region pair
    T_obs = compute_T_from_pairs(L_obs, S_pairs)
    print("Observed T statistic:", T_obs)
    
    # Step 3: Estimate the conditional probability p(d) from observed data.
    bin_centers, p_est, bin_edges = estimate_probability_function(D, high_mask_obs, indices, bins=bins)
    
    # Plot the estimated probability function.
    plt.figure(figsize=(6, 4))
    plt.plot(bin_centers, p_est, 'o-')
    plt.xlabel('Physical Distance')
    plt.ylabel('P(high connectivity | distance)')
    plt.title('Estimated Connectivity Probability vs. Distance')
    plt.show()
    
    # Step 4: Run simulations.
    T_null = []
    for i in range(n_perm):
        sim_labels = simulate_connectivity_labels(D, indices, bin_edges, p_est)
        T_sim = compute_T_from_pairs(sim_labels, S_pairs)
        T_null.append(T_sim)
    T_null = np.array(T_null)
    
    plt.figure(figsize=(10, 4))
    plt.plot(T_null, marker='o', linestyle='-', alpha=0.7)
    plt.xlabel('Iteration')
    plt.ylabel('Test Statistic T')
    plt.title('Trace Plot of T over Permutations')
    plt.show()
    
    # Autocorrelation plot.
    from statsmodels.graphics.tsaplots import plot_acf
    plt.figure(figsize=(6, 4))
    plot_acf(T_null, lags=30)
    plt.title('Autocorrelation of T')
    plt.show()

diagnostic_convergence(C, D, S, percentile=97.5, bins=50, n_perm=1000)

In [None]:
def degree_preserving_randomization_with_distance(matrix, D, bin_edges, p_est, 
                                                  n_swaps=1000, temperature=1.0, 
                                                  random_state=42):
    """
    Perform degree-preserving randomization with a Metropolis–Hastings step 
    to bias the swaps towards preserving the distance-based edge probabilities.
    
    Parameters:
      matrix: symmetric binary connectivity matrix (observed).
      D: physical distance matrix.
      bin_edges, p_est: parameters of the estimated p(d) function.
      n_swaps: number of attempted edge swaps.
      temperature: controls the acceptance of swaps that worsen the p(d) match.
      random_state: seed for reproducibility.
      
    Returns:
      A new connectivity matrix with the same degree sequence as matrix.
    """
    np.random.seed(random_state)
    M = matrix.copy()
    n = M.shape[0]
    
    # Get list of existing edges in the upper triangle (i < j)
    edges = list(zip(*np.where(np.triu(M, k=1) == 1)))
    num_edges = len(edges)
    
    for _ in range(n_swaps):
        if num_edges < 2:
            break
        # Randomly select two distinct edges.
        idx1, idx2 = np.random.choice(num_edges, 2, replace=False)
        i, j = edges[idx1]
        k, l = edges[idx2]
        
        # Ensure the four nodes are distinct.
        if len({i, j, k, l}) != 4:
            continue
        
        # Proposed new edges: (i, l) and (k, j)
        # Check that these edges do not already exist.
        if M[i, l] == 1 or M[k, j] == 1:
            continue
        
        # Compute the distance-based probabilities for the old and new edges.
        p_old1 = get_probability_for_distance(D[i, j], bin_edges, p_est)
        p_old2 = get_probability_for_distance(D[k, l], bin_edges, p_est)
        p_new1 = get_probability_for_distance(D[i, l], bin_edges, p_est)
        p_new2 = get_probability_for_distance(D[k, j], bin_edges, p_est)
        
        # Sum the probabilities for old and new pairs.
        sum_old = p_old1 + p_old2
        sum_new = p_new1 + p_new2
        
        # Compute the change; if new edges are “better” (i.e. higher total probability)
        # then delta will be negative.
        delta = sum_old - sum_new
        
        # Metropolis acceptance probability: if delta < 0 (improvement), accept always;
        # otherwise, accept with probability exp(-delta/temperature).
        accept_prob = np.exp(-delta/temperature) if delta > 0 else 1.0
        
        if np.random.rand() < accept_prob:
            # Perform the swap.
            M[i, j] = M[j, i] = 0
            M[k, l] = M[l, k] = 0
            M[i, l] = M[l, i] = 1
            M[k, j] = M[j, k] = 1
            # Update the edges list with the new edges (sorted as (min, max)).
            edges[idx1] = (min(i, l), max(i, l))
            edges[idx2] = (min(k, j), max(k, j))
            
    return M

def run_permutation_test_degree_preserved_with_distance(C, D, S, percentile=95, 
                                                        n_perm=1000, n_swaps=1000, 
                                                        temperature=1.0, bins=50,
                                                        random_state=42):
    """
    Run the permutation test using degree-preserving randomization that also
    biases swaps to preserve the distance structure.
    
    Parameters:
      C: Original connectivity matrix.
      D: Physical distance matrix.
      S: Cell type similarity matrix (binary).
      percentile: threshold percentile to define the observed binary connectivity.
      n_perm: number of permutations.
      n_swaps: number of swaps for each randomization.
      temperature: MH temperature for swap acceptance.
      bins: number of bins to estimate the p(d) function.
      random_state: seed for reproducibility.
      
    Returns:
      T_obs: observed test statistic.
      T_null: array of test statistics from permutations.
      p_value: one-tailed p-value.
      perm_conn_matrices: list of randomized connectivity matrices.
    """
    np.random.seed(random_state)
    n = C.shape[0]
    
    # Create the observed binary connectivity matrix.
    ground_truth = (C > np.percentile(C, percentile)).astype(int)
    indices = np.triu_indices(n, k=1)
    
    # Compute the observed test statistic.
    S_pairs = S[indices]
    T_obs = compute_T_from_pairs(ground_truth[indices].astype(int), S_pairs)
    print("Observed T statistic:", T_obs)
    
    # Estimate the conditional probability p(d) from observed data.
    bin_centers, p_est, bin_edges = estimate_probability_function(D, 
                                                                    ground_truth[indices].astype(bool), 
                                                                    indices, 
                                                                    bins=bins)
    
    T_null = []
    perm_conn_matrices = []
    
    for i in range(n_perm):
        perm_matrix = degree_preserving_randomization_with_distance(ground_truth, D, bin_edges, p_est, 
                                                                    n_swaps=n_swaps, temperature=temperature, 
                                                                    random_state=random_state + i)
        T_sim = compute_T_from_pairs(perm_matrix[indices].astype(int), S_pairs)
        T_null.append(T_sim)
        perm_conn_matrices.append(perm_matrix)
    
    T_null = np.array(T_null)
    p_value = np.mean(T_null >= T_obs)
    print("One-tailed p-value:", p_value)
    
    # Plot the null distribution.
    plt.figure(figsize=(6, 4))
    plt.hist(T_null, bins=30, alpha=0.7, label='Null distribution')
    plt.axvline(T_obs, color='red', linestyle='dashed', linewidth=2, label='T_obs')
    plt.xlabel('Test Statistic T')
    plt.ylabel('Frequency')
    plt.title('Null Distribution of T with Observed T')
    plt.legend()
    plt.show()
    
    return T_obs, T_null, p_value, perm_conn_matrices

# Example usage with a small number of permutations (e.g., 5) for testing:
T_obs, T_null, p_value, perm_conn_matrices = run_permutation_test_degree_preserved_with_distance(
    C, D, S, percentile=95, n_perm=1000, n_swaps=1000, temperature=1.0, bins=50, random_state=42
)


## Flag lipizones as input-output related based on the density of the local connectivity network

In [None]:
allen = pd.read_csv("./zenodo/csv/allenconnectome_contra.csv", index_col=0) + pd.read_csv("./zenodo/csv/allenconnectome_ipsi.csv", index_col=0) 
cm = normalized_df.T
cm = cm.loc[:, cm.sum()>0]
cm

In [None]:
def assess_subnetwork_density(subnet_indices, allen, num_permutations=10000):

    subnet_indices = np.array(subnet_indices)
    
    observed_submatrix = allen.loc[subnet_indices, subnet_indices]
    n = len(subnet_indices)
    
    if n > 1:
        mask = ~np.eye(n, dtype=bool)
        observed_values = observed_submatrix.values[mask]
        observed_mean = observed_values.mean()
    else:
        observed_mean = observed_submatrix.values[0, 0]
    
    null_means = np.empty(num_permutations)
    all_nodes = allen.index.to_numpy()
    for i in range(num_permutations):
        random_nodes = np.random.choice(all_nodes, size=n, replace=False)
        random_submatrix = allen.loc[random_nodes, random_nodes].values
        if n > 1:
            mask = ~np.eye(n, dtype=bool)
            mean_val = random_submatrix[mask].mean()
        else:
            mean_val = random_submatrix[0, 0]
        null_means[i] = mean_val

    p_value = np.mean(null_means >= observed_mean)
    
    return {"observed_mean": observed_mean,
            "p_value": p_value,
            "null_distribution": null_means}

allanalyses = []

from tqdm import tqdm
for lev7 in tqdm(cm.index.values):
    try:
        topregionshere = cm.loc[lev7,:].sort_values()[::-1]
        topregionshere = topregionshere[topregionshere > 49].index
        subnetwork_nodes = topregionshere.values
        allanalyses.append(assess_subnetwork_density(subnetwork_nodes, allen, num_permutations=10000))
    except:
        print(lev7)
        print("problem!")
        continue
        
import pickle

with open("allanalyses.pkl", "wb") as f:
    pickle.dump(allanalyses, f)
    
pvals = [x['p_value'] for x in allanalyses]
pvalsdf = pd.DataFrame(pvals, index=indexes)
pvalsdf = pvalsdf.groupby(pvalsdf.index).mean()
pvalsdf