In [None]:
import os
import re 
from pathlib import Path

import pandas as pd
import numpy as np
import scanpy as sc
import scanpy.external as sce
import anndata as ad
import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib.backends.backend_pdf import PdfPages

os.chdir('/diskmnt/Projects/myeloma_scRNA_analysis/MMY_IRD/Xenium/analysis/')
os.getcwd()

import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42 #make text editable in pdf
mpl.rcParams['svg.fonttype'] = 'none'

In [None]:
merged = sc.read_h5ad('merged.h5ad')

In [None]:
merged

In [None]:
# clustering-based annotations from WC

clustmeta = pd.read_csv("/diskmnt/Users2/chouw/Projects/BM_spatial/IRD/IRD_JW_Xenium_merge_sketched_projected_annot.csv", sep=",", header=0, index_col=False)

In [None]:
clustmeta['barcode'] = clustmeta['Sample'].astype(str) + "_" + clustmeta['Original_Barcode'].astype(str)
clustmeta = clustmeta.set_index('barcode')

In [None]:
clustmeta_matched = clustmeta.reindex(merged.obs.index)
clustmeta_matched

In [None]:
merged.obs['annot'].value_counts()

In [None]:
merged.obs['clust'] = clustmeta_matched['celltype.full'].values
merged.obs['clust'].value_counts()

In [None]:
cat_order = [
    "HSPC",
    "Erythroid",
    "Megakaryocyte",
    "GMP",
    "Late Myeloid",
    "Neutrophil",
    "Ba/Eo/Ma",
    "cDC2",
    "Monocyte",
    "Macrophage",
    "pDC",
    "CD4 T",
    "CD8 T",
    "NK_T",
    "NK",
    "Early B",
    "Mature B",
    "Plasma Cell",
    "MSC",
    "Osteoblast",
    "Adipocyte",
    "Endothelial",
    "vSMC",
    "Low Confidence"]

merged.obs['clust'] = pd.Categorical(
    merged.obs['clust'], categories=cat_order, ordered=True
)  

In [None]:
ctmarkers_simple = {
    "HSPC": ["CD34", "AVP", "SPINK2", "SMIM24", "KIT", "GATA2"],
    "Erythroid": [ 'GATA1', 'AHSP', 'ALAS2', 'HEMGN', 'SLC4A1'],
    'MKC':['PF4', 'PLEK'],
    "Granulo": ["MPO", "CAMP", "LTF", "MMP9", "S100A12", "MS4A2", "CPA3"],
    "cDC": ["CLEC10A", "CD1C", "CD1A", "CD1E", "GPR183"],
    "Mc/Mp": [ "CD14", "FCN1", "FCGR1A", "FCGR3A", "CD68", "AIF1", "MRC1", "CD163", "VSIG4"],
    "pDC": ["IRF8", "RUNX2", "LILRA4", "IL3RA", "GZMB"],
    "T Cell": ["CD3D", "CD3E", "TRAC", "CD2", "CD4", "CD8A", "CD247", "IL7R", "FOXP3"],
    "NK Cell": ["NKG7", "GNLY", "GZMA",  "GZMK", "PRF1", "KLRB1", "KLRC1", "KLRD1"],
    "B Cell": [ "VPREB1", "SOX4", "PAX5", "CD19", "MS4A1", "CD79A"],
    "PC": ["MZB1", "SLAMF7", "TNFRSF17", "TENT5C", "PRDM1"],
    "MSC": ["LEPR", "KITLG", "CXCL12", "THY1"],
    "Fibro.": ["PDGFRA", "PDGFRB", "COL5A2", "FBLN1"],
    "Osteo": ["BGLAP", "SPP1"],
    "Adipo": ["FABP4", "ADIPOQ", "PPARG"],
    "Endo":["PECAM1", "VWF", "EGFL7", "CLEC14A", "KDR", "ENG", "FLT4", "ACTA2", "CNN1", "MYH11",]
}   

In [None]:
sc.pl.dotplot(
    merged,
    var_names=ctmarkers_simple,
    groupby="clust",
    standard_scale="var",
    figsize=(22,8),
    show=False
)

plt.savefig("celltype_markers_clustWC.pdf", bbox_inches="tight")

In [None]:
mults = merged[merged.obs['annot']=='Multiplet'].copy()
sc.pl.dotplot(
    mults,
    var_names=ctmarkers_simple,
    groupby="clust",
    standard_scale="var",
    figsize=(22,8),
    show=True
)

In [None]:
unks = merged[merged.obs['annot']=='Unknown'].copy()
sc.pl.dotplot(
    unks,
    var_names=ctmarkers_simple,
    groupby="clust",
    standard_scale="var",
    figsize=(22,8),
    show=True
)

In [None]:
no_mults = merged[merged.obs['annot']!='Multiplet'].copy()
sc.pl.dotplot(
    no_mults,
    var_names=ctmarkers_simple,
    groupby="clust",
    standard_scale="var",
    figsize=(22,8),
    show=True
)

In [None]:
clust_unks = merged[merged.obs['clust']=='Low Confidence'].copy()
print(clust_unks.obs['annot'].value_counts())
sc.pl.dotplot(
    clust_unks,
    var_names=ctmarkers_simple,
    groupby="annot",
    standard_scale="var",
    figsize=(22,8),
    show=True
)

In [None]:
# reassign cell types
merged.obs['ct'] = merged.obs['clust'].astype(str)

merged.obs.loc[merged.obs["clust"] == "NK", "ct"] = "Low Confidence"
merged.obs.loc[merged.obs["clust"] == "NK_T", "ct"] = "NK"
merged.obs.loc[merged.obs["clust"] == "cDC2", "ct"] = "cDC"
merged.obs.loc[merged.obs["annot"] == "cDC", "ct"] = "cDC"
merged.obs.loc[merged.obs["clust"] == "Osteoblast", "ct"] = "Fibro/Osteo"
merged.obs.loc[merged.obs["clust"] == "vSMC", "ct"] = "vSMC/Pericyte"
merged.obs.loc[merged.obs["clust"] == "Plasma Cell", "ct"] = "PC"

# rescue low-confidence cells based on gating
merged.obs.loc[(merged.obs["clust"] == "Low Confidence") & (merged.obs['annot']=='HSPC'), "ct"] = "HSPC"
merged.obs.loc[(merged.obs["clust"] == "Low Confidence") & (merged.obs['annot']=='Erythro'), "ct"] = "Erythroid"
merged.obs.loc[(merged.obs["clust"] == "Low Confidence") & (merged.obs['annot']=='MKC'), "ct"] = "Megakaryocyte"
merged.obs.loc[(merged.obs["clust"] == "Low Confidence") & (merged.obs['annot']=='pDC'), "ct"] = "pDC"
merged.obs.loc[(merged.obs["clust"] == "Low Confidence") & (merged.obs['annot']=='NK'), "ct"] = "NK"
merged.obs.loc[(merged.obs["clust"] == "Low Confidence") & (merged.obs['annot']=='PC'), "ct"] = "PC"
merged.obs.loc[(merged.obs["clust"] == "Low Confidence") & (merged.obs['annot']=='MSC'), "ct"] = "MSC"
merged.obs.loc[(merged.obs["clust"] == "Low Confidence") & (merged.obs['annot']=='Fibro/Osteo'), "ct"] = "Fibro/Osteo"
merged.obs.loc[(merged.obs["clust"] == "Low Confidence") & (merged.obs['annot']=='Adipo'), "ct"] = "Adipocyte"
merged.obs.loc[(merged.obs["clust"] == "Low Confidence") & (merged.obs['annot']=='Endo/Pericyte'), "ct"] = "Endothelial"

# rescue low-confidence cells based on gating 

# If MPO is highest: "GMP"
# If CAMP is highest: "Late Myeloid"
# If MMP9 is highest: "Neutrophil"
genes = ["MPO", "CAMP", "MMP9"]
expr = np.vstack([
    merged[:, g].X.toarray().ravel() if hasattr(merged[:, g].X, "toarray") else merged[:, g].X
    for g in genes
]).T 
mask = (merged.obs["annot"] == "Granulo") & (merged.obs["clust"] == "Low Confidence")
idx_max = np.argmax(expr[mask], axis=1)   # 0=MPO, 1=CAMP, 2=MMP9
labels = np.array(["GMP", "Late Myeloid", "Neutrophil"])
merged.obs.loc[mask, "ct"] = labels[idx_max]

# If CD4 is highest: "CD4 T"
# If CD8A is highest: "CD8 T"
genes = ["CD4", "CD8A"]
expr = np.vstack([
    merged[:, g].X.toarray().ravel() if hasattr(merged[:, g].X, "toarray") else merged[:, g].X
    for g in genes
]).T 
mask = (merged.obs["annot"] == "T") & (merged.obs["clust"] == "Low Confidence") 
idx_max = np.argmax(expr[mask], axis=1)   
labels = np.array(["CD4 T", "CD8 T"])
merged.obs.loc[mask, "ct"] = labels[idx_max]

# If MS4A1 is highest: "Mature B"
# If VPREB1 is highest: "Early B"
genes = ["MS4A1", "VPREB1"]
expr = np.vstack([
    merged[:, g].X.toarray().ravel() if hasattr(merged[:, g].X, "toarray") else merged[:, g].X
    for g in genes
]).T 
mask = (merged.obs["annot"] == "B") & (merged.obs["clust"] == "Low Confidence")
idx_max = np.argmax(expr[mask], axis=1)   
labels = np.array(["Mature B", "Early B"])
merged.obs.loc[mask, "ct"] = labels[idx_max]

# If CD163 is highest: "Macrophage"
# If CD14 is highest: "Monocyte"
genes = ["CD163", "CD14"]
expr = np.vstack([
    merged[:, g].X.toarray().ravel() if hasattr(merged[:, g].X, "toarray") else merged[:, g].X
    for g in genes
]).T 
mask = (merged.obs["annot"] == "Mc/Mp") & (merged.obs["clust"] == "Low Confidence")
idx_max = np.argmax(expr[mask], axis=1)   
labels = np.array(["Macrophage", "Monocyte"])
merged.obs.loc[mask, "ct"] = labels[idx_max]

cat_order = [
    "HSPC",
    "Erythroid",
    "Megakaryocyte",
    "GMP",
    "Late Myeloid",
    "Neutrophil",
    "Ba/Eo/Ma",
    "cDC",
    "Monocyte",
    "Macrophage",
    "pDC",
    "CD4 T",
    "CD8 T",
    "NK",
    "Early B",
    "Mature B",
    "PC",
    "MSC",
    "Fibro/Osteo",
    "Adipocyte",
    "Endothelial",
    "vSMC/Pericyte",
    "Low Confidence"]

merged.obs['ct'] = pd.Categorical(
    merged.obs['ct'], categories=cat_order, ordered=True
)  

In [None]:
clust_unks = merged[merged.obs['clust']=='Low Confidence'].copy()
sc.pl.dotplot(
    clust_unks,
    var_names=ctmarkers_simple,
    groupby="ct",
    standard_scale="var",
    figsize=(22,8),
    show=True
)

In [None]:
sc.pl.dotplot(
    merged,
    var_names=ctmarkers_simple,
    groupby="ct",
    standard_scale="var",
    figsize=(22,8),
    show=False
)

plt.savefig("celltype_markers_ct.pdf", bbox_inches="tight")

In [None]:
merged.obs['ct'].value_counts()

In [None]:
 613423/5726156

In [None]:
merged.write('merged.h5ad')

In [None]:
# check markers for every sample
all_samples = merged.obs['Sample'].unique()
pdf_out = "ct_per_sample_marker_dotplots.pdf"
with PdfPages(pdf_out) as pdf:
    for s in all_samples:
        print(s)
        sobj = merged[merged.obs['Sample'] == s].copy()

        dp = sc.pl.dotplot(
            sobj,
            var_names=ctmarkers_simple,
            groupby="ct",
            standard_scale="var",
            return_fig=True,
            show=False,
            title=s
        )
        
        fig = dp.make_figure()
        #fig.suptitle(f"Sample: {s}", y=1.02)
        pdf.savefig(fig, bbox_inches="tight")
        plt.close(dp.fig)


In [None]:
ct_palette = {
    "HSPC": "#d6e376",
    "Erythroid": "#cfcfcf" ,
    "Megakaryocyte": "#8f8f8f",
    "GMP": "#88cf46",
    "Late Myeloid": "#4ab300",
    "Neutrophil": "#95ad74",
    "Ba/Eo/Ma": "#618038",
    "cDC": "#3bff8c",
    "Monocyte": "#3dd49f",
    "Macrophage": "#03ab70" ,
    "pDC": "#a5c3c4",
    "CD4 T": "#ff8400",
    "CD8 T": "#ff0000",
    "NK": "#9302d1",
    "Early B": "#7cb2e6" ,
    "Mature B": "#045eb5",
    "PC": "#ffbafd",
    "MSC": "#cfc10a",
    "Fibro/Osteo": "#ba9e00",
    "Adipocyte": "#ffe600",
    "Endothelial": "#cc7e7e",
    "vSMC/Pericyte": "#ad4b8e",
    "Low Confidence": "#FFFFFF"
}

colors = list(ct_palette.values())
plt.figure(figsize=(6,1))
for i, c in enumerate(colors):
    plt.bar(i, 1, color=c)
plt.axis("off"); plt.show()

merged.uns["ct_colors"] = [
    ct_palette[cat] for cat in merged.obs["ct"].cat.categories
]

In [None]:
obs = merged.obs.copy()
obs = obs.sort_values(
    by=["Panel", "Collection"],
    ascending=[True, True], kind="mergesort"
)
cats = pd.unique(obs["DI_Sample"])
obs["DI_Sample"] = pd.Categorical(obs["DI_Sample"], categories=cats, ordered=True)

counts = pd.crosstab(obs['DI_Sample'], obs['ct']).astype(int)
counts = counts.loc[cats]

props = counts.div(counts.sum(axis=1), axis=0).fillna(0.0)
fig, ax = plt.subplots(figsize=(12, 6))
props.plot(
    kind="bar",
    stacked=True,
    ax=ax,
    width=0.98, 
    color=[ct_palette.get(c, '#FFFFFF') for c in props.columns],
)
ax.set_xlabel("sample")
ax.set_ylabel("proportion")
ax.margins(x=0)             # remove left/right x-axis padding
ax.legend(loc="center right", bbox_to_anchor=(-0.02, 0.5), frameon=False, title="cell type")
fig.subplots_adjust(left=0.25)
plt.xticks(rotation=90, ha="right")
plt.tight_layout()
fig.savefig("ct_per_DI_sample_proportion_plot.pdf", bbox_inches="tight")  # save to PDF
plt.close(fig)

In [None]:
sids = sorted(set(merged.obs['DI_Sample']))

with PdfPages("ct_per_DI_sample_scatterplots.pdf") as pdf:
    for sid in sids:
        f = merged[merged.obs["DI_Sample"] == sid].copy()
        cats_all = list(f.obs["ct"].cat.categories)
        pal_all  = [ct_palette.get(c, "#FFFFFF") for c in cats_all]
        fig, ax = plt.subplots(1, 1, figsize=(6, 6), constrained_layout=True)
        sc.pl.scatter(
            f,
            x="x_centroid", y="y_centroid",
            color="ct",
            palette=pal_all,
            ax=ax,
            legend_loc="none",
            show=False,
            size=2
        )

        ax.set_title(f"{sid} â€” all cells")
        ax.set_aspect("equal")
        ax.invert_yaxis()

        # Rasterize scatter points only
        for coll in ax.collections:
            coll.set_rasterized(True)

        pdf.savefig(fig)
        plt.close(fig)


In [None]:
merged.obs.head()

In [None]:
merged.obs.to_csv("merged_metadata.csv")

In [None]:
merged.write("merged.h5ad")

In [None]:
ctmarkers_short= {
    "HSPC": ["CD34", "AVP", "SPINK2",],
    "Erythroid": [ 'GATA1', 'AHSP', 'HEMGN',],
    'MKC':['PF4', 'PLEK'],
    "Granulo/Myelo": ["MPO", "ELANE", "CAMP", "LTF", "MMP9", "S100A12", "CPA3"],
    "cDC": ["CLEC10A", "CD1C", "CD1E"],
    "Mc/Mp": [ "CD14", "FCN1", "FCGR3A", "CD68", "CD163", "VSIG4"],
    "pDC": ["IRF8", "RUNX2", "LILRA4",],
    "T Cell": ["CD3D", "CD3E", "TRAC", "IL7R", "CD8A", "CD247", ],
    "NK Cell": ["NKG7", "GNLY","KLRD1"],
    "B Cell": [ "VPREB1", "PAX5", "CD19", "CD79A", "MS4A1"],
    "PC": ["MZB1", "SLAMF7", "TNFRSF17"],
    "MSC": ["LEPR", "KITLG", "CXCL12", "THY1"],
    "Fibro.": ["PDGFRA", "COL5A2", "FBLN1"],
    "Osteo.": ["BGLAP", "SPP1"],
    "Adipo.": ["FABP4", "ADIPOQ", "PPARG"],
    "Endo/vSMC":["VWF", "KDR", "ENG", "FLT4",  "ACTA2", "MYH11"],

}   

sc.pl.dotplot(
    merged,
    var_names=ctmarkers_short,
    groupby="ct",
    standard_scale="var",
    figsize=(15.5,4.5),
    show=False
)
plt.savefig('ct_markers_simplified_dotplot.pdf')

In [None]:
# barplot of ct counts
# Count cells per ct
ct_counts = merged.obs['ct'].value_counts().reset_index()
ct_counts.columns = ['ct', 'n_cells']
ct_order = merged.obs['ct'].cat.categories
print(ct_order)
ct_counts = ct_counts.set_index('ct').loc[ct_order].reset_index()
ct_counts['ct']=ct_counts['index']

In [None]:
# Map colors
colors = [ct_palette.get(ct, 'white') for ct in ct_counts['ct']]

# Plot
fig, ax = plt.subplots(figsize=(8, 5))
ax.bar(ct_counts['ct'], ct_counts['n_cells'], color=colors)

ax.set_ylabel("Cell count")
ax.set_xlabel("Cell type (ct)")
ax.set_title("Cell counts per ct")
plt.xticks(rotation=90)

plt.tight_layout()
plt.savefig('celltyping_plots/ncells_ct_barplot.pdf')

In [None]:
# print annotation csvs
merged.obs["Original_Barcode"] = merged.obs_names.str.rsplit("_", n=1).str[-1]
sids = sorted(set(merged.obs['Sample']))

outdir = Path("annotations")

for samp in sids:
    df = (
        merged.obs.loc[merged.obs["Sample"] == samp, ["Original_Barcode", "ct"]]
        .rename(columns={"Original_Barcode": "cell_id", "ct": "group"})
    )
    df['color'] = df['group'].astype(str).map(ct_palette).fillna("#000000")
    df.to_csv(outdir / f"{samp}_ct.csv", index=False)