In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# the final STAG AML data after all cleaning

In [None]:
adata = sc.read("STAG_AML_final.h5ad")

In [None]:
plt.rcParams["figure.figsize"] = (10, 10)
sc.pl.umap(
    adata,
    color=["leiden"],
    s=50,
    frameon=False,
    legend_loc="on data",
    legend_fontsize=20,
    legend_fontoutline=4,
)

# do DE on leiden clusters

In [None]:
adata.uns["log1p"]["base"] = None
sc.tl.rank_genes_groups(adata, "leiden", method="wilcoxon")

In [None]:
result = adata.uns["rank_genes_groups"]
groups = result["names"].dtype.names
deg = pd.DataFrame(
    {
        group + "_" + key[:4]: result[key][group]
        for group in groups
        for key in ["names", "pvals_adj", "logfoldchanges"]
    }
)

In [None]:
i = 3
deg_sub = deg[[f"{i}_name", f"{i}_pval", f"{i}_logf"]].copy()
deg_sub["logpval_mult_logf"] = -np.log10(deg_sub[f"{i}_pval"] + 10 ** (-300)) * np.abs(
    deg_sub[f"{i}_logf"]
)

deg_sub = deg_sub[
    (deg_sub[f"{i}_pval"] < 0.001) & (deg_sub[f"{i}_logf"] > 1.2)
].sort_values(by="logpval_mult_logf", ascending=False)
deg_sub.head(40)

In [None]:
plt.rcParams["figure.figsize"] = (10, 10)
sc.pl.umap(
    adata,
    color=["DLK1", "CXCL2"],
    s=50,
    frameon=False,
    legend_loc="on data",
    legend_fontsize=20,
    legend_fontoutline=4,
)

In [None]:
# print(deg_sub[deg_sub.columns[0]].tolist())

# name leiden clusters based on DEGs

In [None]:
leiden_cell = {
    "0": "HSC",
    "1": "Proliferating",
    "2": "GMP",
    "3": "GMP",
    "4": "HSC",
    "5": "CD14+CD16+ Mono.",
    "6": "Early monocytes",
    "7": "Naive CD4/CD8 T cell",
    "8": "Neutrophil",
    "9": "HSC",
    "10": "Neutrophil",
    "11": "NK/Eff./Mem. T cell",
    "12": "CD16 Mono.",
    "13": "MEP",
    "14": "Erythrocytes",
    "15": "Proliferating",
}

In [None]:
adata.obs["cell_type"] = adata.obs.leiden.apply(lambda x: leiden_cell[x])

In [None]:
plt.rcParams["figure.figsize"] = (10, 10)
sc.pl.umap(
    adata,
    color=["cell_type"],
    s=50,
    frameon=False,
    legend_loc="on data",
    legend_fontsize=20,
    legend_fontoutline=4,
)

# do DE on cell_type clusters

In [None]:
adata.uns["log1p"]["base"] = None
sc.tl.rank_genes_groups(adata, "cell_type", method="wilcoxon", key_added='cell_type_DEGs')

In [None]:
result = adata.uns["cell_type_DEGs"]
groups = result["names"].dtype.names
deg = pd.DataFrame(
    {
        group + "_" + key[:4]: result[key][group]
        for group in groups
        for key in ["names", "pvals_adj", "logfoldchanges"]
    }
)

In [None]:
cell_types = adata.obs.cell_type.cat.categories

In [None]:
cell_types

# aggregate all cell_type DEG in one dictionary

In [None]:
ALL_DEGs = {}
for i in range(len(cell_types)):
    i = cell_types[i]
    deg_sub = deg[[f"{i}_name", f"{i}_pval", f"{i}_logf"]].copy()
    deg_sub["logpval_mult_logf"] = -np.log10(deg_sub[f"{i}_pval"] + 10 ** (-300)) * np.abs(
        deg_sub[f"{i}_logf"]
    )

    deg_sub = deg_sub[
        (deg_sub[f"{i}_pval"] < 0.0001) & (deg_sub[f"{i}_logf"] > 2.5)
    ].sort_values(by=f"{i}_logf", ascending=False)
    ALL_DEGs[i] = deg_sub.head(40)

In [None]:
deg_sub = ALL_DEGs[cell_types[1]]
genes = deg_sub[deg_sub.columns[0]].tolist()

# plot tops DEGs from the entire AML dataset on top of our dataset

In [None]:
deg10 = pd.read_csv('../all_csvs/LMPP_DEG.csv',index_col=0)

In [None]:
deg10 = pd.read_csv('../all_csvs/GMP_DEG.csv',index_col=0)

In [None]:
deg10 = pd.read_csv('../all_csvs/Prog_Mk_DEG.csv',index_col=0)

In [None]:
genes = adata.var[adata.var.index.isin(deg10.index[:20])].index

In [None]:
plt.rcParams["figure.figsize"] = (4, 4)
sc.pl.umap(
    adata,
    color=genes,
    s=20,
    frameon=False,
    legend_loc="on data",
    legend_fontsize=20,
    legend_fontoutline=4,
)

# find which clusters have genes of interest as DEG in the entire AML dataset

In [None]:
deg10 = pd.read_csv('../all_csvs/AML_predicted_celltype_DEG.csv',index_col=0)

In [None]:
genes = deg_sub[deg_sub.columns[0]].tolist()[:40]

In [None]:
for gene in genes:# deg_sub[deg_sub.columns[0]].tolist()[:40]:
    result = deg10.eq(gene).stack()
    print(15*'-',gene,22*'-')#,result[result].shape[0])
    if result[result].shape[0]>0:
        for i in range(3):
            row_index, column_index = result[result].index[i]
            print(deg10.loc[row_index,[column_index.split('_name')[0]+'_pval',column_index.split('_name')[0]+'_logf']])
            print(30*'-')