## Result 3: transcript neighborhood



### Data

In [None]:
import sys
import os

sys.path.append(os.path.abspath("../src"))

from pathlib import Path
from plot import plot_VSI_map, plot_celltypes, plot_vsi_with_named_squares, plot_circular_neighborhood

from IPython.display import display
import numpy as np
import pandas as pd

#### signals in tissue section

In [None]:
MERFISH_data_folder_path = Path("../data/mouse_hypothalamus/MERFISH/")

In [None]:
columns = [
    "Centroid_X",
    "Centroid_Y",
    "Centroid_Z",
    "Gene_name",
    "Cell_name",
    "Total_brightness",
    "Area",
    "Error_bit",
    "Error_direction",
]

signal_coordinate_df = pd.read_csv(
    MERFISH_data_folder_path / "merfish_barcodes_example.csv", usecols=columns
).rename(
    columns={
        "Centroid_X": "x",
        "Centroid_Y": "y",
        "Centroid_Z": "z",
        "Gene_name": "gene",
    }
)


# remove dummy molecules
signal_coordinate_df = signal_coordinate_df.loc[
    ~signal_coordinate_df["gene"].str.contains("Blank|NegControl"),
]

signal_coordinate_df["gene"] = signal_coordinate_df["gene"].astype("category")

# shift the coordinates to avoid the negative values
coordinate_x_m =  signal_coordinate_df['x'].min()
coordinate_y_m =  signal_coordinate_df['y'].min()
signal_coordinate_df['x'] = signal_coordinate_df['x'] - coordinate_x_m
signal_coordinate_df['y'] = signal_coordinate_df['y'] - coordinate_y_m

# make a copy to avoid SettingWithCopyWarning
signal_coordinate_df = signal_coordinate_df.copy()

In [None]:
top20 = signal_coordinate_df['gene'].value_counts()[:20]

#### Results of Ovrlpy

results:  
- signal integrity  
- signal strength  

In [None]:
ovrlpy_result_folder = Path("../data/results/VSI")

In [None]:
signal_integrity = np.loadtxt(ovrlpy_result_folder/"SignalIntegrity.txt")
signal_strength = np.loadtxt(ovrlpy_result_folder/"SignalStrength.txt")

#### Results of BANKSY

In [None]:
banksy_folder_path = Path("../data/banksy_results/")

In [None]:
columns = [
    "Centroid_X",
    "Centroid_Y",
    "Bregma",
    "lam0.2",
]

banksy_result = pd.read_csv(
    banksy_folder_path / 'banksy_cluster.txt', usecols=columns, sep = '\t'
).rename(
    columns={
        "Centroid_X": "x",
        "Centroid_Y": "y",
        "Bregma": "Bregma",
        "lam0.2": "banksy_cluster",
    }
)

banksy_result = banksy_result[banksy_result['Bregma'] == -0.24]

banksy_result['x'] = banksy_result['x'] - coordinate_x_m
banksy_result['y'] = banksy_result['y'] - coordinate_y_m

banksy_result = banksy_result.copy()

#### Segmentation Dataset

In [None]:
merfish_data = pd.read_csv(
    MERFISH_data_folder_path / "merfish_all_cells.csv"
    ).rename(
    columns={
        "Centroid_X": "x",
        "Centroid_Y": "y"
    }
)

merfish_data = merfish_data.drop(columns=[col for col in merfish_data.columns if col == 'Fos' or col.startswith('Blank_')])
merfish_data = merfish_data[merfish_data["Cell_class"] != "Ambiguous"]
merfish_data = merfish_data[merfish_data['Animal_ID'] == 1]
merfish_data = merfish_data[merfish_data['Bregma'] == -0.24]

merfish_data['x'] = merfish_data['x'] - coordinate_x_m
merfish_data['y'] = merfish_data['y'] - coordinate_y_m

merfish_data['banksy'] = banksy_result['banksy_cluster'].values

merfish_data = merfish_data.copy()

In [None]:
cell_class_m = {'Astrocyte': 'Astrocyte',
 'Endothelial 1': 'Endothelial',
 'Endothelial 2': 'Endothelial',
 'Endothelial 3': 'Endothelial',
 'Ependymal': 'Ependymal',
 'Excitatory': 'Excitatory',
 'Inhibitory': 'Inhibitory',
 'Microglia': 'Microglia',
 'OD Immature 1': 'OD Immature',
 'OD Immature 2': 'OD Immature',
 'OD Mature 1': 'OD Mature',
 'OD Mature 2': 'OD Mature',
 'OD Mature 3': 'OD Mature',
 'OD Mature 4': 'OD Mature',
 'Pericytes': 'Pericytes'}

merfish_data['Cell_class'] = merfish_data['Cell_class'].map(cell_class_m)
merfish_data = merfish_data.sort_values(by='Cell_class')

merfish_data = merfish_data.copy()

#### Cell Boundaries Dataset

In [None]:
boundaries_df = pd.read_csv(MERFISH_data_folder_path/'cellboundaries_example_animal.csv')
boundaries_df = boundaries_df.dropna(subset=['boundaryX', 'boundaryY'])

In [None]:
cell_ids = merfish_data['Cell_ID']
boundaries_df = boundaries_df[boundaries_df['feature_uID'].isin(cell_ids)]
boundaries_df = boundaries_df.merge(
    merfish_data[['Cell_ID', 'x', 'y', 'banksy']],
    # merfish_data[['Cell_ID', 'x', 'y']],
    left_on='feature_uID',
    right_on='Cell_ID',
    how='inner'
)
boundaries_df = boundaries_df.drop(columns=['Cell_ID'])

boundaries_df['boundaryX'] = boundaries_df['boundaryX'].apply(lambda x: [float(i) for i in x.split(';')] if isinstance(x, str) else x)
boundaries_df['boundaryY'] = boundaries_df['boundaryY'].apply(lambda x: [float(i) for i in x.split(';')] if isinstance(x, str) else x)

boundaries_df['boundaryX'] = boundaries_df['boundaryX'].apply(lambda x: [i - coordinate_x_m for i in x] if isinstance(x, list) else x)
boundaries_df['boundaryY'] = boundaries_df['boundaryY'].apply(lambda x: [i - coordinate_y_m for i in x] if isinstance(x, list) else x)

boundaries_df = boundaries_df.copy()

#### OD cell boundaries

In [None]:
# boundaries, MOD
MOD_boundaries = boundaries_df[(boundaries_df['banksy'] == 8) | (boundaries_df['banksy'] == 7)]
other_boundaries = boundaries_df[~boundaries_df['banksy'].isin([8, 7])]

#### Marker Genes

differentially expressed genes identified by BANKSY

In [None]:
# all differentially expressed genes
DE_genes = ['Mlc1', 'Dgkk', 'Cbln2', 'Syt4', 'Gad1', 'Plin3', 'Gnrh1', 'Sln', 'Gjc3', 'Mbp', 'Lpar1', 'Trh', 'Ucn3', 'Cck']
# DE_genes_gm: 7
DE_genes_gm = ['Mlc1', 'Dgkk', 'Cbln2', 'Syt4', 'Gad1', 'Plin3', 'Gnrh1', 'Sln', 'Gjc3']
# DE_genes_wm: 8
DE_genes_wm = ['Mbp', 'Lpar1', 'Trh', 'Ucn3', 'Cck']

### Regions of Interest

cell types and vsi

x_range=[250, 350], y_range=[1450, 1550]  
x_range=[1350, 1450], y_range=[1300, 1400]  
x_range=[850, 950], y_range=[950, 1050]  
x_range=[200, 300], y_range=[300, 400]  
x_range=[1580, 1680], y_range=[350, 450] 

#### VSI

In [None]:
regions = [
    {"x": 250, "y": 1450, "name": "Region 1"},
    {"x": 1350, "y": 1300, "name": "Region 2"},
    {"x": 850, "y": 950, "name": "Region 3"},
    {"x": 200, "y": 300, "name": "Region 4"},
    {"x": 1580, "y": 350, "name": "Region 5"}
]

In [None]:
# complete vsi map
VSI_ROI = plot_vsi_with_named_squares(signal_integrity, signal_strength, named_squares=regions)
display(VSI_ROI)

#### region 1
x_range=[250, 350], y_range=[1450, 1550]  

In [None]:
plot_VSI_map(
    cell_integrity=signal_integrity,
    cell_strength=signal_strength,
    boundary_df=boundaries_df,
    signal_threshold=3.0,
    figure_height=10,
    cmap="BIH",
    x_range=[250,350],
    y_range=[1450, 1550]
    )

In [None]:
plot_celltypes(
    cell_type=banksy_result,
    boundary_df=other_boundaries,
    MOD_boundary=MOD_boundaries,
    x_range=[250,350],
    y_range=[1450, 1550]
)

#### region 2
x_range=[1350, 1450], y_range=[1300, 1400]  

In [None]:
plot_VSI_map(
    cell_integrity=signal_integrity,
    cell_strength=signal_strength,
    boundary_df=boundaries_df,
    signal_threshold=3.0,
    figure_height=10,
    cmap="BIH",
    x_range=[1350, 1450],
    y_range=[1300, 1400]
    )

In [None]:
plot_celltypes(
    cell_type=banksy_result,
    boundary_df=other_boundaries,
    MOD_boundary=MOD_boundaries,
    x_range=[1350, 1450],
    y_range=[1300, 1400],
    cmap = full_color_map
)

#### region 3
x_range=[850, 950], y_range=[950, 1050]  

In [None]:
plot_VSI_map(
    cell_integrity=signal_integrity,
    cell_strength=signal_strength,
    boundary_df=boundaries_df,
    signal_threshold=3.0,
    figure_height=10,
    cmap="BIH",
    x_range=[850, 950],
    y_range=[950, 1050]
    )

In [None]:
plot_celltypes(
    cell_type=banksy_result,
    boundary_df=other_boundaries,
    MOD_boundary=MOD_boundaries,
    x_range=[850, 950],
    y_range=[950, 1050]
)

#### region 4
x_range=[200, 300], y_range=[300, 400]  

In [None]:
plot_VSI_map(
    cell_integrity=signal_integrity,
    cell_strength=signal_strength,
    boundary_df=boundaries_df,
    signal_threshold=3.0,
    figure_height=10,
    cmap="BIH",
    x_range=[200, 300],
    y_range=[300, 400]
    )

In [None]:
plot_celltypes(
    cell_type=banksy_result,
    boundary_df=other_boundaries,
    MOD_boundary=MOD_boundaries,
    x_range=[200, 300],
    y_range=[300, 400],
    cmap = full_color_map
)

#### region 5
x_range=[1580, 1680], y_range=[350,450]  

In [None]:
plot_VSI_map(
    cell_integrity=signal_integrity,
    cell_strength=signal_strength,
    boundary_df=boundaries_df,
    signal_threshold=3.0,
    figure_height=10,
    cmap="BIH",
    x_range=[1580, 1680],
    y_range=[350,450],
    )

In [None]:
plot_celltypes(
    cell_type=banksy_result,
    boundary_df=other_boundaries,
    MOD_boundary=MOD_boundaries,
    x_range=[1580, 1680],
    y_range=[350,450],
    cmap = full_color_map
)

### Circle

#### region1

In [None]:
plot_circular_neighborhood(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[250, 350], y_range=[1450, 1550], 
                                diameters=[6, 8, 10, 12], true_boundary=True, 
                                plot_top20=True, top20=top20.index)

In [None]:
plot_circular_neighborhood(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[250, 350], y_range=[1450, 1550], 
                                diameters=[6, 8, 10, 12], true_boundary=True, 
                                plot_top20=False, top20=None)

#### region2

In [None]:
plot_circular_neighborhood(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[1350, 1450], y_range=[1300, 1400], 
                                diameters=[6, 8, 10, 12], true_boundary=True, 
                                plot_top20=True, top20=top20.index)

In [None]:
plot_circular_neighborhood(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[1350, 1450], y_range=[1300, 1400], 
                                diameters=[6, 8, 10, 12], true_boundary=True, 
                                plot_top20=False, top20=None)

#### region3

In [None]:
plot_circular_neighborhood(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[850, 950], y_range=[950, 1050], 
                                diameters=[6, 8, 10, 12], true_boundary=True, 
                                plot_top20=True, top20=top20.index)

In [None]:
plot_circular_neighborhood(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[850, 950], y_range=[950, 1050], 
                                diameters=[6, 8, 10, 12], true_boundary=True, 
                                plot_top20=False, top20=None)

#### region4

In [None]:
plot_circular_neighborhood(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[200, 300], y_range=[300, 400], 
                                diameters=[6, 8, 10, 12], true_boundary=True, 
                                plot_top20=True, top20=top20.index)

In [None]:
plot_circular_neighborhood(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[200, 300], y_range=[300, 400], 
                                diameters=[6, 8, 10, 12], true_boundary=True, 
                                plot_top20=False, top20=None)

#### region5

In [None]:
plot_circular_neighborhood(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[1580, 1680], y_range=[350,450],
                                diameters=[6, 8, 10, 12], true_boundary=True, 
                                plot_top20=True, top20=top20.index)

In [None]:
plot_circular_neighborhood(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[1580, 1680], y_range=[350,450],
                                diameters=[6, 8, 10, 12], true_boundary=True, 
                                plot_top20=False, top20=None)

### kNN

In [None]:
from sklearn.neighbors import NearestNeighbors

def compute_knn(coordinate_df, query_points, k):
    if k > len(coordinate_df):
        raise ValueError(f"k ({k}) cannot be greater than the number of points in coordinate_df ({len(coordinate_df)})")

    coordinates = coordinate_df[['x', 'y', 'z']].values

    nbrs = NearestNeighbors(n_neighbors=k)
    nbrs.fit(coordinates)

    distances, indices = nbrs.kneighbors(query_points)

    results = []
    for query, dist, idx in zip(query_points, distances, indices):
        results.append({
            'query_point': query.tolist(),
            'neighbor_indices': idx.tolist(),
            'neighbor_distances': dist.tolist()
        })

    return results


In [None]:
from scipy.spatial import ConvexHull

def plot_focus_points_with_boundary_knn(signals_df, centroid_df, MOD_boundaries, boundaries_df, x_range, y_range, neighbors=[20, 40, 80, 160, 220], 
                                        true_boundary = True, plot_top20=False, top20=None):
    signals_filtered = signals_df[
        (signals_df["x"] >= x_range[0]) & (signals_df["x"] <= x_range[1]) &
        (signals_df["y"] >= y_range[0]) & (signals_df["y"] <= y_range[1])
    ]
    centroid_filtered = centroid_df[
        (centroid_df["x"] >= x_range[0]) & (centroid_df["x"] <= x_range[1]) &
        (centroid_df["y"] >= y_range[0]) & (centroid_df["y"] <= y_range[1])
    ]
    MOD_filtered = MOD_boundaries[
        (MOD_boundaries["x"] >= x_range[0]) & (MOD_boundaries["x"] <= x_range[1]) &
        (MOD_boundaries["y"] >= y_range[0]) & (MOD_boundaries["y"] <= y_range[1])
    ]
    boundaries_filtered = boundaries_df[
        (boundaries_df["x"] >= x_range[0]) & (boundaries_df["x"] <= x_range[1]) &
        (boundaries_df["y"] >= y_range[0]) & (boundaries_df["y"] <= y_range[1])
    ]


    fig, ax = plt.subplots(figsize=(8, 8), dpi=600)
    norm = Normalize(vmin=0, vmax=2000)
    cmap = plt.cm.Oranges

    if plot_top20 and top20 is not None:
        marker_styles = ['o', 's', 'D', '^']  # 4 marker styles: circle, square, diamond, triangle
        colors = sns.color_palette("tab10", 5)  # 5 distinct colors from seaborn palette
        top20_dict = {gene: (marker_styles[i % 4], colors[i % 5]) for i, gene in enumerate(top20)}

        # Plot background signals
        ax.scatter(
            signals_filtered['x'], signals_filtered['y'],
            s=3, color='lightgrey', alpha=0.5, label="Other Genes"
        )

        # Plot top 20 genes
        for gene, (marker, color) in top20_dict.items():
            subset = signals_filtered[signals_filtered['gene'] == gene]
            ax.scatter(
                subset['x'], subset['y'],
                s=3, color=color, marker=marker, alpha=0.8, label=gene
            )
    else:
        scatter_sig = ax.scatter(
            signals_filtered["x"], signals_filtered["y"],
            s=3,
            c=cmap(norm(signals_filtered["Total_brightness"])),
        )
        # Colorbar
        cbar_wm = plt.colorbar(
            plt.cm.ScalarMappable(norm=norm, cmap=cmap),
            ax=ax,
            shrink=0.5,
            pad=0.02,
            anchor=(0.0, 0.3)
        )
        cbar_wm.set_label("Signal Brightness")

    # Plot centroids
    ax.scatter(
        centroid_filtered["x"], centroid_filtered["y"],
        s=15,
        c='blue',
        label="Cell Centroids",
        marker="x"
    )

    # Plot boundaries
    if true_boundary:
        for _, row in boundaries_filtered.iterrows():
            ax.plot(row['boundaryX'], row['boundaryY'], c='grey', lw=1)
        ax.plot([], [], color='grey', lw=1, label="Other Cells Boundary")
        for _, row in MOD_filtered.iterrows():
            ax.plot(row['boundaryX'], row['boundaryY'], c='#00bfae', lw=1)
        ax.plot([], [], color='#00bfae', lw=1, label="MOD Cells Boundary")

    cmap_rings = mpl.colormaps['tab20']
    for idx, k in enumerate(neighbors):
        color = cmap_rings(idx / len(neighbors))
        label_added = False
        for _, centroid in centroid_filtered.iterrows():
            new_point = np.array([[centroid['x'], centroid['y'], 4.5]])
            knn_results = compute_knn(signals_df, new_point, k)
            neighbors_indices = knn_results[0]['neighbor_indices']
            neighbor_points = signals_df.iloc[neighbors_indices][['x', 'y']].values
            
            # ConvexHull requires at least 3 points
            if len(neighbor_points) < 3:
                continue

            hull = ConvexHull(neighbor_points)
            for simplex in hull.simplices:
                ax.plot(
                    neighbor_points[simplex, 0],
                    neighbor_points[simplex, 1],
                    color=color,
                    lw=1,
                    label=f"k={k} NN" if not label_added else None
                    # label=f"Boundary k={k}" if not label_added else None
                )
                label_added = True
    
    ax.set_xlim(x_range)
    ax.set_ylim(y_range)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_aspect('equal')
    ax.legend(
            loc="upper left", 
            bbox_to_anchor=(1.02, 1),
            fontsize=10, 
            frameon=False, 
            markerscale=1.5, 
            ncol=1
        )
    plt.show()

#### region1

In [None]:
plot_focus_points_with_boundary_knn(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[250, 350], y_range=[1450, 1550],
                                neighbors=[40, 80, 160, 240], true_boundary=True, 
                                plot_top20=True, top20=top20.index)

In [None]:
plot_focus_points_with_boundary_knn(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries,  
                                x_range=[250, 350], y_range=[1450, 1550],
                                neighbors=[40, 80, 160, 240], true_boundary=True, 
                                plot_top20=False, top20=None)

#### region2

In [None]:
plot_focus_points_with_boundary_knn(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[1350, 1450], y_range=[1300, 1400], 
                                neighbors=[40, 80, 160, 240], true_boundary=True, 
                                plot_top20=True, top20=top20.index)

In [None]:
plot_focus_points_with_boundary_knn(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[1350, 1450], y_range=[1300, 1400], 
                                neighbors=[40, 80, 160, 240], true_boundary=True, 
                                plot_top20=False, top20=None)

#### region3

In [None]:
plot_focus_points_with_boundary_knn(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[850, 950], y_range=[950, 1050], 
                                neighbors=[40, 80, 160, 240], true_boundary=True, 
                                plot_top20=True, top20=top20.index)

In [None]:
plot_focus_points_with_boundary_knn(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[850, 950], y_range=[950, 1050], 
                                neighbors=[40, 80, 160, 240], true_boundary=True, 
                                plot_top20=False, top20=None)

#### region4

In [None]:
plot_focus_points_with_boundary_knn(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[200, 300], y_range=[300, 400], 
                                neighbors=[40, 80, 160, 240], true_boundary=True, 
                                plot_top20=True, top20=top20.index)

In [None]:
plot_focus_points_with_boundary_knn(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[200, 300], y_range=[300, 400], 
                                neighbors=[40, 80, 160, 240], true_boundary=True, 
                                plot_top20=False, top20=None)

#### region5

In [None]:
plot_focus_points_with_boundary_knn(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[1580, 1680], y_range=[350,450],
                                neighbors=[40, 80, 160, 240], true_boundary=True, 
                                plot_top20=True, top20=top20.index)

In [None]:
plot_focus_points_with_boundary_knn(signals_df=signal_coordinate_df,
                                centroid_df=merfish_data, 
                                MOD_boundaries=MOD_boundaries,
                                boundaries_df=other_boundaries, 
                                x_range=[1580, 1680], y_range=[350,450], 
                                neighbors=[40, 80, 160, 240], true_boundary=True, 
                                plot_top20=False, top20=None)