## Result 5: MOD markers in scRNA-seq

In [None]:
from pathlib import Path
from scipy.io import mmread
import anndata as ad

import sys
import os
sys.path.append(os.path.abspath("../src"))
from plot import plot_annotate_heatmap
from utils import load_scRNA_data

import numpy as np
import pandas as pd

### data

#### Signals in the Tissue Section

In [None]:
MERFISH_data_folder_path = Path("../data/mouse_hypothalamus/MERFISH/")

In [None]:
columns = [
    "Centroid_X",
    "Centroid_Y",
    "Centroid_Z",
    "Gene_name",
    "Cell_name",
    "Total_brightness",
    "Area",
    "Error_bit",
    "Error_direction",
]

signal_coordinate_df = pd.read_csv(
    MERFISH_data_folder_path / "merfish_barcodes_example.csv", usecols=columns
).rename(
    columns={
        "Centroid_X": "x",
        "Centroid_Y": "y",
        "Centroid_Z": "z",
        "Gene_name": "gene",
    }
)


# remove dummy molecules
signal_coordinate_df = signal_coordinate_df.loc[
    ~signal_coordinate_df["gene"].str.contains("Blank|NegControl"),
]

signal_coordinate_df["gene"] = signal_coordinate_df["gene"].astype("category")

# shift the coordinates to avoid the negative values
coordinate_x_m =  signal_coordinate_df['x'].min()
coordinate_y_m =  signal_coordinate_df['y'].min()
signal_coordinate_df['x'] = signal_coordinate_df['x'] - coordinate_x_m
signal_coordinate_df['y'] = signal_coordinate_df['y'] - coordinate_y_m

# make a copy to avoid SettingWithCopyWarning
signal_coordinate_df = signal_coordinate_df.copy()

#### Results of BANKSY

In [None]:
banksy_folder_path = Path("../data/banksy_results/")

In [None]:
columns = [
    "Centroid_X",
    "Centroid_Y",
    "Bregma",
    "lam0.2",
]

banksy_result = pd.read_csv(
    banksy_folder_path / 'banksy_cluster.txt', usecols=columns, sep = '\t'
).rename(
    columns={
        "Centroid_X": "x",
        "Centroid_Y": "y",
        "Bregma": "Bregma",
        "lam0.2": "banksy_cluster",
    }
)

banksy_result = banksy_result[banksy_result['Bregma'] == -0.24]

banksy_result['x'] = banksy_result['x'] - coordinate_x_m
banksy_result['y'] = banksy_result['y'] - coordinate_y_m

banksy_result = banksy_result.copy()

#### Segmentation Dataset

In [None]:
merfish_data = pd.read_csv(
    MERFISH_data_folder_path / "merfish_all_cells.csv"
    ).rename(
    columns={
        "Centroid_X": "x",
        "Centroid_Y": "y"
    }
)

merfish_data = merfish_data.drop(columns=[col for col in merfish_data.columns if col == 'Fos' or col.startswith('Blank_')])
merfish_data = merfish_data[merfish_data["Cell_class"] != "Ambiguous"]
merfish_data = merfish_data[merfish_data['Animal_ID'] == 1]
merfish_data = merfish_data[merfish_data['Bregma'] == -0.24]

merfish_data['x'] = merfish_data['x'] - coordinate_x_m
merfish_data['y'] = merfish_data['y'] - coordinate_y_m

merfish_data['banksy'] = banksy_result['banksy_cluster'].values

merfish_data = merfish_data.copy()

In [None]:
cell_class_m = {'Astrocyte': 'Astrocyte',
 'Endothelial 1': 'Endothelial',
 'Endothelial 2': 'Endothelial',
 'Endothelial 3': 'Endothelial',
 'Ependymal': 'Ependymal',
 'Excitatory': 'Excitatory',
 'Inhibitory': 'Inhibitory',
 'Microglia': 'Microglia',
 'OD Immature 1': 'OD Immature',
 'OD Immature 2': 'OD Immature',
 'OD Mature 1': 'OD Mature',
 'OD Mature 2': 'OD Mature',
 'OD Mature 3': 'OD Mature',
 'OD Mature 4': 'OD Mature',
 'Pericytes': 'Pericytes'}

merfish_data['Cell_class'] = merfish_data['Cell_class'].map(cell_class_m)
merfish_data = merfish_data.sort_values(by='Cell_class')

merfish_data = merfish_data.copy()

#### Matched Single-Cell RNA Sequencing

In [None]:
data_path = Path("../data/mouse_hypothalamus/SingleCell/")
mtx_path = data_path / "GSE113576_matrix.mtx"
barcodes_path = data_path / "GSE113576_barcodes.tsv"
genes_path = data_path / "GSE113576_genes.tsv"
meta_path = data_path / "aau5324_Moffitt_Table-S1.xlsx"

In [None]:
cell_class_dict = {
    'Astrocytes': 'Astrocytes',
    'Endothelial': 'Endothelial',
    'Ependymal': 'Ependymal',
    'Excitatory': 'Excitatory',
    'Fibroblast': 'Fibroblast',
    'Immature oligodendrocyte': 'OD immature',
    'Inhibitory': 'Inhibitory',
    'Macrophage': 'Macrophage',
    'Mature oligodendrocyte': 'OD mature',
    'Microglia': 'Microglia',
    'Mural': 'Mural',
    'Newly formed oligodendrocyte': 'OD newly formed',
}

adata = load_scRNA_data(
    mtx_path, 
    barcodes_path, 
    genes_path, 
    meta_path, 
    cell_class_filter = cell_class_dict
)

#### Marker Genes

differentially expressed genes identified by BANKSY

In [None]:
# all differentially expressed genes
DE_genes = ['Mlc1', 'Dgkk', 'Cbln2', 'Syt4', 'Gad1', 'Plin3', 'Gnrh1', 'Sln', 'Gjc3', 'Mbp', 'Lpar1', 'Trh', 'Ucn3', 'Cck']
# DE_genes_MOD2: 7
DE_genes_MOD2 = ['Mlc1', 'Dgkk', 'Cbln2', 'Syt4', 'Gad1', 'Plin3', 'Gnrh1', 'Sln', 'Gjc3']
# DE_genes_MOD1: 8
DE_genes_MOD1 = ['Mbp', 'Lpar1', 'Trh', 'Ucn3', 'Cck']

In [None]:
sc_data = adata.to_df()
sc_DE_MOD1_df = sc_data[DE_genes_MOD2]
sc_DE_MOD2_df = sc_data[DE_genes_MOD1]
sc_DE = pd.concat([sc_DE_MOD1_df, sc_DE_MOD2_df], axis=1)

sc_cell_class = adata.obs['Cell_class']

In [None]:
# MERFISH, MOD, Marker Genes
MOD_merfish = merfish_data[(merfish_data['banksy']==7) | (merfish_data['banksy']==8)]
MOD_merfish = MOD_merfish.sort_values(by='banksy')

common_genes = [gene for gene in DE_genes if gene in MOD_merfish.columns]
MOD_merfish_DE = MOD_merfish[common_genes].T

MOD_banksy = MOD_merfish['banksy']

In [None]:
Gene_Group = pd.DataFrame([0,0,0,0,0,0,0,0,0,1,1,1,1,1], columns=['cluster'])

### Heatmaps

#### functions

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

heat_cmap = sns.color_palette("RdYlBu_r", as_cmap=True)

In [None]:
# genes_clusters_heatmap: heatmap of the marker genes

from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import linkage, leaves_list
import matplotlib.patches as patches

def genes_clusters_heatmap(cluster_data, cluster_labels, gene_groups=None, zscore=True, DE_g=True, y_pos=-1.2, cmap=heat_cmap):
    cluster_data = cluster_data.copy()
    cluster_data['Cell_class'] = cluster_labels

    cluster_data = cluster_data.sort_values(by='Cell_class')
    cell_class_col = cluster_data['Cell_class']
    numeric_data = cluster_data.drop(columns=['Cell_class']).apply(pd.to_numeric, errors='coerce')
    numeric_data = numeric_data.dropna(axis=1, how='any')
    numeric_data = numeric_data.loc[:, ~numeric_data.T.duplicated()]
    cluster_data = pd.concat([cell_class_col, numeric_data], axis=1)

    if 'Cell_class' in cluster_data.columns:
        expression_data = cluster_data.drop('Cell_class', axis=1)
    else:
        print("Warning: 'Cell_class' column not found in cluster_data.")
        expression_data = cluster_data.copy()

    expression_data = expression_data.T
    expression_data = expression_data.loc[:, ~expression_data.columns.duplicated()]

    cluster_labels_sorted = cluster_data['Cell_class'].values

    if zscore:
        expression_data = expression_data.apply(lambda x: (x - x.mean()) / x.std(), axis=1)
        vmin, vmax = -3, 3
    else:
        vmin, vmax = 0, 5

    new_order = []
    unique_labels = sorted(set(cluster_labels_sorted))

    for label in unique_labels:
        class_indices = np.where(cluster_labels_sorted == label)[0]

        if len(class_indices) == 0:
            print(f"Warning: No cells found for class {label}")
            continue
        
        subset = expression_data.iloc[:, class_indices]

        if subset.shape[1] > 1:
            linkage_matrix = linkage(subset.T, method='ward')
            sorted_indices = leaves_list(linkage_matrix)
            sorted_indices = class_indices[sorted_indices]
        else:
            sorted_indices = class_indices

        new_order.extend(sorted_indices)

    reordered_expression_data = expression_data.iloc[:, new_order]
    reordered_cluster_labels = cluster_labels_sorted[new_order]

    if gene_groups is not None:
        new_gene_order = []
        unique_gene_labels = sorted(set(gene_groups['cluster']))
        
        for label in unique_gene_labels:
            gene_indices = np.where(gene_groups == label)[0]
            # print(f"Gene label: {label}, Indices: {gene_indices}")

            if len(gene_indices) == 0:
                continue
            
            subset = reordered_expression_data.iloc[gene_indices, :]

            if subset.shape[0] > 1:
                linkage_matrix = linkage(subset, method='average')
                sorted_gene_indices = leaves_list(linkage_matrix)
                sorted_gene_indices = gene_indices[sorted_gene_indices]
            else:
                sorted_gene_indices = gene_indices

            new_gene_order.extend(sorted_gene_indices)
    else:
        linkage_matrix = linkage(reordered_expression_data, method='average')
        new_gene_order = leaves_list(linkage_matrix)
    
    reordered_expression_data = reordered_expression_data.iloc[new_gene_order, :]

    plt.figure(figsize=(20, 10), dpi=600)
    ax = sns.heatmap(
        reordered_expression_data,
        vmin=vmin, vmax=vmax,
        annot=False, fmt="g", xticklabels=False, yticklabels=True,  # Ensure labels are visible
        cmap=cmap,
        cbar=True,
        cbar_kws={"shrink": 0.5}  # Shrink color bar
    )

    cbar = ax.collections[0].colorbar
    cbar.ax.set_position([ax.get_position().x1 + 0.01,
                      ax.get_position().y1-0.3,
                      0.02,
                      0.2])

    # Set tick label font sizes
    ax.set_yticks(ax.get_yticks())  
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=15, rotation=0)

    cluster_boundaries = []
    for label in unique_labels:
        class_indices = np.where(reordered_cluster_labels == label)[0]
        if len(class_indices) == 0:
            continue
        
        start_idx, end_idx = class_indices[0], class_indices[-1]
        x_pos = (start_idx + end_idx) / 2

        ax.text(
            x_pos, y_pos, str(label), ha='center', va='center', rotation=90, fontsize=15, color='black' # scRNA, mfsh non-od
            # x_pos, -0.9, str(label), ha='center', va='center', rotation=90, fontsize=15, color='black' # merfish mod
        )

        cluster_boundaries.append(end_idx)

    for boundary in cluster_boundaries[:-1]:
        ax.axvline(x=boundary + 0.5, color='purple', linestyle='--', linewidth=1)

    if DE_g:
        ax.hlines(y=9, xmin=ax.get_xlim()[0], xmax=ax.get_xlim()[1], color='green', linestyle='--', linewidth=1)

    # for scRNA MOD cells
    # # Get the position of the heatmap axis to adjust rectangle accordingly
    ax_pos = ax.get_position()  # This gives the (left, bottom, width, height)

    # Specify the rectangle parameters (customize x_rect, width, and height)
    x_rect = ax_pos.x0 + 500  # Adjust this value to move the rectangle horizontally (x-axis)
    width = 350  # Adjust this value to control the width of the rectangle
    y_rect = ax_pos.y0 -0.1  # Set y_rect to the bottom of the heatmap
    height = ax_pos.height * 18.15  # Set the height of the rectangle to match the heatmap height

    # Create and add the rectangle
    rect1 = patches.Rectangle(
        (x_rect, y_rect), width, height,
        linewidth=1.5, edgecolor='green', facecolor='none', linestyle='--'
    )
    ax.add_patch(rect1)

    x_rect = ax_pos.x0 + 1350  # Adjust this value to move the rectangle horizontally (x-axis)
    width = 250  # Adjust this value to control the width of the rectangle
    y_rect = ax_pos.y0 -0.1  # Set y_rect to the bottom of the heatmap
    height = ax_pos.height * 18.15  # Set the height of the rectangle to match the heatmap height

    # Create and add the rectangle
    rect2 = patches.Rectangle(
        (x_rect, y_rect), width, height,
        linewidth=1.5, edgecolor='orange', facecolor='none', linestyle='--'
    )
    ax.add_patch(rect2)

    x_rect = ax_pos.x0 + 2620  # Adjust this value to move the rectangle horizontally (x-axis)
    width = 200  # Adjust this value to control the width of the rectangle
    y_rect = ax_pos.y0 -0.1  # Set y_rect to the bottom of the heatmap
    height = ax_pos.height * 18.15  # Set the height of the rectangle to match the heatmap height

    # Create and add the rectangle
    rect2 = patches.Rectangle(
        (x_rect, y_rect), width, height,
        linewidth=1.5, edgecolor='blue', facecolor='none', linestyle='--'
    )
    ax.add_patch(rect2)

    x_rect = ax_pos.x0 + 4800  # Adjust this value to move the rectangle horizontally (x-axis)
    width = 1800  # Adjust this value to control the width of the rectangle
    y_rect = ax_pos.y0 -0.1  # Set y_rect to the bottom of the heatmap
    height = ax_pos.height * 18.15  # Set the height of the rectangle to match the heatmap height

    # Create and add the rectangle
    rect3 = patches.Rectangle(
        (x_rect, y_rect), width, height,
        linewidth=1.5, edgecolor='red', facecolor='none', linestyle='--'
    )
    ax.add_patch(rect3)

    plt.xlabel("")
    plt.ylabel("")
    plt.show()


#### scRNA-seq: Marker Genes in MOD

In [None]:
MOD_sc_data = pd.concat([sc_DE, sc_cell_class], axis=1)
MOD_sc_data = MOD_sc_data[MOD_sc_data['Cell_class'].str.startswith("OD mat")]

MOD_sc_cell_class = MOD_sc_data['Cell_class']
MOD_sc_data = MOD_sc_data.drop('Cell_class', axis=1)

In [None]:
box_specs = [
    {"x_offset": 500, "width": 350, "color": "green"},
    {"x_offset": 1350, "width": 250, "color": "orange"},
    {"x_offset": 2620, "width": 200, "color": "blue"},
    {"x_offset": 4800, "width": 1800, "color": "red"}
]

In [None]:
plot_annotate_heatmap(cluster_data = MOD_sc_data, cluster_labels=MOD_sc_cell_class, gene_groups=Gene_Group, show_cluster=False, box_specs=box_specs)

#### scRNA-seq: Marker Genes in non-OD Cell Types

In [None]:
noOD_sc_data = pd.concat([sc_DE, sc_cell_class], axis=1)
noOD_sc_data = noOD_sc_data[~noOD_sc_data['Cell_class'].str.startswith("OD")]

noOD_sc_cell_class = noOD_sc_data['Cell_class']
noOD_sc_data = noOD_sc_data.drop('Cell_class', axis=1)

In [None]:
plot_annotate_heatmap(cluster_data = noOD_sc_data, cluster_labels=noOD_sc_cell_class, gene_groups=Gene_Group, cluster_text_y=-1.3)

#### MERFISH: Marker Genes in MOD

In [None]:
MOD_banksy = MOD_banksy.replace({7: 'MOD2', 8: 'MOD1'})

In [None]:
plot_annotate_heatmap(cluster_data=MOD_merfish_DE.T, cluster_labels=MOD_banksy, gene_groups=Gene_Group, cluster_text_y=-0.9)

#### MERFISH: Marker Genes in non-OD cell types

In [None]:
OD_mask = merfish_data['Cell_class'].str.startswith("OD")
no_OD_merfish_data = merfish_data.drop(merfish_data.index[OD_mask], axis=0, inplace=False)

common_genes_gm = [gene for gene in DE_genes_MOD2 if gene in no_OD_merfish_data.columns]
no_OD_merfish_DE_gm = no_OD_merfish_data[common_genes_gm]

common_genes_wm = [gene for gene in DE_genes_MOD1 if gene in no_OD_merfish_data.columns]
no_OD_merfish_DE_wm = no_OD_merfish_data[common_genes_wm]

no_OD_merfish_DE = pd.concat([no_OD_merfish_DE_gm, no_OD_merfish_DE_wm], axis=1).T

no_OD_cell_class = no_OD_merfish_data['Cell_class']

In [None]:
plot_annotate_heatmap(cluster_data=no_OD_merfish_DE.T, cluster_labels=no_OD_cell_class, gene_groups=Gene_Group)