In [None]:
import os
from pathlib import Path
from typing import Annotated

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc
import seaborn as sns
import tifffile

from sklearn.cluster import KMeans
from skimage.color import label2rgb
from sklearn.neighbors import radius_neighbors_graph
from sklearn.neighbors import NearestNeighbors

from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import euclidean_distances
from scipy.stats import entropy, chi2_contingency

from scipy import sparse

plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42 #make text editable in pdf

output_dir = Path('/diskmnt/Projects/myeloma_scRNA_analysis/MMY_IRD/Xenium/analysis/radial_neighborhoods/Output')
output_dir.mkdir(parents=True, exist_ok=True)
os.chdir('/diskmnt/Projects/myeloma_scRNA_analysis/MMY_IRD/Xenium/analysis/radial_neighborhoods/')
os.getcwd()

In [None]:
merged = sc.read_h5ad('/diskmnt/Projects/myeloma_scRNA_analysis/MMY_IRD/Xenium/analysis/merged.h5ad')
merged

In [None]:
cts = list(merged.obs['ct'].cat.categories)

In [None]:
df = merged.obs.copy()

In [None]:
base = (df[["Sample", "Collection", "DI_Sample", "DI_UPN", "UPN"]]
          .drop_duplicates("Sample"))   # one row per sample

sample_to_collection = base.set_index("Sample")["Collection"].to_dict()
sample_to_DI_Sample  = base.set_index("Sample")["DI_Sample"].to_dict()
sample_to_DI_UPN     = base.set_index("Sample")["DI_UPN"].to_dict()
sample_to_UPN        = base.set_index("Sample")["UPN"].to_dict()

In [None]:
sample_col = "Sample"
annot_col  = "ct"
xcol, ycol = "x_centroid", "y_centroid"
radius = 100.0

per_cell_frames = []
totals_rows = []

all_annots = df[annot_col].cat.categories
ann_idx = pd.Index(all_annots)

for s, g in df.groupby(sample_col, sort=False, observed=True):
    print(s)
    g = g.copy()

    # per-sample totals for annot
    vc = g[annot_col].value_counts().reindex(ann_idx, fill_value=0)
    totals_rows.append(pd.Series({"Sample": s, **vc.to_dict(), "total_cells": len(g)}))

    # neighbor search
    X = g[[xcol, ycol]].to_numpy(dtype=np.float32)
    nn = NearestNeighbors(radius=radius, metric="euclidean", algorithm="kd_tree")
    nn.fit(X)
    neigh_ind = nn.radius_neighbors(X, radius=radius, return_distance=False)  # list of arrays

    # per-annot neighbor counts via bincount
    codes = pd.Categorical(g[annot_col], categories=all_annots, ordered=True).codes
    K = len(all_annots)
    n = len(g)
    counts_mat = np.zeros((n, K), dtype=np.int32)

    for i, inds in enumerate(neigh_ind):
        if inds.size:
            counts_mat[i] = np.bincount(codes[inds], minlength=K)

    total_n_neighbors = counts_mat.sum(axis=1).astype(np.int32)

    # assemble per-cell output for this sample (retain x/y centroids)
    counts_df = pd.DataFrame(counts_mat, index=g.index, columns=all_annots)
    df_out = pd.concat([pd.DataFrame({sample_col: s}, index=g.index), g[[xcol, ycol]].copy(),counts_df], axis=1)
    df_out["total_neighbors"] = total_n_neighbors
    per_cell_frames.append(df_out)

per_cell_df = pd.concat(per_cell_frames, axis=0)
per_sample_totals = pd.DataFrame(totals_rows).set_index(sample_col)

In [None]:
per_cell_df.to_csv(output_dir / 'allSamples_neighborsPerCell_annotCounts.txt', sep='\t')
per_cell_df

In [None]:
sns.displot(per_cell_df['total_neighbors'])

In [None]:
# only keep cells with > [thresh] neighbors
thresh = 30
sns.displot(per_cell_df['total_neighbors'])
plt.vlines(thresh, ymin=0, ymax=30000, colors='red')

In [None]:
per_cell_df_filtered = per_cell_df[per_cell_df['total_neighbors']>=thresh].copy()
per_cell_df_filtered

In [None]:
countsdf=per_cell_df_filtered.copy()
# convert countsdf to fractions per neighborhood
denom = countsdf["total_neighbors"]

# fractions table: keep identifiers/coords, replace counts with fractions
fractions = countsdf[["Sample", "x_centroid", "y_centroid", "total_neighbors"]].copy()
fractions[cts] = countsdf[cts].div(denom, axis=0).fillna(0.0)
fractions

In [None]:
# we want to find the optimal k for k-means clustering that maximizes:
# 1. sample representation in each cluster, which we measure with shannon diversity index,
# 2. "uniqueness" of each cluster, which we measure with average distance between cluster centroids
# 3. Collection (NBM, NDMM, PT) separation, which we measure using per-cluster entropy (lower is better)

df = fractions[['Sample']].copy()
df['Collection'] = df['Sample'].map(sample_to_collection)

X = fractions[[c for c in fractions.columns if c in all_annots]].fillna(0).values
Xs = StandardScaler().fit_transform(X)

ks = range(2, 16)
rows = []

for k in ks:
    print(f"evaluation k = {k}")
    km = KMeans(n_clusters=k, random_state=42, n_init="auto")
    labels = km.fit_predict(Xs)

    tmp = pd.DataFrame({
        'Sample': df['Sample'].values,
        'Collection': df['Collection'].values,
        'cluster': labels
    })

    # shannon 
    sh = []
    for _, g in tmp.groupby('cluster'):
        counts = g['Sample'].value_counts().values.astype(float)
        p = counts / counts.sum()
        sh.append(entropy(p))              # natural log base is fine
    mean_shannon = float(np.mean(sh)) if sh else 0.0

    # uniqueness
    pdists = euclidean_distances(km.cluster_centers_)
    centroid_sep = float(pdists[np.triu_indices(k, 1)].mean()) if k > 1 else 0.0

    # separate collection
    collection_entropies = []
    for cid, group in tmp.groupby('cluster'):
        counts = group['Collection'].value_counts()
        p = counts / counts.sum()
        collection_entropies.append(entropy(p, base=np.e))
    collection_sep = -np.mean(collection_entropies)

    # weighted composite
    composite = (mean_shannon + centroid_sep + collection_sep) / 3.0

    rows.append({
        'k': k,
        'mean_shannon': mean_shannon,
        'centroid_sep': centroid_sep,
        'collection_sep': collection_sep,
        'composite': composite
    })

res = pd.DataFrame(rows).sort_values('composite', ascending=False)
best_k = int(res.iloc[0]['k'])

# Final fit and attach labels
final = KMeans(n_clusters=best_k, random_state=42, n_init="auto")
fractions['cluster'] = final.fit_predict(Xs)

print(res)
print("Chosen k =", best_k)

In [None]:
fractions['cluster'].value_counts()

In [None]:
# rename clusters based on size: RN1 = biggest, RN11=smallest
name_map = {
    12: 'RN1', 
    6: 'RN2', 
    4: 'RN3', 
    5: 'RN4', 
    9: 'RN5', 
    2: 'RN6',
    3: 'RN7',
    8: 'RN8',
    11: 'RN9',
    1: 'RN10',
    0: 'RN11',
    10: 'RN12',
    7:'RN13'
}
fractions["rn"] = fractions["cluster"].map(name_map)
order = list(name_map.values())
print(order)
fractions["rn"] = pd.Categorical(fractions["rn"], categories=order, ordered=True)
fractions

In [None]:
countsdf['rn'] = fractions['rn']
barcode_to_annot = merged.obs["ct"].to_dict()
countsdf = countsdf.copy()
countsdf["ct"] = countsdf.index.map(barcode_to_annot)
countsdf

In [None]:
tally = (
    countsdf
    .groupby(["rn", "ct"], observed=True)
    .size()
    .reset_index(name="n_cells")
)
grouped = (
    tally.pivot(index="ct", columns="rn", values="n_cells")
         .fillna(0)
).T
grouped

In [None]:
sns.clustermap(grouped.transpose(), cmap='viridis', standard_scale=0, col_cluster=False)
plt.savefig(output_dir / 'clusters_heatmap_counts_scaled_by_celltype.svg')

In [None]:
clusters = grouped.index
cell_types = grouped.columns

odds_matrix = pd.DataFrame(np.nan, index=clusters, columns=cell_types)
pval_matrix = pd.DataFrame(np.nan, index=clusters, columns=cell_types)

from scipy.stats import fisher_exact, hypergeom
from statsmodels.stats.multitest import multipletests

N = grouped.values.sum()         # total cells
row_totals = grouped.sum(axis=1) # cluster totals
col_totals = grouped.sum(axis=0) # cell-type totals

for rn in clusters:
    for ct in cell_types:

        a = grouped.loc[rn, ct]               # in RN & of type CT
        b = row_totals[rn] - a               # in RN & not CT
        c = col_totals[ct] - a               # not RN & type CT
        d = N - (a + b + c)                  # not RN & not CT

        tbl = np.array([[a, b], [c, d]])
        
        # Fisher (enrichment)
        odds, fisher_p = fisher_exact(tbl, alternative='greater')
        odds_matrix.loc[rn, ct] = odds

        # Hypergeometric (over-representation)
        m = col_totals[ct]
        k = row_totals[rn]
        q = a - 1 if a > 0 else 0
        
        pval_matrix.loc[rn, ct] = hypergeom.sf(q, N, m, k)

# Stabilize zeros 
pvals = pval_matrix.replace(0, 1e-300).values.flatten()

# BH adjust
adj = multipletests(pvals, method='fdr_bh')[1]

# Back to DataFrame
pval_adj = pd.DataFrame(
    adj.reshape(pval_matrix.shape),
    index=clusters, columns=cell_types
)


In [None]:
logOR = np.log2(odds_matrix.replace(0, np.nan)).fillna(0)
scaled = (logOR - logOR.mean()) / logOR.std()
def p_to_star(p):
    if p < 0.05: return '.'
    else: return ''
ann = pval_adj.applymap(p_to_star)

logOR

In [None]:
g = sns.clustermap(
    scaled.T,
    cmap='coolwarm',
    row_cluster=True,
    col_cluster=False
)

row_order = g.dendrogram_row.reordered_ind
scaled_reordered = scaled.T.iloc[row_order, :]
ann_reordered = ann.T.iloc[row_order, :]
plt.figure(figsize=(12, 8))
ax = sns.heatmap(
    scaled_reordered,
    cmap='coolwarm',
    center=0, #vmin=-1, vmax=1,
    annot=ann_reordered,
    fmt='',square=True, 
    linewidths=0,
    cbar_kws={'label': 'Scaled log2(OR)'}
)
plt.tight_layout()
plt.savefig(output_dir / 'clusters_heatmap_oddsratio_scaled_by_celltype_signif.svg')


In [None]:
g = sns.clustermap(
    grouped.T, standard_scale=0,
    cmap='viridis',
    row_cluster=True,
    col_cluster=False
)

row_order = g.dendrogram_row.reordered_ind
scaled_reordered = scaled.T.iloc[row_order, :]
ann_reordered = ann.T.iloc[row_order, :]
plt.figure(figsize=(12, 8))
ax = sns.heatmap(
    scaled_reordered,
    cmap='coolwarm',
    center=0, #vmin=-1, vmax=1,
    annot=ann_reordered,
    fmt='',square=True, 
    linewidths=0,
    cbar_kws={'label': 'Scaled log2(OR)'}
)
plt.tight_layout()
plt.savefig(output_dir / 'clusters_heatmap_oddsratio_scaled_by_celltype_signif_alternativeDendo.svg')


In [None]:
# color neighborhood by most prominent cell type, in order of cluster 0:13
neighborhood_colors = {
    "RN1":  "#079450",  # later granulo
    "RN2":  "#9e9e9e", # erythroid
    "RN3":  "#ff42ca", # PC
    "RN4": "#00ff1e", # early granulo/mye 
    "RN5": "#241717", # MKC
    "RN6":  '#00ba9e', # other myelo (cDC, ba/eo/ma, low confidence)
    "RN7":  "#00f7ff",  # early B and myelo
    "RN8":  "#b50d0d",  # cytotoxic T NK
    "RN9":  "#de9835",  # endothelial
    "RN10": "#c6db02",  # HSPC
    "RN11": "#7875ff",  # lymphoid
    "RN12": "#fabc02",  # fibro/osteo
    "RN13": "#735b2e",  # pericyte
}
colors = list(neighborhood_colors.values())
plt.figure(figsize=(6,1))
for i, c in enumerate(colors):
    plt.bar(i, 1, color=c)
plt.axis("off"); plt.show()

In [None]:
fractions.to_csv(output_dir / 'all_fractions.txt', sep='\t')
countsdf.to_csv(output_dir / 'all_counts.txt', sep='\t')
grouped.to_csv(output_dir / 'cluster_celltype_total.txt', sep='\t')

In [None]:
sids = sorted(set(fractions['Sample']))
ncols = 5
nrows = (len(sids) // ncols) + 1
nrows, ncols

In [None]:
# plot each sample
#import colorcet as cc
#colors = sns.color_palette("Set2", 12).as_hex() 

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20,40))
for sid, ax in zip(sids, axs.flatten()):
    #print(sid)
    subset = fractions[fractions['Sample']==sid].copy()
    
    sns.scatterplot(subset, x='x_centroid', y='y_centroid', hue='rn', s=.1, ax=ax, palette=colors, legend=False)
    ax.axis('equal')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(sid)
    ax.collections[0].set_rasterized(True)

plt.savefig(output_dir / 'clusters.svg')

In [None]:
# plot each sample with DI sample ID
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 40))
for sid, ax in zip(sids, axs.flatten()):
    di = sample_to_DI_Sample[sid]
    #print(di)
    subset = fractions[fractions['Sample']==sid].copy()
    
    sns.scatterplot(subset, x='x_centroid', y='y_centroid', hue='rn', s=.1, ax=ax, palette=colors, legend=False)
    ax.axis('equal')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(di)
    ax.collections[0].set_rasterized(True)

plt.savefig(output_dir / 'clusters_DI.svg')

In [None]:
print(fractions.shape, merged.shape)

In [None]:
from sklearn.decomposition import PCA
'''
name_map = {
    12: 'RN1', 
    6: 'RN2', 
    4: 'RN3', 
    5: 'RN4', 
    9: 'RN5', 
    2: 'RN6',
    3: 'RN7',
    8: 'RN8',
    11: 'RN9',
    1: 'RN10',
    0: 'RN11',
    10: 'RN12',
    7:'RN13'
}
'''
neighborhood_colors = {
    "12":  "#079450",  # later granulo
    "6":  "#9e9e9e", # erythroid
    "4":  "#ff42ca", # PC
    "5": "#00ff1e", # early granulo/mye 
    "9": "#241717", # MKC
    "2":  '#00ba9e', # other myelo (cDC, ba/eo/ma, low confidence)
    "3":  "#00f7ff",  # early B and myelo
    "8":  "#b50d0d",  # cytotoxic T NK
    "11":  "#de9835",  # endothelial
    "1": "#c6db02",  # HSPC
    "0": "#7875ff",  # lymphoid
    "10": "#fabc02",  # fibro/osteo
    "7": "#735b2e",  # pericyte
}

# 2D PCA
pca = PCA(n_components=2, random_state=42)
X_pca = pca.fit_transform(Xs)

# Combine metadata
plot_df = pd.DataFrame({
    'PC1': X_pca[:, 0],
    'PC2': X_pca[:, 1],
    'Cluster': fractions['cluster'].astype(str),
    'Collection': fractions['Sample'].map(sample_to_collection)
})


fig, ax = plt.subplots(figsize=(7, 6))
sns.scatterplot(
    data=plot_df,
    x='PC1', y='PC2',
    hue='Cluster', #style='Collection',
    palette=neighborhood_colors, s=50, ax=ax,
    rasterized=True                 
)
ax.set_title(f'K-means (k={best_k}) on Fraction Profiles (PCA 2D)')
ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% var)')
ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% var)')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False)

out_pdf = f"kmeans_PCA_k{best_k}_rasterized.pdf"
plt.tight_layout()
fig.savefig(output_dir / out_pdf, dpi=300, bbox_inches='tight', transparent=False)
plt.close(fig)


In [None]:
# add cluster to metadata
clusts = fractions.reindex(merged.obs_names)['rn']
clusts = clusts.astype("category").cat.add_categories(["Unassigned"]).fillna("Unassigned")

merged.obs['rn'] = clusts.values
merged.obs['rn'].value_counts()

In [None]:
merged.obs

In [None]:
merged.write(output_dir / "merged_RN.h5ad")

In [None]:
# print annotation csvs
#merged.obs["Original_Barcode"] = merged.obs_names.str.rsplit("_", n=1).str[-1]
sids = sorted(set(merged.obs['Sample']))

outdir =  Path('/diskmnt/Projects/myeloma_scRNA_analysis/MMY_IRD/Xenium/analysis/radial_neighborhoods/Output/annotations_rn')

for samp in sids:
    df = (
        merged.obs.loc[merged.obs["Sample"] == samp, ["Original_Barcode", "rn"]]
        .rename(columns={"Original_Barcode": "cell_id", "rn": "group"})
    )
    df.to_csv(outdir / f"{samp}_rn.csv", index=False)

In [None]:
fractions = pd.read_csv((output_dir / 'all_fractions.txt'), sep='\t')

In [None]:
fractions.head()

In [None]:

cat_order = [
    "HSPC",
    "Erythroid",
    "Megakaryocyte",
    "GMP",
    "Late Myeloid",
    "Neutrophil",
    "Ba/Eo/Ma",
    "cDC",
    "Monocyte",
    "Macrophage",
    "pDC",
    "CD4 T",
    "CD8 T",
    "NK",
    "Early B",
    "Mature B",
    "PC",
    "MSC",
    "Fibro/Osteo",
    "Adipocyte",
    "Endothelial",
    "vSMC/Pericyte",
    "Low Confidence"]
frac_f = fractions[cat_order]
frac_f.head()

In [None]:
frac_f.head().sum(axis=1)

In [None]:
sns.clustermap(
    frac_f.sample(25), standard_scale='1',
    cmap='viridis',
    row_cluster=False,
    col_cluster=False,
)
plt.savefig(output_dir/'fractions_heatmap.pdf')

In [None]:
# barplot of ct counts
# Count cells per RN
rn_counts = merged.obs['rn'].value_counts().reset_index()
rn_counts.columns = ['rn', 'n_cells']
rn_order = merged.obs['rn'].cat.categories
print(rn_order)
rn_counts = rn_counts.set_index('rn').loc[rn_order].reset_index()
rn_counts['rn']=rn_counts['index']

# Map colors
neighborhood_colors = {
    "RN1":  "#079450",  # later granulo
    "RN2":  "#9e9e9e", # erythroid
    "RN3":  "#ff42ca", # PC
    "RN4": "#00ff1e", # early granulo/mye 
    "RN5": "#241717", # MKC
    "RN6":  '#00ba9e', # other myelo (cDC, ba/eo/ma, low confidence)
    "RN7":  "#00f7ff",  # early B and myelo
    "RN8":  "#b50d0d",  # cytotoxic T NK
    "RN9":  "#de9835",  # endothelial
    "RN10": "#c6db02",  # HSPC
    "RN11": "#7875ff",  # lymphoid
    "RN12": "#fabc02",  # fibro/osteo
    "RN13": "#735b2e",  # pericyte
}
colors = [neighborhood_colors.get(rn, 'white') for rn in rn_counts['rn']]

# Plot
fig, ax = plt.subplots(figsize=(8, 5))
ax.bar(rn_counts['rn'], rn_counts['n_cells'], width=1.0,color=colors)

ax.set_ylabel("Cell count")
ax.set_xlabel("Radial Neighborhood (RN)")
ax.set_title("Cell counts per RN")
plt.xticks(rotation=90)

plt.tight_layout()
plt.savefig(output_dir / 'ncells_rn_barplot.pdf')

In [None]:
rn_counts