In [None]:
import os
from pathlib import Path
from typing import Annotated

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc
import seaborn as sns
import tifffile

from sklearn.cluster import KMeans
from skimage.color import label2rgb
from sklearn.neighbors import radius_neighbors_graph
from sklearn.neighbors import NearestNeighbors

from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import euclidean_distances
from scipy.stats import wilcoxon
from statsmodels.stats.multitest import multipletests
from scipy.stats import entropy, chi2_contingency
from matplotlib.backends.backend_pdf import PdfPages
from statannotations.Annotator import Annotator

from scipy import sparse

plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42 #make text editable in pdf

os.chdir('/diskmnt/Projects/myeloma_scRNA_analysis/MMY_IRD/revision/merge/compare_celltype_abundance/')
os.getcwd()

In [None]:
merged = sc.read_h5ad('/diskmnt/Projects/myeloma_scRNA_analysis/MMY_IRD/revision/merge/no_harmony/combined_cleaned.h5ad')

In [None]:
merged.obs[merged.obs["Sample"].str.contains("OSU00")]

In [None]:
uc_to_collection = merged.obs.set_index('UPN_Collection')['Collection'].to_dict()

In [None]:
collection_order = ["NBM", "NDMM", "PT"]
timecols = {"NBM": "#0C7515", "NDMM": "#E619B9", "PT": "#CF99C3"} 

In [None]:
def remove_outliers_iqr(df, col='frac_subset', k=1.5):
    q1 = df[col].quantile(0.25)
    q3 = df[col].quantile(0.75)
    iqr = q3 - q1
    lower = q1 - k * iqr
    upper = q3 + k * iqr
    return df[(df[col] >= lower) & (df[col] <= upper)]


In [None]:
def plot_fraction_boxplot(data, subset_list, title, pdf):
    width = len(subset_list)
    fig, ax = plt.subplots(figsize=(width, 5))
    dataplot = data[data["subset"].isin(subset_list)]
    
    filtered = (
        dataplot.groupby(['subset', 'Collection'], observed=True, group_keys=False)
               .apply(remove_outliers_iqr)
    )

    sns.boxplot(
        data=filtered,
        x="subset", y="frac_subset",
        hue="Collection",
        hue_order=collection_order,
        palette=timecols,
        fliersize=0, linewidth=1, ax=ax
    )

    sns.stripplot(
        data=filtered,
        x="subset", y="frac_subset",
        hue="Collection",
        hue_order=collection_order,
        dodge=True, alpha=1, size=2,
        palette="dark:black", ax=ax
    )

    # remove duplicate legend entries
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[:len(collection_order)],
              labels[:len(collection_order)],
              title="Collection",
              bbox_to_anchor=(1.05, 1), loc="upper left", frameon=False)

    # significance pairs (same logic as before)
    pairs = []
    combos = [("NBM", "NDMM"), ("NDMM", "PT"), ("NBM", "PT")]
    for a in subset_list:
        available = set(data.loc[data["subset"] == a, "Collection"].unique())
        for c1, c2 in combos:
            if {c1, c2}.issubset(available):
                pairs.append(((a, c1), (a, c2)))

    if pairs:
        annotator = Annotator(
            ax, pairs, data=data, x="subset", y="frac_subset",
            hue="Collection", hue_order=collection_order
        )
        annotator.configure(
            test="Mann-Whitney", text_format="star", loc="inside",
            comparisons_correction="BH", hide_non_significant=True, verbose=0
        )
        annotator.apply_and_annotate()

    ax.set_ylabel("Fraction of cells")
    ax.set_xlabel("Cell Type")
    ax.set_title(title)
    sns.despine(ax=ax)
    plt.tight_layout()
    pdf.savefig(fig, bbox_inches="tight")
    plt.close(fig)


In [None]:
obs = merged.obs[['UPN_Collection', 'UPN', 'Collection', 'subset']]
obs = obs[obs['Collection'].isin(collection_order)]

# Compute per-UPN fractions of each subset
counts = (
    obs.groupby(["UPN_Collection", "subset"], observed=False)
        .size()
        .reset_index(name="n_cells")
)

totals = (
    counts.groupby(["UPN_Collection"], observed=True)["n_cells"]
          .sum()
          .reset_index(name="total_cells")
)

frac_df = counts.merge(totals, on=["UPN_Collection"])
frac_df["frac_subset"] = frac_df["n_cells"] / frac_df["total_cells"]
frac_df[['UPN', 'Collection']] = frac_df['UPN_Collection'].str.split('|', expand=True)

# get order of subsets based on NBM abundance
nbm = frac_df[frac_df['Collection'] == 'NBM']
median_nbm = (
    nbm.groupby('subset', observed=False)['frac_subset']
        .median()
        .sort_values(ascending=False)
)
subset_order = median_nbm.index.tolist()
frac_df['subset'] = pd.Categorical(frac_df['subset'], categories=subset_order, ordered=True)

split_threshold = 0.02  # 2% cutoff
median_nbm_pct = median_nbm.copy()  # from your earlier code
high_subsets = median_nbm_pct[median_nbm_pct > split_threshold].index.tolist()
low_subsets  = median_nbm_pct[median_nbm_pct <= split_threshold].index.tolist()


pdf_out = "subset_relative_abundance_split_yaxes.pdf"
with PdfPages(pdf_out) as pdf:
    plot_fraction_boxplot(frac_df, high_subsets,
                          "High-abundance subsets", pdf)
    plot_fraction_boxplot(frac_df, low_subsets,
                          "Low-abundance subsets (zoomed-in)", pdf)


In [None]:
median_nbm

In [None]:
# now exclude PCs

obs=obs[obs['subset'] != 'PC']
obs['subset'] = obs['subset'].cat.remove_unused_categories()

# Compute per-UPN fractions of each subset
counts = (
    obs.groupby(["UPN_Collection", "subset"], observed=False)
        .size()
        .reset_index(name="n_cells")
)

totals = (
    counts.groupby(["UPN_Collection"], observed=True)["n_cells"]
          .sum()
          .reset_index(name="total_cells")
)

frac_df = counts.merge(totals, on=["UPN_Collection"])
frac_df["frac_subset"] = frac_df["n_cells"] / frac_df["total_cells"]
frac_df[['UPN', 'Collection']] = frac_df['UPN_Collection'].str.split('|', expand=True)

# get order of subsets based on PT abundance
pt = frac_df[frac_df['Collection'] == 'PT']
median_pt = (
    pt.groupby('subset', observed=False)['frac_subset']
        .median()
        .sort_values(ascending=False)
)
subset_order = median_pt.index.tolist()
frac_df['subset'] = pd.Categorical(frac_df['subset'], categories=subset_order, ordered=True)
median_pt

In [None]:
high_subsets = ['CD8T', 'CD14 Mc', 'Early Ery', 'CD4T', 'Naive B', 'Transitional B', 'gdT/NK', ]
mid_subsets  = [ 'Late Ery', 'cDC', 'CD16 Mc / TAM', 'Immature B', 'Neutrophil', 'HSPC', 'Pro/Pre B', 'pDC',  ]
low_subsets = [  'T Stim/Exh', 'MSC', 'Memory B', 'CLP', 'MKC' ]


pdf_out = "subset_relative_abundance_split_yaxes_exclPC.pdf"
with PdfPages(pdf_out) as pdf:
    plot_fraction_boxplot(frac_df, high_subsets,
                          "High-abundance subsets", pdf)
    plot_fraction_boxplot(frac_df, mid_subsets,
                          "Mid-abundance subsets", pdf)
    plot_fraction_boxplot(frac_df, low_subsets,
                          "Low-abundance subsets", pdf)
    

In [None]:
# get paired comparisons of abundance in non-PC fraction

paired_upns = (
    frac_df[frac_df['Collection'].isin(['NDMM', 'PT'])]
    .groupby('UPN', observed=True)['Collection']
    .nunique()
)
paired_upns = paired_upns[paired_upns == 2].index.tolist()
paired_upns

In [None]:
subset_palette={'CD14 Mc': '#64d941',
 'CD16 Mc / TAM': '#268f07',
 'HSPC': '#d6e376',
 'MSC': '#cfc10a',
 'Neutrophil': '#95ad74',
 'cDC': '#3bff8c',
 'CD4T': '#eb9449',
 'CD8T': '#f00e0e',
 'T Stim/Exh': '#8c3a3a',
 'gdT/NK': '#ba6ee0',
 'Early Ery': '#d6d6d6',
 'Late Ery': '#807f7d',
 'Immature B': '#17d4ff',
 'Memory B': '#032563',
 'Naive B': '#5fa5ed',
 'Pro/Pre B': '#adede7',
 'Transitional B': '#a2c6eb',
 'CLP': '#d1e5e6',
 'pDC': '#a5c3c4',
 'PC': '#ffbafd',
 'MKC': '#000000'}

In [None]:
paired = frac_df[
    frac_df["UPN"].isin(paired_upns)
    & frac_df["Collection"].isin(["NDMM", "PT"])
]
wide = (
    paired.pivot_table(
        index=["UPN", "subset"],
        columns="Collection",
        values="frac_subset"
    )
    .reset_index()
)
wide['NDMM_log']=np.log(wide['NDMM']*100+ 0.001)
wide['PT_log']=np.log(wide['PT']*100 + 0.001)
wide

In [None]:
pdf_out = "NDMM_vs_PT_paired_subset_scatter.pdf"
with PdfPages(pdf_out) as pdf:
    fig, ax = plt.subplots(figsize=(6, 6))

    # scatter colored by subset
    sns.scatterplot(
        data=wide,
        x="NDMM_log",
        y="PT_log",
        hue="subset",
        palette=subset_palette,  # your dict of cell-type colors
        s=50,
        edgecolor="none",
        alpha=0.85,
        ax=ax
    )

    # diagonal y=x reference
    lim = (wide[["NDMM_log", "PT_log"]].min().min()-1, wide[["NDMM_log", "PT_log"]].max().max()+1)
    ax.plot(lim, lim, color="black", linestyle="--", lw=1)
    ax.set_xlim(lim)
    ax.set_ylim(lim)
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("Fraction in NDMM (log10)")
    ax.set_ylabel("Fraction in PT (log10)")
    ax.set_title("Paired subset abundances (NDMM vs PT)")
    sns.despine(ax=ax)
    plt.tight_layout()
    pdf.savefig(fig, bbox_inches="tight")
    plt.close(fig)

In [None]:
wide[wide['PT']==wide['PT'].max()]

In [None]:
pdf_out = "NDMM_vs_PT_paired_subset_scatter_persample.pdf"
with PdfPages(pdf_out) as pdf:
    for upn in wide['UPN'].unique():
        fig, ax = plt.subplots(figsize=(6, 6))
    
        # scatter colored by subset
        sns.scatterplot(
            data=wide[wide['UPN']==upn],
            x="NDMM_log",
            y="PT_log",
            hue="subset",
            palette=subset_palette,  # your dict of cell-type colors
            s=50,
            edgecolor="none",
            alpha=0.85,
            ax=ax
        )
    
        # diagonal y=x reference
        lim = (wide[["NDMM_log", "PT_log"]].min().min()-1, wide[["NDMM_log", "PT_log"]].max().max()+1)
        ax.plot(lim, lim, color="black", linestyle="--", lw=1)
        ax.set_xlim(lim)
        ax.set_ylim(lim)
        ax.set_aspect("equal", adjustable="box")
        ax.set_xlabel("Fraction in NDMM (log10)")
        ax.set_ylabel("Fraction in PT (log10)")
        ax.set_title(f"Paired subset abundances (NDMM vs PT) in {upn}")
        sns.despine(ax=ax)
        plt.tight_layout()
        pdf.savefig(fig, bbox_inches="tight")
        plt.close(fig)

In [None]:
pdf_out = "NDMM_vs_PT_paired_subset_scatter_perCT.pdf"
with PdfPages(pdf_out) as pdf:
    for ct in wide['subset'].unique():
        fig, ax = plt.subplots(figsize=(6, 6))
    
        # scatter colored by subset
        sns.scatterplot(
            data=wide[wide['subset']==ct],
            x="NDMM_log",
            y="PT_log",
            hue="subset",
            palette=subset_palette,  # your dict of cell-type colors
            s=50,
            edgecolor="none",
            legend=False,
            alpha=0.85,
            ax=ax
        )
    
        # diagonal y=x reference
        lim = (wide[["NDMM_log", "PT_log"]].min().min()-1, wide[["NDMM_log", "PT_log"]].max().max()+1)
        ax.plot(lim, lim, color="black", linestyle="--", lw=1)
        ax.set_xlim(lim)
        ax.set_ylim(lim)
        ax.set_aspect("equal", adjustable="box")
        ax.set_xlabel("Fraction in NDMM (log10)")
        ax.set_ylabel("Fraction in PT (log10)")
        ax.set_title(f"Paired subset abundances (NDMM vs PT) in {ct}")
        sns.despine(ax=ax)
        plt.tight_layout()
        pdf.savefig(fig, bbox_inches="tight")
        plt.close(fig)

In [None]:
paired = frac_df[(frac_df['UPN'].isin(paired_upns))].copy()

paired['Collection'] = paired['Collection'].astype('category').cat.remove_unused_categories()
paired['subset'] = paired['subset'].astype('category').cat.remove_unused_categories()

# one figure per subset:
pdf_out = "subset_relative_abundance_exclPCs_paired_boxplots.pdf"
with PdfPages(pdf_out) as pdf:
    for ct in paired['subset'].cat.categories:  
        sub = paired[paired['subset'] == ct].copy()
        fig, ax = plt.subplots(figsize=(2, 3))
    
        sns.boxplot(
            data=sub,
            x="Collection",
            y="frac_subset",
            hue="Collection",
            palette=timecols,
            fliersize=0,
            linewidth=1,
            ax=ax,
        )
    
        sns.stripplot(
            data=sub,
            x="Collection",
            y="frac_subset",
            hue="Collection",
            #dodge=True,
            alpha=1,
            size=3,
            palette="dark:black",
            ax=ax,
        )
        
        # Paired lines
        for upn, g in sub.groupby('UPN'):
            if len(g) == 2:
                pts = g.sort_values("Collection")["frac_subset"].values
                xs = [0, 1]
                ax.plot(xs, pts, color='gray', alpha=1, linewidth=1)
    
        ax.set_title(f"{ct} (paired NDMM vs PT)")
        ax.set_ylabel("Fraction of cells")
        
        pairs = [("NDMM", "PT")]
        annotate = Annotator(
            ax,
            pairs,
            data=sub,
            x="Collection",
            y="frac_subset",
            order=["NDMM", "PT"],
        )
        annotate.configure(
            test='Wilcoxon',  # paired Wilcoxon
            text_format='star',
            loc='outside',
            hide_non_significant=True
        )
        annotate.apply_and_annotate()
    
        ax.set_ylabel("Fraction of cells")
        ax.set_xlabel("Cell Type")
        ax.set_title(f"{ct}")
        sns.despine(ax=ax)
        plt.tight_layout()
    
        pdf.savefig(fig, bbox_inches="tight")
        plt.close(fig)

In [None]:
cd_ratio = (
    paired[paired["subset"].isin(["CD4T", "CD8T"])]
    .pivot_table(
        index=["UPN", "Collection"],
        columns="subset",
        values="frac_subset",
        fill_value=0
    )
    .reset_index()
)
cd_ratio["CD4_CD8_ratio"] = cd_ratio["CD4T"] / (cd_ratio["CD8T"] + 1e-6)
cd_ratio

In [None]:
pdf_out = "CD4CD8Tratio_paired_boxplots.pdf"
with PdfPages(pdf_out) as pdf:
    fig, ax = plt.subplots(figsize=(2, 3))

    sns.boxplot(
        data=cd_ratio,
        x="Collection",
        y="CD4_CD8_ratio",
        hue="Collection",
        palette=timecols,
        fliersize=0,
        linewidth=1,
        ax=ax,
    )

    sns.stripplot(
        data=cd_ratio,
        x="Collection",
        y="CD4_CD8_ratio",
        hue="Collection",
        alpha=1,
        size=3,
        palette="dark:black",
        ax=ax,
    )
    
    # Paired lines
    for upn, g in cd_ratio.groupby('UPN'):
        if len(g) == 2:
            pts = g.sort_values("Collection")["CD4_CD8_ratio"].values
            xs = [0, 1]
            ax.plot(xs, pts, color='gray', alpha=1, linewidth=1)

    ax.set_title(f"{ct} (paired NDMM vs PT)")
    ax.set_ylabel("Fraction of cells")
    
    pairs = [("NDMM", "PT")]
    annotate = Annotator(
        ax,
        pairs,
        data=cd_ratio,
        x="Collection",
        y="CD4_CD8_ratio",
        order=["NDMM", "PT"],
    )
    annotate.configure(
        test='Wilcoxon',  # paired Wilcoxon
        text_format='star',
        loc='outside',
        hide_non_significant=True
    )
    annotate.apply_and_annotate()

    ax.set_ylabel("Fraction of cells")
    ax.set_xlabel("Cell Type")
    ax.set_title("CD4:CD8 ratio")
    sns.despine(ax=ax)
    plt.tight_layout()

    pdf.savefig(fig, bbox_inches="tight")
    plt.close(fig)

In [None]:
obs = merged.obs[['UPN_Collection', 'UPN', 'Collection', 'lin']]
obs = obs[obs['Collection'].isin(collection_order)]

obs=obs[obs['lin'] != 'PC']
obs['lin'] = obs['lin'].cat.remove_unused_categories()

# Compute per-UPN fractions of each lin
counts = (
    obs.groupby(["UPN_Collection", "lin"], observed=False)
        .size()
        .reset_index(name="n_cells")
)

totals = (
    counts.groupby(["UPN_Collection"], observed=True)["n_cells"]
          .sum()
          .reset_index(name="total_cells")
)

frac_df = counts.merge(totals, on=["UPN_Collection"])
frac_df["frac_lin"] = frac_df["n_cells"] / frac_df["total_cells"]
frac_df[['UPN', 'Collection']] = frac_df['UPN_Collection'].str.split('|', expand=True)

# get order of lins based on NBM abundance
nbm = frac_df[frac_df['Collection'] == 'NBM']
median_nbm = (
    nbm.groupby('lin', observed=False)['frac_lin']
        .median()
        .sort_values(ascending=False)
)
lin_order = median_nbm.index.tolist()
frac_df['lin'] = pd.Categorical(frac_df['lin'], categories=lin_order, ordered=True)


paired = frac_df[(frac_df['UPN'].isin(paired_upns))].copy()

paired['Collection'] = paired['Collection'].astype('category').cat.remove_unused_categories()
paired['lin'] = paired['lin'].astype('category').cat.remove_unused_categories()

# one figure per lin:
pdf_out = "lineage_relative_abundance_exclPCs_paired_boxplots.pdf"
with PdfPages(pdf_out) as pdf:
    for ct in paired['lin'].cat.categories:  
        sub = paired[paired['lin'] == ct].copy()
        fig, ax = plt.subplots(figsize=(2, 3))
    
        sns.boxplot(
            data=sub,
            x="Collection",
            y="frac_lin",
            hue="Collection",
            palette=timecols,
            fliersize=0,
            linewidth=1,
            ax=ax,
        )
    
        sns.stripplot(
            data=sub,
            x="Collection",
            y="frac_lin",
            hue="Collection",
            #dodge=True,
            alpha=1,
            size=3,
            palette="dark:black",
            ax=ax,
        )
        
        # Paired lines
        for upn, g in sub.groupby('UPN'):
            if len(g) == 2:
                pts = g.sort_values("Collection")["frac_lin"].values
                xs = [0, 1]
                ax.plot(xs, pts, color='gray', alpha=1, linewidth=1)
    
        ax.set_title(f"{ct} (paired NDMM vs PT)")
        ax.set_ylabel("Fraction of cells")
        
        pairs = [("NDMM", "PT")]
        annotate = Annotator(
            ax,
            pairs,
            data=sub,
            x="Collection",
            y="frac_lin",
            order=["NDMM", "PT"],
        )
        annotate.configure(
            test='Wilcoxon',  # paired Wilcoxon
            text_format='star',
            loc='outside',
            hide_non_significant=True
        )
        annotate.apply_and_annotate()
    
        ax.set_ylabel("Fraction of cells")
        ax.set_xlabel("Cell Type")
        ax.set_title(f"{ct}")
        sns.despine(ax=ax)
        plt.tight_layout()
    
        pdf.savefig(fig, bbox_inches="tight")
        plt.close(fig)

In [None]:
merged.obs['lin'].value_counts()