The code to analyze the permutations is adapted from: https://github.com/ZhuangLab/whole_mouse_brain_MERFISH_atlas_scripts_2023/blob/main/scripts/cell_cell_contacts/get_significant_contacts_30um.ipynb

In [None]:
import scanpy as sc
import pandas as pd
import seaborn as sns
import os
import string
import geopandas as gpd
import numpy as np
from shapely.geometry import Point
from tqdm import tqdm
import Mapping
import matplotlib.pyplot as plt
import multiprocessing as mp
from functools import partial
import gseapy as gp
import networkx as nx

import skimage
import cv2
from skimage.morphology import disk, opening, closing
from scipy.ndimage import binary_fill_holes, label, distance_transform_edt
from skimage.segmentation import find_boundaries, watershed
from skimage.feature import peak_local_max
from skimage.measure import regionprops
from skimage.segmentation import watershed
import scipy.stats
import statsmodels.stats.multitest

from sklearn.cluster import SpectralClustering
import networkx as nx
from matplotlib.lines import Line2D

# CLQ Scoring

In [None]:
def clq_array(base_path, batch_list, morphology_list, output_p,compartment='soma'):
    final_clq = []
    for morph in morphology_list:
        for batch in batch_list
            transcripts = find_filtered_transcripts(base_path+batch+'/')
            genes = np.unique(transcripts.gene.unique().tolist())        
            soma_clq.append(np.load(f'{output_p}{batch}/morph_{morph}/{compartment}_clq_scores.npy'))
    final_array = np.concatenate(final_clq, axis=0)
    return final_array

In [None]:
# defining variables
output_path = 'permutation_coloc/'
morphologies = ['0','1','2','3','4']
batches =['3-mo-male-1',
          '3-mo-male-2',
          '3-mo-male-3-rev2',
          '3-mo-female-1-rev2',
          '3-mo-female-2',
          '3-mo-female-3',
          '24-mo-male-1',
          '24-mo-male-2',
          '24-mo-male-4-rev2',
          '24-mo-female-1',
          '24-mo-female-3',
          '24-mo-female-5']
# adjust this to your base path
base_data_path = ''

In [None]:
final_array_soma = clq_array(base_data_path, batches, morphologies, output_path,compartment='soma')
final_array_branches = clq_array(base_data_path, batches, morphologies, output_path,compartment='branches')

In [None]:
# graph the soma cluster map and record the order of genes
average_array_soma = np.mean(final_array_soma, axis=0)
df_soma = pd.DataFrame(average_array_soma, index=genes, columns=genes)

g = sns.clustermap(
    df_soma,
    cmap='Reds',
    figsize=(10, 10),
    method='average',
    metric='euclidean',
    vmax = 2.0,
    cbar_kws={'label': 'Average CLQ Score'}
)
gene_order = g.dendrogram_row.reordered_ind  # Save the ordered row indices
ordered_genes_list = df_soma.index[gene_order].tolist()  # Get the ordered gene names
plt.savefig('coloc_figs/soma_clq.pdf', format='pdf')

In [None]:
# order the branches clustermap to match the gene order of the soma clustermap
average_array_branches = np.mean(final_array_branches, axis=0)
df_branches = pd.DataFrame(average_array_branches, index=genes, columns=genes)

df_branches_ordered = df_branches.loc[ordered_genes_list, ordered_genes_list]

# Plot the clustermap for df2 with the saved ordering
sns.clustermap(
    df_branches_ordered,
    cmap='Reds',         # Change the colormap if desired
    figsize=(10, 10),
    row_cluster=False,       # Disable row clustering
    col_cluster=False,
    vmax=2.0,
    cbar_kws={'label': 'Second Dataset CLQ Score'}
)

plt.title('Clustermap with Preserved Ordering', fontsize=16)
plt.savefig('coloc_figs/branches_clq.pdf', format='pdf')
plt.show()

In [None]:
def plot_clustermap_subset(graph,df,df2,n_genes,save_pdf=None)
    '''
    Plotting the top right square of n_genes.
    
    graph = clustermap object
    df = dataframe with genes in the correct order
    df2 = dataframe with the values we are graphing
    n_genes = size of square at the top left we want to graph
    
    '''
    gene_order = graph.dendrogram_row.reordered_ind  # Row order from the dendrogram
    ordered_genes = df.index[gene_order]  # Ordered gene names
    
    subset_genes = ordered_genes[:n_genes]

# Subset the DataFrame for the cluster of interest
    df_subset = df2.loc[subset_genes, subset_genes]

    # Visualize the subset as a clustermap or heatmap
    sns.heatmap(
        df_subset,
        cmap='Reds',
        vmax=2.0,
        vmin=0,
        cbar_kws={'label': 'Average CLQ Score'}
    )

    plt.title('Subset of Genes in Bottom-Right Cluster', fontsize=16)
    if save_pdf:
        plt.savefig(save_pdf, format='pdf')
    plt.show()

In [None]:
# plot soma clq subset
plot_clustermap_subset(g,df_soma, df_soma,14,save_pdf='coloc_figs/soma_clq_subset.pdf')

In [None]:
# plot branches clq subset
plot_clustermap_subset(g,df_soma, df_branches_ordered,14,save_pdf='coloc_figs/branches_clq_subset.pdf')

# Colocalization Networks

In [None]:
def count_zero_pairs(contact_mtx):
    n_0 = 0
    for i in range(contact_mtx.shape[0]):
        for j in range(i, contact_mtx.shape[0]):
            if contact_mtx[i, j] == 0:
                n_0 += 1
    return n_0

def adjust_p_value_matrix_by_BH(p_val_mtx):
    '''Adjust the p-values in a matrix by the Benjamini/Hochberg method.
    The matrix should be symmetric.
    '''
    p_val_sequential = []
    N = p_val_mtx.shape[0]
    
    for i in range(N):
        for j in range(i, N):
            p_val_sequential.append(p_val_mtx[i, j])

    p_val_sequential_bh = statsmodels.stats.multitest.multipletests(p_val_sequential, method='fdr_bh')[1]
    
    adjusted_p_val_mtx = np.zeros((N, N))
    
    counter = 0
    for i in range(N):
        for j in range(i, N):
            adjusted_p_val_mtx[i, j] = p_val_sequential_bh[counter]
            adjusted_p_val_mtx[j, i] = p_val_sequential_bh[counter]
            counter += 1
            
    return adjusted_p_val_mtx

def get_data_frame_from_metrices(cell_types, mtx_dict):
    N = len(cell_types)
    
    serials_dict = {'cell_type1':[], 'cell_type2':[]}
    for k in mtx_dict.keys():
        serials_dict[k] = []
        
    for i in range(N):
        for j in range(i, N):
            serials_dict['cell_type1'].append(cell_types[i])
            serials_dict['cell_type2'].append(cell_types[j])
            for k in mtx_dict.keys():
                serials_dict[k].append(mtx_dict[k][i, j])
                
    return pd.DataFrame(serials_dict)
    

def sort_cell_type_contact_p_values(p_val_mtx, cell_types):
    '''Return a list of (cell_type1, cell_type2, p_value) sorted by p_values.'''
    p_val_list = []
    N = p_val_mtx.shape[0]
    for i in range(N):
        for j in range(i, N):
            p_val_list.append((cell_types[i], cell_types[j], p_val_mtx[i, j]))
    return sorted(p_val_list, key=lambda x:x[2])

import scipy.cluster
#from scattermap import scattermap

def get_optimal_order_of_mtx(X):
    Z = scipy.cluster.hierarchy.ward(X)
    return scipy.cluster.hierarchy.leaves_list(
        scipy.cluster.hierarchy.optimal_leaf_ordering(Z, X))

def get_ordered_tick_labels(tick_labels):
    tick_labels_with_class = [s.split(' ')[-1] + ' ' + s for s in tick_labels]
    return np.argsort(tick_labels_with_class)

def filter_pval_mtx(pval_mtx, tick_labels, allowed_pairs):
    pval_mtx_filtered = pval_mtx.copy()
    
    for i in range(pval_mtx.shape[0]):
        ct1 = tick_labels[i]
        for j in range(pval_mtx.shape[1]):
            ct2 = tick_labels[j]
            
            if ((ct1, ct2) in allowed_pairs) or ((ct2, ct1) in allowed_pairs):
                continue
            else:
                pval_mtx_filtered[i, j] = 1
            
    return pval_mtx_filtered

def make_dotplot(pval_mtx, fold_change_mtx, tick_labels, title='', allowed_pairs=None):

    #optimal_order = get_optimal_order_of_mtx(pval_mtx)
    optimal_order = get_ordered_tick_labels(tick_labels)
    
    pval_mtx = pval_mtx[optimal_order][:, optimal_order]
    fold_change_mtx = fold_change_mtx[optimal_order][:, optimal_order]
    tick_labels = tick_labels[optimal_order]
    
    
def find_filtered_transcripts(experiment_path):
    region_types = ['region_0', 'region_1']
    for region in region_types:
        file_path = f'{experiment_path}baysor/detected_transcripts.csv'
        if os.path.exists(file_path):
            return pd.read_csv(file_path,index_col=0)
    return None

def permutation_analysis(base_path, batch_list, morphology_list, output_p,compartment='soma'):
    full_df = []
    # so I want to iterate through each geometry class
    for morph in morphology_list:
        # then through every batch
        for batch in batch_list:
            # load the gene names
            transcripts = find_filtered_transcripts(base_path+batch+'/')
            genes = np.unique(transcripts.gene.unique().tolist())
    
            gene_coloc_counts = np.load(f'{output_p}{batch}/morph_{morph}/{compartment}_no_permutation.npy')

            local_null_means = np.load(f'{output_p}{batch}/morph_{morph}/{compartment}_full_permutation_mean.npy')
            local_null_stds = np.load(f'{output_p}{batch}/morph_{morph}/{compartment}_full_permutation_std.npy')

            # Require all stds to be larger or equal to the minimal observable std value
            local_null_stds = np.maximum(local_null_stds, np.sqrt(1 / 1000))
    
            local_z_scores = (gene_coloc_counts - local_null_means) / local_null_stds
            local_p_values = scipy.stats.norm.sf(local_z_scores)
            adjusted_local_p_values = adjust_p_value_matrix_by_BH(local_p_values)
    
            fold_changes = gene_coloc_counts / (local_null_means + 1e-4)
        
            # Gather all results into a data frame
            contact_result_df = get_data_frame_from_metrices(genes, 
                                             {'pval-adjusted': adjusted_local_p_values,
                                              'pval': local_p_values,
                                              'z_score': local_z_scores,
                                              'contact_count': gene_coloc_counts,
                                              'permutation_mean': local_null_means,
                                              'permutation_std': local_null_stds,
                                            }).sort_values('z_score', ascending=False)

            
            contact_result_df.to_csv(f'{output_p}{batch}/morph_{morph}/{compartment}_close_contacts.csv')
    
            full_df.append(contact_result_df)
        
    return full_df

In [None]:
# name all of the necessary variables

output_path = 'permutation_coloc/'
morphologies = ['0','1','2','3','4']
batches =['3-mo-male-1',
          '3-mo-male-2',
          '3-mo-male-3-rev2',
          '3-mo-female-1-rev2',
          '3-mo-female-2',
          '3-mo-female-3',
          '24-mo-male-1',
          '24-mo-male-2',
          '24-mo-male-4-rev2',
          '24-mo-female-1',
          '24-mo-female-3',
          '24-mo-female-5']

# replace with where the data is kept
base_data_path = ''

In [None]:
soma_df = permutation_analysis(base_data_path, batches, morphologies, output_path,compartment='soma')
branches_df = permutation_analysis(base_data_path, batches, morphologies, output_path,compartment='branches')

# Trimming the dataframes and looking at specific morphologies

In [None]:
def filter_disconnected_pairs(df, node1_col, node2_col):
    """
    Remove disconnected components with only two nodes from the graph created by the DataFrame.
    Args:
        df (pd.DataFrame): Input DataFrame containing at least two columns representing edges in a graph.
        node1_col (str): Name of the first column representing nodes (e.g., 'cell_type1').
        node2_col (str): Name of the second column representing nodes (e.g., 'cell_type2').
    Returns:
        pd.DataFrame: Filtered DataFrame containing only edges from valid components.
    """
    # Create a graph from the DataFrame
    G = nx.Graph()
    G.add_edges_from(zip(df[node1_col], df[node2_col]))

    # Identify connected components
    components = list(nx.connected_components(G))

    # Keep only components with more than two nodes
    valid_components = [comp for comp in components if len(comp) > 2]

    # Flatten the valid components into a set of nodes
    valid_nodes = set(node for comp in valid_components for node in comp)

    # Filter the DataFrame to retain only rows where both nodes are in valid components
    filtered_df = df[
        (df[node1_col].isin(valid_nodes)) &
        (df[node2_col].isin(valid_nodes))
    ]
    return filtered_df

def extract_coloc_tables(morphologies, batches, contact_thresh=4, pval_thresh=0.05):
    """
    Extract and filter colocalization tables for soma and branches, 
    including filtering out disconnected graph components with only two nodes.

    Args:
        morphologies (list): List of morphology types.
        batches (list): List of batch identifiers.
        contact_thresh (int): Minimum contact count threshold for filtering.
        pval_thresh (float): Maximum adjusted p-value threshold for filtering.

    Returns:
        tuple: Filtered soma and branches DataFrames.
    """
    soma_df = []
    branches_df = []

    # Iterate over morphologies and batches to load data
    for morph in morphologies:
        for batch in batches:
            soma_df.append(pd.read_csv(f'{output_path}{batch}/morph_{morph}/soma_close_contacts.csv', index_col=0))
            branches_df.append(pd.read_csv(f'{output_path}{batch}/morph_{morph}/branches_close_contacts.csv', index_col=0))

    # Combine all data
    soma_full = pd.concat(soma_df)
    branches_full = pd.concat(branches_df)

    # Apply initial filtering based on thresholds
    soma_full = soma_full[
        (soma_full.contact_count > contact_thresh) &
        (soma_full['pval-adjusted'] < pval_thresh) &
        (soma_full.cell_type1 != soma_full.cell_type2)
    ]

    branches_full = branches_full[
        (branches_full.contact_count > contact_thresh) &
        (branches_full['pval-adjusted'] < pval_thresh) &
        (branches_full.cell_type1 != branches_full.cell_type2)
    ]

    # Group by cell_type1 and cell_type2, averaging numerical values
    soma_full = soma_full.groupby(['cell_type1', 'cell_type2']).mean().reset_index()
    branches_full = branches_full.groupby(['cell_type1', 'cell_type2']).mean().reset_index()

    # Apply graph-based filtering to remove disconnected pairs
    soma_full = filter_disconnected_pairs(soma_full, 'cell_type1', 'cell_type2')
    branches_full = filter_disconnected_pairs(branches_full, 'cell_type1', 'cell_type2')

    return soma_full, branches_full

def save_dataframes_to_excel(dataframes, sheet_names, output_path):
    """
    Save multiple DataFrames to a single Excel file, with each DataFrame on a separate sheet.

    Args:
        dataframes (list): List of pandas DataFrames to save.
        sheet_names (list): List of sheet names corresponding to each DataFrame.
        output_path (str): Path to save the Excel file (e.g., 'output.xlsx').
    """
    with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
        for df, sheet_name in zip(dataframes, sheet_names):
            df.to_excel(writer, sheet_name=sheet_name, index=False)
    print(f"DataFrames saved to {output_path}")

def plot_contact_network_with_clusters_spectral(df, title, threshold=0.05, n_clusters=3,save_path=None):
    """
    Plot a gene-gene contact network and identify clusters using Spectral Clustering.
    
    Args:
        df (pd.DataFrame): Dataframe with gene-gene interaction data.
        title (str): Title of the plot (e.g., 'Young' or 'Old').
        threshold (float): p-value threshold to filter significant interactions.
        n_clusters (int): Number of clusters to identify.
    
    Returns:
        dict: A dictionary where keys are cluster IDs and values are lists of nodes in each cluster.
    """
    # Filter data based on p-value threshold
    df_filtered = df[df['pval-adjusted'] < threshold]

    # Create a network graph
    G = nx.Graph()
    
    # Add edges with weights based on z_score
    for _, row in df_filtered.iterrows():
        G.add_edge(row['cell_type1'], row['cell_type2'], weight=row['z_score'])
    
    # Convert graph to adjacency matrix
    adjacency_matrix = nx.to_numpy_array(G)

    # Perform Spectral Clustering
    clustering = SpectralClustering(
        n_clusters=n_clusters, 
        affinity='precomputed', 
        random_state=42
    ).fit(adjacency_matrix)

    # Assign cluster labels
    node_list = list(G.nodes())
    cluster_labels = clustering.labels_  # Cluster labels for each node
    
    # Group nodes by cluster
    clusters = {}
    for node, cluster_id in zip(node_list, cluster_labels):
        clusters.setdefault(cluster_id, []).append(node)
    
    # Draw the network with cluster-based coloring
    pos = nx.spring_layout(G, seed=42)  # Layout for consistent visualization
    cluster_colors = cluster_labels  # Colors based on cluster IDs
    plt.figure(figsize=(10, 10))
    nodes = nx.draw_networkx_nodes(G, pos, node_size=400, node_color=cluster_colors, cmap=plt.cm.Set3, alpha=0.8)
    nx.draw_networkx_edges(G, pos, width=1.0, edge_color='gray', alpha=0.5)
    nx.draw_networkx_labels(G, pos, font_size=8, font_weight='bold')
    
    # Add a legend for the clusters
    cluster_ids = list(set(cluster_labels))
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', label=f'Cluster {cluster_id}',
               markerfacecolor=plt.cm.Set3(cluster_id / max(cluster_ids)), markersize=10)
        for cluster_id in cluster_ids
    ]
    plt.legend(handles=legend_elements, title='Clusters', loc='upper right')
    
    # Add title
    plt.title(title, fontsize=16)
    
    if save_path:
        plt.savefig(f'coloc_figs/{save_path}', format='pdf')
        
    plt.show()
    
    return clusters

def perform_go_enrichment(clusters_dict, organism='Mouse'):
    """
    Perform GO enrichment analysis for each cluster and save results.

    Args:
        clusters_dict (dict): Dictionary where keys are cluster IDs and values are lists of genes.
        output_dir (str): Directory to save GO enrichment results.
        organism (str): Organism for enrichment analysis ('Human' or 'Mouse').

    Returns:
        dict: A dictionary where keys are cluster IDs and values are DataFrames of enrichment results.
    """
    go_results = {}
    
    for cluster_id, genes in clusters_dict.items():
        # Perform GO enrichment analysis
        enrichment_results = gp.enrichr(
            gene_list=genes,  # List of genes in the cluster
            gene_sets='GO_Biological_Process_2023',  # Gene Ontology Biological Process
            organism=organism
        )
        
        # Store results in dictionary
        go_results[cluster_id] = enrichment_results.results
    
    return go_results

def plot_go_enrichment(go_results, top_n=5, save_path=None, global_max_score=None, figure_width=10):
    """
    Plot a horizontal bar graph of the top GO terms for each cluster, ensuring consistent graph sizes.
    
    Args:
        go_results (dict): Dictionary where keys are cluster IDs and values are DataFrames of GO enrichment results.
        top_n (int): Number of top GO terms to display per cluster.
        save_path (str): Path to save the figure as a PDF (optional).
        global_max_score (float): Global maximum score for x-axis synchronization (optional).
        figure_width (float): Fixed width of the graph area (excluding labels).
    """
    # Prepare data for plotting
    plot_data = []
    for cluster_id, df in go_results.items():
        # Select top N GO terms based on adjusted p-value
        top_terms = df.sort_values('Adjusted P-value').head(top_n)
        for _, row in top_terms.iterrows():
            plot_data.append({
                'Cluster': cluster_id,
                'GO Term': row['Term'],
                '-log10(P-value)': -np.log10(row['Adjusted P-value']),
                'Combined Score': row['Combined Score']
            })
    
    # Convert to DataFrame
    plot_df = pd.DataFrame(plot_data)
    
    # Determine the x-axis limit
    max_score = global_max_score or plot_df['Combined Score'].max()

    # Estimate space needed for y-axis labels
    max_label_length = plot_df['GO Term'].str.len().max()
    label_width = max_label_length * 0.1  # Estimate label width in inches (adjust scaling factor as needed)
    total_width = figure_width + label_width

    # Plot horizontal bar graph
    plt.figure(figsize=(total_width, len(plot_df) * 0.4))
    sns.barplot(
        data=plot_df,
        y='GO Term',
        x='Combined Score',
        hue='Cluster',
        dodge=False,  # Avoid overlapping bars
        palette='Set2'
    )
    plt.xlabel('Combined Score', fontsize=12)
    plt.ylabel('GO Term', fontsize=12)
    plt.xlim([0, max_score])  # Use the global max score for consistent x-axes
    plt.title('GO Enrichment Analysis by Cluster', fontsize=14)
    plt.legend(title='Cluster', loc='best')
    plt.tight_layout()
    if save_path:
        plt.savefig(f'coloc_figs/{save_path}', format='pdf')
    plt.show()

In [None]:
# Define the batches (split by age) and most complex morphology
morphologies = ['4']

batches_3 =['3-mo-male-1',
          '3-mo-male-2',
          '3-mo-male-3-rev2',
          '3-mo-female-1-rev2',
          '3-mo-female-2',
          '3-mo-female-3']

batches_24 =['24-mo-male-1',
          '24-mo-male-2',
          '24-mo-male-4-rev2',
          '24-mo-female-1',
          '24-mo-female-3',
          '24-mo-female-5']

young_4_soma, young_4_branches = extract_coloc_tables(morphologies, batches_3)
old_4_soma, old_4_branches = extract_coloc_tables(morphologies, batches_24)

In [None]:
dataframes = [young_4_soma, young_4_branches, old_4_soma, old_4_branches]
sheet_names = ['Young_Soma', 'Young_Processes', 'Old_Soma', 'Old_Processes']

output_file = 'coloc_figs/gene_colocalization_networks.xlsx'
save_dataframes_to_excel(dataframes, sheet_names, output_file)

In [None]:
# Plot the graphs of the colocalization networks, clusters set to 3 by default, can manually tune
young_soma_gene_clusters = plot_contact_network_with_clusters_spectral(young_4_soma, '3 month soma', threshold=0.05,save_path='3_month_soma.pdf')
young_branches_gene_clusters = plot_contact_network_with_clusters_spectral(young_4_branches, '3 month branches', threshold=0.05,n_clusters=2,save_path='3_month_branches.pdf')
old_soma_gene_clusters = plot_contact_network_with_clusters_spectral(old_4_soma, '24 month soma', threshold=0.05,n_clusters=1,save_path='24_month_soma.pdf')
old_branches_gene_clusters = plot_contact_network_with_clusters_spectral(old_4_branches, '24 month branches', threshold=0.05,n_clusters=2,save_path='24_month_branches.pdf')

In [None]:
# 3 month soma gene ontology
results = perform_go_enrichment(young_soma_gene_clusters)
plot_go_enrichment(results, top_n=5,save_path='3_month_soma_GO.pdf',global_max_score=12000)

In [None]:
# 24 month soma gene ontology
results = perform_go_enrichment(old_soma_gene_clusters)
plot_go_enrichment(results, top_n=5,save_path='24_month_soma_GO.pdf',global_max_score=12000)

In [None]:
# 3 month branches gene ontology
results = perform_go_enrichment(young_branches_gene_clusters)
plot_go_enrichment(results, top_n=5,save_path='3_month_branches_GO.pdf',global_max_score=20000)

In [None]:
# 24 month branches gene ontology
results = perform_go_enrichment(old_branches_gene_clusters)
plot_go_enrichment(results, top_n=5,save_path='24_month_branches_GO.pdf',global_max_score=20000)