In [1]:
import os
import anndata as ad
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
correspondence = {}
correspondence['149Sm_CREB'] ='149Sm_pCREB'
correspondence['167Er_ERK'] ='167Er_pERK12'
correspondence['164Dy_IkB'] ='164Dy_IkB'
correspondence['159Tb_MAPKAPK2'] ='159Tb_pMAPKAPK2'
correspondence['166Er_NFkB'] ='166Er_pNFkB'
correspondence['151Eu_p38'] ='151Eu_pp38'
correspondence['155Gd_S6'] ='155Gd_pS6'
correspondence['153Eu_STAT1'] ='153Eu_pSTAT1'
correspondence['154Sm_STAT3'] ='154Sm_pSTAT3'
correspondence['150Nd_STAT5'] ='150Nd_pSTAT5'
correspondence['168Er_pSTAT6'] ='168Yb_pSTAT6'
correspondence['174Yb_HLADR'] ='174Yb_HLADR'
correspondence['169Tm_CD25'] ='169Tm_CD25'

In [58]:
def create_density_plots(dist_data, out_file, title_suffix=""):
    sns.set_theme(style="whitegrid")
    pts_sorted = sorted(dist_data.keys())

    num_plots = len(pts_sorted)
    cols = min(3, num_plots)
    rows = int(np.ceil(num_plots / cols))

    fig_width = max(5 * cols, 8)
    fig_height = 5 * rows
    fig, axes = plt.subplots(
        rows, cols, figsize=(fig_width, fig_height), constrained_layout=True
    )

    fig.suptitle(f"Density Plots {title_suffix}", fontsize=16)

    if num_plots == 1:
        axes = np.array([axes])

    cat_labels = ["Unstim perio", 'Unstim surge',"Unstim log perio","Unstim log surge"]
    cat_colors = ["blue", "red", "green", "black", "orange", "purple"]

    for i, (pt, ax) in enumerate(zip(pts_sorted, axes.flatten())):
        for label, color in zip(cat_labels, cat_colors):
            arr = dist_data[pt][label]
            if arr.size > 0:
                sns.kdeplot(
                    arr,
                    ax=ax,
                    label=f"{label} (n={arr.size})",
                    color=color,
                    fill=False,  # set tot True to fill the area under the curve
                    alpha=0.3,
                )

        ax.set_title(f"Perio vs surge:", fontsize=14)
        ax.set_xlabel("Value", fontsize=12)
        ax.set_ylabel("Density", fontsize=12)
        ax.legend(fontsize=10)
        ax.grid(True)
    for j in range(i + 1, len(axes.flatten())):
        fig.delaxes(axes.flatten()[j])

    plt.savefig(out_file, dpi=200, bbox_inches="tight")
    plt.close()


def plot_result(path_cohort_1, path_cohort_2, marker, outdir_path,doms_stim):
    cohort1 = ad.read(path_cohort_1)
    cohort1 = cohort1[:, marker].copy()
    if doms_stim == 'LPS':
        stim_perio='P. gingivalis'
    unstim1 = pd.Series(
        cohort1[cohort1.obs["drug"] == "Unstim"].X.flatten(), name="Unstim perio"
    )
    stim1 = pd.Series(
        cohort1[cohort1.obs["drug"] ==stim_perio].X.flatten(), name="Stim True perio"
    )
    log1=np.log2p(pd.Series(
        cohort1[cohort1.obs["drug"] == "Unstim"].X.flatten(), name="Unstim log perio"
    ))
    
    dataf = ad.read(path_cohort_2)
    target2 = dataf[:, correspondence[marker]].copy()
    unstim2 = pd.Series(
        target2[target2.obs["drug"] == "Unstim"].X.flatten(), name="Unstim surge"
    )
    stim2 = pd.Series(
        target2[target2.obs["drug"] == doms_stim].X.flatten(), name="Stim True surge"
    )
    log2=np.log2p(pd.Series(
        target2[target2.obs["drug"] == "Unstim"].X.flatten(), name="Unstim log surge"
    ))
    dist_data = {
        "Patient_1": {
            "Unstim perio": unstim1.values,
            "Stim True perio": stim1.values,
            'Unstim surge': unstim2.values,
            'Stim True surge': stim2.values,
            "Unstim log perio": log1.values,
            "Unstim log surge": log2.values,
        }
    }

    create_density_plots(dist_data, outdir_path, title_suffix="")
    return

In [59]:
def plot_result2(path_cohort_1, path_cohort_2, marker, outdir_path, doms_stim):
    cohort1 = ad.read(path_cohort_1)
    cohort1 = cohort1[:, marker].copy()
    if doms_stim == 'LPS':
        stim_perio = 'P. gingivalis'

    unstim1 = pd.Series(cohort1[cohort1.obs["drug"] == "Unstim"].X.flatten(), name="Unstim perio")
    stim1 = pd.Series(cohort1[cohort1.obs["drug"] == stim_perio].X.flatten(), name="Stim True perio")

    dataf = ad.read(path_cohort_2)
    target2 = dataf[:, correspondence[marker]].copy()
    unstim2 = pd.Series(target2[target2.obs["drug"] == "Unstim"].X.flatten(), name="Unstim surge")
    stim2 = pd.Series(target2[target2.obs["drug"] == doms_stim].X.flatten(), name="Stim True surge")

    offset = unstim2[unstim2 > 0].quantile(0.05)
    scale=10
    # Apply log1p after shifting
    log1 = np.log1p((unstim1 + offset) * scale)
    log2 = np.log1p((unstim2 + offset) * scale)
    print("Quantiles unstim2:", np.quantile(unstim2, [0, 0.25, 0.5, 0.75, 1]))
    print(offset)
    dist_data = {
        "Patient_1": {
            "Unstim perio": unstim1.values,
            "Stim True perio": stim1.values,
            "Unstim surge": unstim2.values,
            "Stim True surge": stim2.values,
            "Unstim log perio": log1,
            "Unstim log surge": log2,
        }
    }
    create_density_plots(dist_data, outdir_path, title_suffix="")
    return


In [60]:
marker_list=['159Tb_MAPKAPK2', '151Eu_p38','155Gd_S6']
perio_stim_list_=['TNFa','P._gingivalis']
perio_cell_list_=['Granulocytes_(CD45-CD66+)','B-Cells_(CD19+CD3-)','Classical_Monocytes_(CD14+CD16-)','MDSCs_(lin-CD11b-CD14+HLADRlo)','mDCs_(CD11c+HLADR+)','pDCs(CD123+HLADR+)','Intermediate_Monocytes_(CD14+CD16+)','Non-classical_Monocytes_(CD14-CD16+)','CD56+CD16-_NK_Cells','CD56loCD16+NK_Cells','NK_Cells_(CD7+)','CD4_T-Cells','Tregs_(CD25+FoxP3+)','CD8_T-Cells','CD8-CD4-_T-Cells']
cell_type='Classical Monocytes (CD14+CD16-)'
stim='P._gingivalis'
if [stim, cell_type] not in [['P._gingivalis', 'Non-classical_Monocytes_(CD14-CD16+)'],['P._gingivalis', 'NK_Cells_(CD7+)'],['TNFa', 'Granulocytes_(CD45-CD66+)']]:
    for marker in marker_list:
        if stim=='P._gingivalis':
            doms_stim='LPS'
        else:
            doms_stim=stim
    
        path_cohort_2 = f"surge_just_concat/surge_data_LPS_Classical Monocytes (CD14+CD16-).h5ad"
        path_cohort_1="perio_just_concat/perio_data_sherlock_P. gingivalis_Classical Monocytes (CD14+CD16-).h5ad"
        output_path = f"plot_surge_vs_perio_uncorr/{doms_stim}_{cell_type}_{marker}_surge.png"
        
        plot_result2(path_cohort_1, path_cohort_2, marker, output_path,doms_stim)
        print(f"Plot {marker} for {cell_type} and {doms_stim}")

Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


Quantiles unstim2: [0.         0.03517235 0.24389052 0.53200577 5.09151506]
0.03201842866837979


Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


Plot 159Tb_MAPKAPK2 for Classical Monocytes (CD14+CD16-) and LPS


Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


Quantiles unstim2: [0.         0.         0.07819439 0.29686595 3.34448886]
0.0214058643206954


Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


Plot 151Eu_p38 for Classical Monocytes (CD14+CD16-) and LPS


Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


Quantiles unstim2: [0.         0.22247922 0.52866298 0.88515298 4.00684834]
0.06318165361881256
Plot 155Gd_S6 for Classical Monocytes (CD14+CD16-) and LPS
