## Result 6: neuron clusters
specific neuronal clusters identified by scRNA-seq

### data

In [None]:
from scipy.io import mmread
from pathlib import Path

import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

#### Matched Single-Cell RNA Sequencing

In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
from scipy.io import mmread
import anndata as ad

# Define file paths
single_cell_data_folder_path = Path("../data/mouse_hypothalamus/singlecell/")
mtx_path = single_cell_data_folder_path / "GSE113576_matrix.mtx"
barcodes_path = single_cell_data_folder_path / "GSE113576_barcodes.tsv"
genes_path = single_cell_data_folder_path / "GSE113576_genes.tsv"
meta_path = single_cell_data_folder_path / "aau5324_Moffitt_Table-S1.xlsx"

In [None]:
# Load data
X = mmread(mtx_path).tocsr()
cell_ids = pd.read_csv(barcodes_path, sep="\t", header=None)[0].values
gene_names = pd.read_csv(genes_path, sep="\t", header=None)[1].values

# Create AnnData object
adata = ad.AnnData(X=X.T)  # Transpose to cell x gene
adata.var_names = gene_names
adata.obs_names = cell_ids
adata.var_names_make_unique()

# Load and process metadata
meta = pd.read_excel(meta_path).rename(columns={
    "Cell name": "Cell_name",
    "Sex": "Sex",
    "Replicate number": "Rep",
    "Cell class (determined from clustering of all cells)": "Cell_class",
    "Non-neuronal cluster (determined from clustering of all cells)": "Non_neuronal_cluster",
    "Neuronal cluster (determined from clustering of inhibitory or excitatory neurons)": "Neuronal_cluster"
})

neuron_types = ['Excitatory', 'Inhibitory']
meta = meta.set_index("Cell_name")
meta = meta.loc[meta["Cell_class"].isin(neuron_types)]

adata = adata[adata.obs_names.isin(meta.index)].copy()
adata.obs = meta.loc[adata.obs_names, ["Cell_class", "Neuronal_cluster"]]

mt_genes = adata.var_names.str.startswith("mt")
mt_frac = np.array(adata[:, mt_genes].X.sum(axis=1)).flatten() / (
    np.array(adata.X.sum(axis=1)).flatten() + 1e-6
)
adata = adata[mt_frac < 0.2, :].copy()

n_detected = np.array((adata.X != 0).sum(axis=1)).flatten()
adata = adata[n_detected > 1000, :].copy()

adata = adata[:, ~adata.var_names.str.startswith("Blank")].copy()

sc_total = np.array(adata.X.sum(axis=1)).flatten()
adata.X = adata.X.multiply(10_000 / sc_total[:, None])
adata.X = adata.X.log1p()

In [None]:
sc_data = adata.to_df()
neuron_cluster = adata.obs["Neuronal_cluster"]
sc_data["Neuronal_cluster"] = neuron_cluster

In [None]:
from natsort import natsorted

clusters = sc_data['Neuronal_cluster'].unique()
sorted_clusters = natsorted(clusters)

sorted_index = []
for cluster in sorted_clusters:
    cells = sc_data.index[sc_data['Neuronal_cluster'] == cluster]
    sorted_index.extend(cells)

sc_data_sorted = sc_data.loc[sorted_index]

#### Marker Genes

differentially expressed genes identified by BANKSY

In [None]:
# all differentially expressed genes
DE_genes = ['Mlc1', 'Dgkk', 'Cbln2', 'Syt4', 'Gad1', 'Plin3', 'Gnrh1', 'Sln', 'Gjc3', 'Mbp', 'Lpar1', 'Trh', 'Ucn3', 'Cck']
# DE_genes_gm: 7
DE_genes_gm = ['Mlc1', 'Dgkk', 'Cbln2', 'Syt4', 'Gad1', 'Plin3', 'Gnrh1', 'Sln', 'Gjc3']
# DE_genes_wm: 8
DE_genes_wm = ['Mbp', 'Lpar1', 'Trh', 'Ucn3', 'Cck']

In [None]:
# Excitatory
Ex_sc_data = sc_data_sorted[sc_data_sorted['Neuronal_cluster'].str.startswith("e")]
sc_MOD2 = Ex_sc_data.loc[:, [gene for gene in DE_genes_gm if gene in Ex_sc_data.columns]]
sc_MOD1 = Ex_sc_data.loc[:, [gene for gene in DE_genes_wm if gene in Ex_sc_data.columns]]
Ex_DE = pd.concat([sc_MOD2, sc_MOD1], axis=1)

Ex_cluster = Ex_sc_data["Neuronal_cluster"]


In [None]:
Ex_DE.shape

In [None]:
# Inhibitory
In_sc_data = sc_data_sorted[sc_data_sorted['Neuronal_cluster'].str.startswith("i")]
sc_MOD2 = In_sc_data.loc[:, [gene for gene in DE_genes_gm if gene in In_sc_data.columns]]
sc_MOD1 = In_sc_data.loc[:, [gene for gene in DE_genes_wm if gene in In_sc_data.columns]]
In_DE = pd.concat([sc_MOD2, sc_MOD1], axis=1)

In_cluster = In_sc_data["Neuronal_cluster"]


In [None]:
In_DE.shape

In [None]:
# Hybrid
Hy_sc_data = sc_data_sorted[sc_data_sorted['Neuronal_cluster'].str.startswith("h")]
sc_MOD2 = Hy_sc_data.loc[:, [gene for gene in DE_genes_gm if gene in Hy_sc_data.columns]]
sc_MOD1 = Hy_sc_data.loc[:, [gene for gene in DE_genes_wm if gene in Hy_sc_data.columns]]
Hy_DE = pd.concat([sc_MOD2, sc_MOD1], axis=1)

Hy_cluster = Hy_sc_data["Neuronal_cluster"]


In [None]:
Hy_DE.shape

### Heatmaps

#### functions

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

heat_cmap = sns.color_palette("RdYlBu_r", as_cmap=True)

import matplotlib.colors as mcolors
custom_cmap = mcolors.LinearSegmentedColormap.from_list(
    "white_orange_red", ["white", "orange", "red"], N=256
)

In [None]:
# genes_clusters_heatmap: heatmap of the marker genes

from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import linkage, leaves_list
import matplotlib.patches as patches

def genes_clusters_heatmap(cluster_data, cluster_labels, gene_groups=None, zscore=True, DE_g=True, cmap=heat_cmap):
    cluster_data = cluster_data.copy()
    cluster_data['Cell_class'] = cluster_labels

    # cluster_data = cluster_data.sort_values(by='Cell_class')
    cell_class_col = cluster_data['Cell_class']
    numeric_data = cluster_data.drop(columns=['Cell_class']).apply(pd.to_numeric, errors='coerce')
    numeric_data = numeric_data.dropna(axis=1, how='any')
    numeric_data = numeric_data.loc[:, ~numeric_data.T.duplicated()]
    cluster_data = pd.concat([cell_class_col, numeric_data], axis=1)

    # cluster_data = cluster_data.apply(pd.to_numeric, errors='coerce').dropna(axis=1, how='any')
    # cluster_data = cluster_data.loc[:, ~cluster_data.T.duplicated()]

    if 'Cell_class' in cluster_data.columns:
        expression_data = cluster_data.drop('Cell_class', axis=1)
    else:
        print("Warning: 'Cell_class' column not found in cluster_data.")
        expression_data = cluster_data.copy()

    expression_data = expression_data.T
    expression_data = expression_data.loc[:, ~expression_data.columns.duplicated()]

    cluster_labels_sorted = cluster_data['Cell_class'].values

    if zscore:
        expression_data = expression_data.apply(lambda x: (x - x.mean()) / x.std(), axis=1)
        vmin, vmax = -3, 3
    else:
        vmin, vmax = 0, 5

    new_order = []
    unique_labels = sorted(set(cluster_labels_sorted))

    for label in unique_labels:
        class_indices = np.where(cluster_labels_sorted == label)[0]

        if len(class_indices) == 0:
            print(f"Warning: No cells found for class {label}")
            continue
        
        subset = expression_data.iloc[:, class_indices]

        if subset.shape[1] > 1:
            linkage_matrix = linkage(subset.T, method='ward')
            sorted_indices = leaves_list(linkage_matrix)
            sorted_indices = class_indices[sorted_indices]
        else:
            sorted_indices = class_indices

        new_order.extend(sorted_indices)

    reordered_expression_data = expression_data.iloc[:, new_order]
    reordered_cluster_labels = cluster_labels_sorted[new_order]

    if gene_groups is not None:
        new_gene_order = []
        unique_gene_labels = sorted(set(gene_groups['cluster']))
        
        for label in unique_gene_labels:
            gene_indices = np.where(gene_groups == label)[0]
            # print(f"Gene label: {label}, Indices: {gene_indices}")

            if len(gene_indices) == 0:
                continue
            
            subset = reordered_expression_data.iloc[gene_indices, :]

            if subset.shape[0] > 1:
                linkage_matrix = linkage(subset, method='average')
                sorted_gene_indices = leaves_list(linkage_matrix)
                sorted_gene_indices = gene_indices[sorted_gene_indices]
            else:
                sorted_gene_indices = gene_indices

            new_gene_order.extend(sorted_gene_indices)
    else:
        linkage_matrix = linkage(reordered_expression_data, method='average')
        new_gene_order = leaves_list(linkage_matrix)
    
    reordered_expression_data = reordered_expression_data.iloc[new_gene_order, :]

    plt.figure(figsize=(20, 10), dpi=600) # non-OD
    # plt.figure(figsize=(20, 10), dpi=600) # MOD
    ax = sns.heatmap(
        reordered_expression_data,
        vmin=vmin, vmax=vmax,
        annot=False, fmt="g", xticklabels=False, yticklabels=True,  # Ensure labels are visible
        cmap=cmap,
        cbar=True,
        cbar_kws={"shrink": 0.5}  # Shrink color bar
    )

    cbar = ax.collections[0].colorbar
    cbar.ax.set_position([ax.get_position().x1 + 0.01,
                      ax.get_position().y1-0.3,
                      0.02,
                      0.2])

    # Set tick label font sizes
    ax.set_yticks(ax.get_yticks())  
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=15, rotation=0)

    cluster_boundaries = []
    for label in unique_labels:
        class_indices = np.where(reordered_cluster_labels == label)[0]
        if len(class_indices) == 0:
            continue
        
        start_idx, end_idx = class_indices[0], class_indices[-1]
        x_pos = (start_idx + end_idx) / 2

        ax.text(
            x_pos, -1.2, str(label), ha='center', va='center', rotation=90, fontsize=7, color='black'
        )

        cluster_boundaries.append(end_idx)

    for boundary in cluster_boundaries[:-1]:
        ax.axvline(x=boundary + 0.5, color='purple', linestyle='--', linewidth=1)

    if DE_g:
        ax.hlines(y=9, xmin=ax.get_xlim()[0], xmax=ax.get_xlim()[1], color='green', linestyle='--', linewidth=1)

    plt.xlabel("")
    plt.ylabel("")
    plt.show()


In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, leaves_list

def reorder_clusters(cluster_data, cluster_labels, gene_groups=None, zscore=True):
    cluster_data = cluster_data.copy()
    cluster_data['Cell_class'] = cluster_labels

    # cluster_data = cluster_data.sort_values(by='Cell_class')
    cell_class_col = cluster_data['Cell_class']
    numeric_data = cluster_data.drop(columns=['Cell_class']).apply(pd.to_numeric, errors='coerce')
    numeric_data = numeric_data.dropna(axis=1, how='any')
    numeric_data = numeric_data.loc[:, ~numeric_data.T.duplicated()]
    cluster_data = pd.concat([cell_class_col, numeric_data], axis=1)

    if 'Cell_class' in cluster_data.columns:
        expression_data = cluster_data.drop('Cell_class', axis=1)
    else:
        print("Warning: 'Cell_class' column not found in cluster_data.")
        expression_data = cluster_data.copy()

    expression_data = expression_data.T
    expression_data = expression_data.loc[:, ~expression_data.columns.duplicated()]
    cluster_labels_sorted = cluster_data['Cell_class'].values

    if zscore:
        expression_data = expression_data.apply(lambda x: (x - x.mean()) / x.std(), axis=1)

    new_order = []
    unique_labels = natsorted(set(cluster_labels_sorted))


    for label in unique_labels:
        class_indices = np.where(cluster_labels_sorted == label)[0]

        if len(class_indices) == 0:
            print(f"Warning: No cells found for class {label}")
            continue
        
        subset = expression_data.iloc[:, class_indices]

        if subset.shape[1] > 1:
            linkage_matrix = linkage(subset.T, method='ward')
            sorted_indices = leaves_list(linkage_matrix)
            sorted_indices = class_indices[sorted_indices]
        else:
            sorted_indices = class_indices

        new_order.extend(sorted_indices)

    reordered_expression_data = expression_data.iloc[:, new_order]
    reordered_cluster_labels = cluster_labels_sorted[new_order]

    if gene_groups is not None:
        new_gene_order = []
        unique_gene_labels = sorted(set(gene_groups['cluster']))
        
        for label in unique_gene_labels:
            gene_indices = np.where(gene_groups == label)[0]

            if len(gene_indices) == 0:
                continue
            
            subset = reordered_expression_data.iloc[gene_indices, :]

            if subset.shape[0] > 1:
                linkage_matrix = linkage(subset, method='average')
                sorted_gene_indices = leaves_list(linkage_matrix)
                sorted_gene_indices = gene_indices[sorted_gene_indices]
            else:
                sorted_gene_indices = gene_indices

            new_gene_order.extend(sorted_gene_indices)
    else:
        linkage_matrix = linkage(reordered_expression_data, method='average')
        new_gene_order = leaves_list(linkage_matrix)
    
    reordered_expression_data = reordered_expression_data.iloc[new_gene_order, :]

    return reordered_expression_data, reordered_cluster_labels

In [None]:
def cluster_plot_heatmap(re_IN, re_IN_clu, DE_g=True, cmap=heat_cmap, figures=(15,25)):
    unique_labels = set(re_IN_clu)
    
    plt.figure(figsize=figures, dpi=600) # non-OD
    # plt.figure(figsize=(20, 10), dpi=600) # MOD
    ax = sns.heatmap(
        re_IN.T,
        vmin=-3, vmax=3,
        annot=False, fmt="g", xticklabels=True, yticklabels=False,  # Ensure labels are visible
        cmap=cmap,
        cbar=True,
        cbar_kws={"shrink": 0.5}  # Shrink color bar
    )

    cbar = ax.collections[0].colorbar
    cbar.ax.set_position([ax.get_position().x1 + 0.01,
                      ax.get_position().y1-0.3,
                      0.02,
                      0.2])

    # Set tick label font sizes
    ax.set_xticks(ax.get_xticks())  
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=13, rotation=0)

    cluster_boundaries = []
    for label in unique_labels:
        class_indices = np.where(re_IN_clu == label)[0]
        if len(class_indices) == 0:
            continue
        
        start_idx, end_idx = class_indices[0], class_indices[-1]
        x_pos = (start_idx + end_idx) / 2

        ax.text(
            -1.5, x_pos, str(label), ha='center', va='center', rotation=0, fontsize=11, color='black'
        )

        cluster_boundaries.append(end_idx)
    
    yticks = ax.get_yticks()
    yticklabels = [label.get_text() for label in ax.get_yticklabels()]

    adjusted_positions = {}

    for i, (ytick, label) in enumerate(zip(yticks, yticklabels)):
        if i > 0 and abs(ytick - yticks[i - 1]) < 5:
            new_y = yticks[i - 1] - 5
            adjusted_positions[label] = new_y
        else:
            adjusted_positions[label] = ytick

    for label, new_y in adjusted_positions.items():
        if new_y != yticks[yticklabels.index(label)]: 
            orig_y = yticks[yticklabels.index(label)]
            ax.annotate(
                label, 
                xy=(-1.5, orig_y), xytext=(-3, new_y), 
                ha='right', va='center', fontsize=10, color='black',
                arrowprops=dict(arrowstyle="-", color="gray", linewidth=1.0, alpha=0.6)
            )
        else:
            ax.text(-1.5, new_y, label, ha='right', va='center', fontsize=10, color='black')

    for boundary in cluster_boundaries:#[:-1]:
        ax.axhline(y=boundary, color='purple', linestyle='--', linewidth=1)

    if DE_g:
        ax.vlines(x=9, ymin=ax.get_ylim()[0], ymax=ax.get_ylim()[1], color='green', linestyle='--', linewidth=1)
    
    plt.xlabel("")
    plt.ylabel("")
    plt.show()


In [None]:
gene_g = pd.DataFrame([0,0,0,0,0,0,0,0,0,1,1,1,1,1], columns=['cluster'])

### Excitatory

In [None]:
re_Ex_DE, re_Ex_clu = reorder_clusters(cluster_data=Ex_DE, cluster_labels=Ex_cluster, gene_groups=gene_g, zscore=True)

In [None]:
unique_labels = natsorted(set(re_Ex_clu))
cluster_boundaries = []
for label in unique_labels:
    class_indices = np.where(re_Ex_clu == label)[0]
    if len(class_indices) == 0:
        continue

    start_idx, end_idx = class_indices[0], class_indices[-1]
    # print(start_idx)
    cluster_boundaries.append(end_idx)

In [None]:
re_cluster_boundaries = natsorted(cluster_boundaries)
re_cluster_boundaries

In [None]:
cluster_plot_heatmap(re_Ex_DE.iloc[:,:2809], re_Ex_clu[:2809], DE_g=True, cmap=heat_cmap, figures=(15,25))

In [None]:
cluster_plot_heatmap(re_Ex_DE.iloc[:,2809:], re_Ex_clu[2809:], DE_g=True, cmap=heat_cmap, figures=(15,6))

In [None]:
genes_clusters_heatmap(cluster_data=Ex_DE, cluster_labels=Ex_cluster, gene_groups=gene_g, zscore=True, DE_g=True, cmap=heat_cmap)

### Inhibitory

In [None]:
re_IN_DE, re_IN_clu = reorder_clusters(cluster_data=In_DE, cluster_labels=In_cluster, gene_groups=gene_g, zscore=True)

In [None]:
unique_labels = natsorted(set(re_IN_clu))
cluster_boundaries = []
for label in unique_labels:
    class_indices = np.where(re_IN_clu == label)[0]
    if len(class_indices) == 0:
        continue

    start_idx, end_idx = class_indices[0], class_indices[-1]
    cluster_boundaries.append(end_idx)

In [None]:
re_cluster_boundaries = sorted(cluster_boundaries)
re_cluster_boundaries


In [None]:
cluster_plot_heatmap(re_IN_DE.iloc[:,:4865], re_IN_clu[:4865], DE_g=True, cmap=heat_cmap, figures=(15,25))

In [None]:
cluster_plot_heatmap(re_IN_DE.iloc[:,4865:8702], re_IN_clu[4865:8702], DE_g=True, cmap=heat_cmap, figures=(15,25))

In [None]:
cluster_plot_heatmap(re_IN_DE.iloc[:,8702:12840], re_IN_clu[8702:12840], DE_g=True, cmap=heat_cmap, figures=(15,25))

In [None]:
cluster_plot_heatmap(re_IN_DE.iloc[:,12841:], re_IN_clu[12841:], DE_g=True, cmap=heat_cmap, figures=(15,11))

In [None]:
genes_clusters_heatmap(cluster_data=In_DE, cluster_labels=In_cluster, gene_groups=gene_g, zscore=True, DE_g=True, cmap=heat_cmap)

### Hybrid

In [None]:
re_Hy_DE, re_Hy_clu = reorder_clusters(cluster_data=Hy_DE, cluster_labels=Hy_cluster, gene_groups=gene_g, zscore=True)

In [None]:
unique_labels = natsorted(set(re_Hy_clu))
cluster_boundaries = []
for label in unique_labels:
    class_indices = np.where(re_Hy_clu == label)[0]
    if len(class_indices) == 0:
        continue

    start_idx, end_idx = class_indices[0], class_indices[-1]
    cluster_boundaries.append(end_idx)

In [None]:
re_cluster_boundaries = sorted(cluster_boundaries)
re_cluster_boundaries


In [None]:
re_Hy_DE.shape

In [None]:
cluster_plot_heatmap(re_Hy_DE, re_Hy_clu, DE_g=True, cmap=heat_cmap, figures=(15,3))

### Astrocytes

In [None]:
As_DE_filtered = As_DE.drop(index='Ucn3')
gene_g_filtered = gene_g[:13]

In [None]:
genes_clusters_heatmap(cluster_data=As_DE_filtered.T, cluster_labels=As_cluster, gene_groups=gene_g_filtered, zscore=True, DE_g=True, cmap=heat_cmap)

In [None]:
genes_clusters_heatmap(cluster_data=As_DE_filtered.T, cluster_labels=As_cluster, gene_groups=gene_g_filtered, zscore=False, DE_g=True, cmap=custom_cmap)