In [None]:
import os
import re
import numpy as np
import pandas as pd
import scipy.stats
import sklearn.metrics
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import seaborn as sns
import gseapy as gp
import Cell_BLAST as cb
import exputils

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = exputils.pick_gpu_lowest_memory()
cb.config.RANDOM_SEED = 0
plt.rcParams['svg.fonttype'] = "none"
plt.rcParams['font.family'] = "Arial"
FILTER_GENE_SETS = True
PATH = "gene_grad_gsea"
os.makedirs(PATH, exist_ok=True)

In [None]:
class MidpointNormalize(colors.Normalize):  # https://matplotlib.org/tutorials/colors/colormapnorms.html
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        super(MidpointNormalize, self).__init__(vmin, vmax, clip)

    def __call__(self, value, clip=None):
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y), np.isnan(value))


def equalizing_weights(x):
    unique, unique_inverse, unique_counts = np.unique(x, return_inverse=True, return_counts=True)
    weights = unique_counts.sum() / unique_counts
    weights = weights[unique_inverse]
    return weights * weights.size / weights.sum()

## Read PanglaoDB markers

In [None]:
panglao_markers = pd.read_csv(
    "../../Datasets/marker/PanglaoDB_markers_21_Jan_2020.tsv",
    sep="\t", usecols=["species", "official gene symbol", "cell type", "organ"]
)
cl_panglao_mapping = pd.read_csv("../../Datasets/marker/CL_PanglaoDB_mapping.csv")
panglao_markers = panglao_markers.merge(cl_panglao_mapping)
panglao_markers.head()

### Extract human cell type markers

In [None]:
panglao_markers_human = panglao_markers.loc[np.vectorize(
    lambda x: bool(re.search("Hs", x))
)(panglao_markers["species"]), :]
panglao_markers_human.head()

### Extract and convert mouse cell type markers

In [None]:
human_mouse_ortholog = pd.read_csv(
    "../../Datasets/ortholog/Ensembl/orthology/human_mouse.csv", header=None,
    names=["ENSG", "human_gene_symbol", "ENMUSG", "mouse_gene_symbol", "ortholog_type"],
    usecols=["human_gene_symbol", "mouse_gene_symbol"]
)
human_mouse_ortholog.head()

In [None]:
panglao_markers_mouse = panglao_markers.loc[np.vectorize(
    lambda x: bool(re.search("Mm", x))
)(panglao_markers["species"]), :]
panglao_markers_mouse = panglao_markers_mouse.merge(
    human_mouse_ortholog,
    left_on="official gene symbol",
    right_on="human_gene_symbol"
)
panglao_markers_mouse["official gene symbol"] = panglao_markers_mouse["mouse_gene_symbol"]
del panglao_markers_mouse["human_gene_symbol"], panglao_markers_mouse["mouse_gene_symbol"]
panglao_markers_mouse.head()

## Pancreas

### Clean data

In [None]:
blast = cb.blast.BLAST.load("../../Results/Cell_BLAST/Pancreas/seed_0/blast")
blast.models = [
    cb.directi.DIRECTi.load(
        f"../../Results/Cell_BLAST/Baron_human+Xin_2016+Lawlor/dim_10_rmbatch0.01/seed_{i}"
    ) for i in range(16)
]  # We are not directly using this BLAST object for querying so it's okay to override models
blast.ref.obs["cell_ontology_class"] = pd.Categorical(blast.ref.obs["cell_ontology_class"])  # Make colors consistent

In [None]:
blast.ref.latent = blast.models[0].inference(blast.ref)
ax = blast.ref.visualize_latent("cell_ontology_class", method="UMAP", dr_kws=dict(min_dist=0.5))

Cell type annotation of some cells may not be reliable, which could significantly confuse neighbor-based gradient analysis (especially the subgroup of ductal cells located closer to acinar cells), so we remove these ambiguous cells in advance.

In [None]:
blast.ref.obs["silhouette"] = sklearn.metrics.silhouette_samples(
    blast.ref.latent, blast.ref.obs["cell_ontology_class"])
clean_ref = blast.ref[blast.ref.obs["silhouette"] > 0.2, :]
ax = clean_ref.visualize_latent("cell_ontology_class", method="UMAP")

In [None]:
used_cell_types = np.intersect1d(
    np.unique(clean_ref.obs["cell_ontology_class"]),
    np.unique(panglao_markers_human["cell_ontology_class"])
)
used_cell_types

In [None]:
gene_sets = {
    cell_type: np.intersect1d(panglao_markers_human.query(
        f"cell_ontology_class == '{cell_type}'"
    )["official gene symbol"], blast.models[0].genes).tolist()
    for cell_type in used_cell_types
}
if FILTER_GENE_SETS:
    cell_type_mean_exprs = clean_ref.get_meta_or_var(
        clean_ref.var_names.to_numpy().tolist() + ["cell_ontology_class"],
        normalize_var=True, log_var=True
    ).groupby("cell_ontology_class").mean()
    for cell_type in used_cell_types:
        gene_sets[cell_type] = [
            gene for gene in gene_sets[cell_type]
            if cell_type in cell_type_mean_exprs.index[
                cell_type_mean_exprs[gene].argsort()[-3:]
            ]
        ]
gene_sets

### Gradients

In [None]:
nes, pval, fdr = {}, {}, {}
for cell_type in used_cell_types:
    print(f"Dealing with {cell_type}...")
    used_ref = clean_ref[clean_ref.obs["cell_ontology_class"] != cell_type, :]
    used_query = clean_ref[clean_ref.obs["cell_ontology_class"] == cell_type, :]
    gene_grad = []
    for model in blast.models:
        used_blast = cb.blast.BLAST(
            [model], used_ref, distance_metric="ed"
        )  # Skip posterior distance since we are not doing any filtering
        hits = used_blast.query(used_query, n_neighbors=100, store_dataset=True)
        _gene_grad = hits.gene_gradient()
        _gene_grad = np.concatenate(_gene_grad)
        _gene_grad = np.average(
            _gene_grad, axis=0,
            weights=equalizing_weights(
                used_ref.obs["cell_ontology_class"][np.concatenate(hits.hits)]
            )
        )
        gene_grad.append(_gene_grad)
    gene_grad = np.stack(gene_grad).mean(axis=0)
    gene_grad = pd.DataFrame({0: blast.models[0].genes, 1: gene_grad})
    try:
        gsea_result = gp.prerank(
            gene_grad, gene_sets.copy(),  # gp.prerank seems to modify gene sets in-place
            outdir=f"{PATH}/pancreas/{cell_type}",
            weighted_score_type=0, min_size=10, seed=0
        )
    except Exception:
        print(f"GSEA failed with {cell_type}, skipped...")
        continue
    nes[cell_type] = {key: val["nes"] for key, val in gsea_result.results.items()}
    pval[cell_type] = {key: val["pval"] for key, val in gsea_result.results.items()}
    fdr[cell_type] = {key: val["fdr"] for key, val in gsea_result.results.items()}

In [None]:
nes_df = pd.DataFrame.from_dict(nes, orient="index")
fdr_df = pd.DataFrame.from_dict(fdr, orient="index")
nes_df.index.name = "cell ontology class"
fdr_df.index.name = "cell ontology class"
successful_cell_types = np.intersect1d(nes_df.index, nes_df.columns)
nes_df = nes_df.loc[successful_cell_types, successful_cell_types]
fdr_df = fdr_df.loc[successful_cell_types, successful_cell_types]

In [None]:
nes_df = nes_df.reset_index().melt(id_vars="cell ontology class", var_name="PanglaoDB markers", value_name="NES")
fdr_df = fdr_df.reset_index().melt(id_vars="cell ontology class", var_name="PanglaoDB markers", value_name="FDR")
nes_fdr_df = pd.merge(nes_df, fdr_df)
nes_fdr_df["-log10 FDR"] = np.minimum(-np.log10(nes_fdr_df["FDR"]), 5)
nes_fdr_df.head()

In [None]:
hue_norm = MidpointNormalize(
    midpoint=0, vmin=min(nes_fdr_df["NES"]), vmax=max(nes_fdr_df["NES"]))
sm = plt.cm.ScalarMappable(cmap="seismic", norm=hue_norm)
sm.set_array([])

fig, ax = plt.subplots(figsize=(5, 4))
ax = sns.scatterplot(
    x="PanglaoDB markers", y="cell ontology class", hue="NES", size="-log10 FDR",
    data=nes_fdr_df, palette="seismic", hue_norm=hue_norm, sizes=(1, 300),
    edgecolor=None, ax=ax
)
ax.invert_yaxis()
ax.tick_params(axis="x", labelrotation=90)
cbar = plt.colorbar(sm)
cbar.set_label("NES", rotation=270)
h, l = ax.get_legend_handles_labels()
ax.legend(h[-5:], l[-5:], loc="center left", bbox_to_anchor=(1.3, 0.5), frameon=False, labelspacing=1.2)
fig.savefig(f"{PATH}/pancreas/dotplot.pdf", bbox_inches="tight")

## Trachea

### Clean data

In [None]:
blast = cb.blast.BLAST.load("../../Results/Cell_BLAST/Trachea/seed_0/blast")
blast.models = [
    cb.directi.DIRECTi.load(
        f"../../Results/Cell_BLAST/Montoro_10x/dim_10/seed_{i}"
    ) for i in range(16)
]  # We are not directly using this BLAST object for querying so it's okay to override models
blast.ref.obs["cell_ontology_class"] = pd.Categorical(blast.ref.obs["cell_ontology_class"])  # Make colors consistent

In [None]:
blast.ref.latent = blast.models[0].inference(blast.ref)
ax = blast.ref.visualize_latent("cell_ontology_class", method="UMAP", dr_kws=dict(min_dist=0.5))

Cell type annotation of some cells may not be reliable, which could significantly confuse neighbor-based gradient analysis (especially the subgroup of ductal cells located closer to acinar cells), so we remove these ambiguous cells in advance.

In [None]:
blast.ref.obs["silhouette"] = sklearn.metrics.silhouette_samples(
    blast.ref.latent, blast.ref.obs["cell_ontology_class"])
clean_ref = blast.ref[blast.ref.obs["silhouette"] > 0.2, :]
ax = clean_ref.visualize_latent("cell_ontology_class", method="UMAP")

In [None]:
used_cell_types = np.intersect1d(
    np.unique(clean_ref.obs["cell_ontology_class"]),
    np.unique(panglao_markers_mouse["cell_ontology_class"])
)
used_cell_types

In [None]:
gene_sets = {
    cell_type: np.intersect1d(panglao_markers_mouse.query(
        f"cell_ontology_class == '{cell_type}'"
    )["official gene symbol"], blast.models[0].genes).tolist()
    for cell_type in used_cell_types
}
if FILTER_GENE_SETS:
    cell_type_mean_exprs = clean_ref.get_meta_or_var(
        clean_ref.var_names.to_numpy().tolist() + ["cell_ontology_class"],
        normalize_var=True, log_var=True
    ).groupby("cell_ontology_class").mean()
    for cell_type in used_cell_types:
        gene_sets[cell_type] = [
            gene for gene in gene_sets[cell_type]
            if cell_type in cell_type_mean_exprs.index[
                cell_type_mean_exprs[gene].argsort()[-3:]
            ]
        ]
gene_sets

### Gradients

In [None]:
nes, pval, fdr = {}, {}, {}
for cell_type in used_cell_types:
    print(f"Dealing with {cell_type}...")
    used_ref = clean_ref[clean_ref.obs["cell_ontology_class"] != cell_type, :]
    used_query = clean_ref[clean_ref.obs["cell_ontology_class"] == cell_type, :]
    gene_grad = []
    for model in blast.models:
        used_blast = cb.blast.BLAST(
            [model], used_ref, distance_metric="ed"
        )  # Skip posterior distance since we are not doing any filtering
        hits = used_blast.query(used_query, n_neighbors=50, store_dataset=True)
        _gene_grad = hits.gene_gradient()
        _gene_grad = np.concatenate(_gene_grad)
        _gene_grad = np.average(
            _gene_grad, axis=0,
            weights=equalizing_weights(
                used_ref.obs["cell_ontology_class"][np.concatenate(hits.hits)]
            )
        )
        gene_grad.append(_gene_grad)
    gene_grad = np.stack(gene_grad).mean(axis=0)
    gene_grad = pd.DataFrame({0: blast.models[0].genes, 1: gene_grad})
    try:
        gsea_result = gp.prerank(
            gene_grad, gene_sets.copy(),  # gp.prerank seems to modify gene sets in-place
            outdir=f"{PATH}/trachea/{cell_type}",
            weighted_score_type=0, min_size=5, seed=0
        )
    except Exception:
        print(f"GSEA failed with {cell_type}, skipped...")
        continue
    nes[cell_type] = {key: val["nes"] for key, val in gsea_result.results.items()}
    pval[cell_type] = {key: val["pval"] for key, val in gsea_result.results.items()}
    fdr[cell_type] = {key: val["fdr"] for key, val in gsea_result.results.items()}

In [None]:
nes_df = pd.DataFrame.from_dict(nes, orient="index")
fdr_df = pd.DataFrame.from_dict(fdr, orient="index")
nes_df.index.name = "cell ontology class"
fdr_df.index.name = "cell ontology class"
successful_cell_types = np.intersect1d(nes_df.index, nes_df.columns)
nes_df = nes_df.loc[successful_cell_types, successful_cell_types]
fdr_df = fdr_df.loc[successful_cell_types, successful_cell_types]

In [None]:
nes_df = nes_df.reset_index().melt(id_vars="cell ontology class", var_name="PanglaoDB markers", value_name="NES")
fdr_df = fdr_df.reset_index().melt(id_vars="cell ontology class", var_name="PanglaoDB markers", value_name="FDR")
nes_fdr_df = pd.merge(nes_df, fdr_df)
nes_fdr_df["-log10 FDR"] = np.minimum(-np.log10(nes_fdr_df["FDR"]), 6)
nes_fdr_df.head()

In [None]:
hue_norm = MidpointNormalize(
    midpoint=0, vmin=min(nes_fdr_df["NES"]), vmax=max(nes_fdr_df["NES"]))
sm = plt.cm.ScalarMappable(cmap="seismic", norm=hue_norm)
sm.set_array([])

fig, ax = plt.subplots(figsize=(5, 4))
ax = sns.scatterplot(
    x="PanglaoDB markers", y="cell ontology class", hue="NES", size="-log10 FDR",
    data=nes_fdr_df, palette="seismic", hue_norm=hue_norm, sizes=(1, 300),
    edgecolor=None, ax=ax
)
ax.invert_yaxis()
ax.tick_params(axis="x", labelrotation=90)
cbar = plt.colorbar(sm)
cbar.set_label("NES", rotation=270)
h, l = ax.get_legend_handles_labels()
ax.legend(h[-5:], l[-5:], loc="center left", bbox_to_anchor=(1.3, 0.5), frameon=False, labelspacing=1.2)
fig.savefig(f"{PATH}/trachea/dotplot.pdf", bbox_inches="tight")

## Lung

### Clean data

In [None]:
blast = cb.blast.BLAST.load("../../Results/Cell_BLAST/Lung/seed_0/blast")
blast.models = [
    cb.directi.DIRECTi.load(
        f"../../Results/Cell_BLAST/Quake_10x_Lung/dim_10/seed_{i}"
    ) for i in range(16)
]  # We are not directly using this BLAST object for querying so it's okay to override models
blast.ref.obs["cell_ontology_class"] = pd.Categorical(blast.ref.obs["cell_ontology_class"])  # Make colors consistent

In [None]:
blast.ref.latent = blast.models[0].inference(blast.ref)
ax = blast.ref.visualize_latent("cell_ontology_class", method="UMAP", dr_kws=dict(min_dist=0.5))

Cell type annotation of some cells may not be reliable, which could significantly confuse neighbor-based gradient analysis (especially the subgroup of ductal cells located closer to acinar cells), so we remove these ambiguous cells in advance.

In [None]:
blast.ref.obs["silhouette"] = sklearn.metrics.silhouette_samples(
    blast.ref.latent, blast.ref.obs["cell_ontology_class"])
clean_ref = blast.ref[np.logical_and(
    blast.ref.obs["silhouette"] > 0.2,
    blast.ref.obs["cell_ontology_class"] != "stromal cell"
    # Stromal cells are removed because of significant heterogeneity
), :]
ax = clean_ref.visualize_latent("cell_ontology_class", method="UMAP")

In [None]:
used_cell_types = np.intersect1d(
    np.unique(clean_ref.obs["cell_ontology_class"]),
    np.unique(panglao_markers_mouse["cell_ontology_class"])
)
used_cell_types

In [None]:
gene_sets = {
    cell_type: np.intersect1d(panglao_markers_mouse.query(
        f"cell_ontology_class == '{cell_type}'"
    )["official gene symbol"], blast.models[0].genes).tolist()
    for cell_type in used_cell_types
}
if FILTER_GENE_SETS:
    cell_type_mean_exprs = clean_ref.get_meta_or_var(
        clean_ref.var_names.to_numpy().tolist() + ["cell_ontology_class"],
        normalize_var=True, log_var=True
    ).groupby("cell_ontology_class").mean()
    for cell_type in used_cell_types:
        gene_sets[cell_type] = [
            gene for gene in gene_sets[cell_type]
            if cell_type in cell_type_mean_exprs.index[
                cell_type_mean_exprs[gene].argsort()[-3:]
            ]
        ]
gene_sets

### Gradients

In [None]:
nes, pval, fdr = {}, {}, {}
for cell_type in used_cell_types:
    print(f"Dealing with {cell_type}...")
    used_ref = clean_ref[clean_ref.obs["cell_ontology_class"] != cell_type, :]
    used_query = clean_ref[clean_ref.obs["cell_ontology_class"] == cell_type, :]
    gene_grad = []
    for model in blast.models:
        used_blast = cb.blast.BLAST(
            [model], used_ref, distance_metric="ed"
        )  # Skip posterior distance since we are not doing any filtering
        hits = used_blast.query(used_query, n_neighbors=50, store_dataset=True)
        _gene_grad = hits.gene_gradient()
        _gene_grad = np.concatenate(_gene_grad)
        _gene_grad = np.average(
            _gene_grad, axis=0,
            weights=equalizing_weights(
                used_ref.obs["cell_ontology_class"][np.concatenate(hits.hits)]
            )
        )
        gene_grad.append(_gene_grad)
    gene_grad = np.stack(gene_grad).mean(axis=0)
    gene_grad = pd.DataFrame({0: blast.models[0].genes, 1: gene_grad})
    try:
        gsea_result = gp.prerank(
            gene_grad, gene_sets.copy(),  # gp.prerank seems to modify gene sets in-place
            outdir=f"{PATH}/lung/{cell_type}",
            weighted_score_type=0, min_size=10, seed=0
        )
    except Exception:
        print(f"GSEA failed with {cell_type}, skipped...")
        continue
    nes[cell_type] = {key: val["nes"] for key, val in gsea_result.results.items()}
    pval[cell_type] = {key: val["pval"] for key, val in gsea_result.results.items()}
    fdr[cell_type] = {key: val["fdr"] for key, val in gsea_result.results.items()}

In [None]:
nes_df = pd.DataFrame.from_dict(nes, orient="index")
fdr_df = pd.DataFrame.from_dict(fdr, orient="index")
nes_df.index.name = "cell ontology class"
fdr_df.index.name = "cell ontology class"
successful_cell_types = np.intersect1d(nes_df.index, nes_df.columns)
nes_df = nes_df.loc[successful_cell_types, successful_cell_types]
fdr_df = fdr_df.loc[successful_cell_types, successful_cell_types]

In [None]:
nes_df = nes_df.reset_index().melt(id_vars="cell ontology class", var_name="PanglaoDB markers", value_name="NES")
fdr_df = fdr_df.reset_index().melt(id_vars="cell ontology class", var_name="PanglaoDB markers", value_name="FDR")
nes_fdr_df = pd.merge(nes_df, fdr_df)
nes_fdr_df["-log10 FDR"] = np.minimum(-np.log10(nes_fdr_df["FDR"]), 6)
nes_fdr_df.head()

In [None]:
hue_norm = MidpointNormalize(
    midpoint=0, vmin=min(nes_fdr_df["NES"]), vmax=max(nes_fdr_df["NES"]))
sm = plt.cm.ScalarMappable(cmap="seismic", norm=hue_norm)
sm.set_array([])

fig, ax = plt.subplots(figsize=(5, 4))
ax = sns.scatterplot(
    x="PanglaoDB markers", y="cell ontology class", hue="NES", size="-log10 FDR",
    data=nes_fdr_df, palette="seismic", hue_norm=hue_norm, sizes=(1, 300),
    edgecolor=None, ax=ax
)
ax.invert_yaxis()
ax.tick_params(axis="x", labelrotation=90)
cbar = plt.colorbar(sm)
cbar.set_label("NES", rotation=270)
h, l = ax.get_legend_handles_labels()
ax.legend(h[-5:], l[-5:], loc="center left", bbox_to_anchor=(1.3, 0.5), frameon=False, labelspacing=1.2)
fig.savefig(f"{PATH}/lung/dotplot.pdf", bbox_inches="tight")

## Mammary gland

PanglaoDB has this exact dataset, with each of the eight donors as a separate dataset, but their cell type annotations are largely inconsistent with the original publication. E.g., according to PanglaoDB, both of the two NP donors contain: "luminal epithelial cells" and "myoepithelial cells", while the first NP donor also contains "mammary epithelial cells". However, in the original publication, the two NP donors have similar cell type composition, including: "basal cell" / "bsl", "luminal progenitor" / "lp", "hormone sensing progenitor" / "hsp", "hormone sensing differentiated" / "hsd". Except for the basal cells, all other cell types are labeled as "luminal epithelial cell of mammary gland" in terms of cell ontology.

![PanglaoDB_NP1](https://panglaodb.se/plots/SRA625553_SRS2641016.tSNE_w_labels.png)
![PanglaoDB_NP2](https://panglaodb.se/plots/SRA625553_SRS2641017.tSNE_w_labels.png)

After checking PanglaoDB marker genes, it appears that many PanglaoDB markers assigned to "Luminal epithelial cells" are indeed luminal markers (mostly hormone sensing cells), but PanglaoDB markers in "Mammary epithelial cells" contains markers for all kinds of mammary epithelical cell types.

As such, we discard the "Mammary epithelical cells" category in PanglaoDB markers, and use "Luminal epithelial cells" markers for cell ontology "luminal epithelial cell of mammary gland", and "Myoepithelial cells" markers for cell ontology "myoepithelial cell of mammary gland".

### Clean data

In [None]:
blast = cb.blast.BLAST.load("../../Results/Cell_BLAST/Mammary_Gland/seed_0/blast")
blast.models = [
    cb.directi.DIRECTi.load(
        f"../../Results/Cell_BLAST/Bach/dim_10/seed_{i}"
    ) for i in range(16)
]  # We are not directly using this BLAST object for querying so it's okay to override models
blast.ref.obs["cell_ontology_class"] = pd.Categorical(blast.ref.obs["cell_ontology_class"])

In [None]:
blast.ref.latent = blast.models[0].inference(blast.ref)
ax = blast.ref.visualize_latent("cell_ontology_class", method="UMAP", dr_kws=dict(min_dist=0.5))

Cell type annotation of some cells may not be reliable, which could significantly confuse neighbor-based gradient analysis (especially the subgroup of ductal cells located closer to acinar cells), so we remove these ambiguous cells in advance.

In [None]:
blast.ref.obs["silhouette"] = sklearn.metrics.silhouette_samples(
    blast.ref.latent, blast.ref.obs["cell_ontology_class"])
clean_ref = blast.ref[blast.ref.obs["silhouette"] > 0.25, :]
ax = clean_ref.visualize_latent("cell_ontology_class", method="UMAP")

In [None]:
used_cell_types = np.intersect1d(
    np.unique(clean_ref.obs["cell_ontology_class"]),
    np.unique(panglao_markers_mouse["cell_ontology_class"])
)
used_cell_types

In [None]:
gene_sets = {
    cell_type: np.intersect1d(panglao_markers_mouse.query(
        f"cell_ontology_class == '{cell_type}'"
    )["official gene symbol"], blast.models[0].genes).tolist()
    for cell_type in used_cell_types
}
if FILTER_GENE_SETS:
    cell_type_mean_exprs = clean_ref.get_meta_or_var(
        clean_ref.var_names.to_numpy().tolist() + ["cell_ontology_class"],
        normalize_var=True, log_var=True
    ).groupby("cell_ontology_class").mean()
    for cell_type in used_cell_types:
        gene_sets[cell_type] = [
            gene for gene in gene_sets[cell_type]
            if cell_type in cell_type_mean_exprs.index[
                cell_type_mean_exprs[gene].argsort()[-3:]
            ]
        ]
gene_sets

### Gradients

In [None]:
nes, pval, fdr = {}, {}, {}
for cell_type in used_cell_types:
    print(f"Dealing with {cell_type}...")
    used_ref = clean_ref[clean_ref.obs["cell_ontology_class"] != cell_type, :]
    used_query = clean_ref[clean_ref.obs["cell_ontology_class"] == cell_type, :]
    gene_grad = []
    for model in blast.models:
        used_blast = cb.blast.BLAST(
            [model], used_ref, distance_metric="ed"
        )  # Skip posterior distance since we are not doing any filtering
        hits = used_blast.query(used_query, n_neighbors=100, store_dataset=True)
        _gene_grad = hits.gene_gradient()
        _gene_grad = np.concatenate(_gene_grad)
        _gene_grad = np.average(
            _gene_grad, axis=0,
            weights=equalizing_weights(
                used_ref.obs["cell_ontology_class"][np.concatenate(hits.hits)]
            )
        )
        gene_grad.append(_gene_grad)
    gene_grad = np.stack(gene_grad).mean(axis=0)
    gene_grad = pd.DataFrame({0: blast.models[0].genes, 1: gene_grad})
    try:
        gsea_result = gp.prerank(
            gene_grad, gene_sets.copy(),  # gp.prerank seems to modify gene sets in-place
            outdir=f"{PATH}/mammary_gland/{cell_type}",
            weighted_score_type=0, min_size=10, seed=0
        )
    except Exception:
        print(f"GSEA failed with {cell_type}, skipped...")
        continue
    nes[cell_type] = {key: val["nes"] for key, val in gsea_result.results.items()}
    pval[cell_type] = {key: val["pval"] for key, val in gsea_result.results.items()}
    fdr[cell_type] = {key: val["fdr"] for key, val in gsea_result.results.items()}

In [None]:
nes_df = pd.DataFrame.from_dict(nes, orient="index")
fdr_df = pd.DataFrame.from_dict(fdr, orient="index")
nes_df.index.name = "cell ontology class"
fdr_df.index.name = "cell ontology class"
successful_cell_types = np.intersect1d(nes_df.index, nes_df.columns)
nes_df = nes_df.loc[successful_cell_types, successful_cell_types]
fdr_df = fdr_df.loc[successful_cell_types, successful_cell_types]

In [None]:
nes_df = nes_df.reset_index().melt(id_vars="cell ontology class", var_name="PanglaoDB markers", value_name="NES")
fdr_df = fdr_df.reset_index().melt(id_vars="cell ontology class", var_name="PanglaoDB markers", value_name="FDR")
nes_fdr_df = pd.merge(nes_df, fdr_df)
nes_fdr_df["-log10 FDR"] = np.minimum(-np.log10(nes_fdr_df["FDR"]), 6)
nes_fdr_df.head()

In [None]:
hue_norm = MidpointNormalize(
    midpoint=0, vmin=min(nes_fdr_df["NES"]), vmax=max(nes_fdr_df["NES"]))
sm = plt.cm.ScalarMappable(cmap="seismic", norm=hue_norm)
sm.set_array([])

fig, ax = plt.subplots(figsize=(5, 4))
ax = sns.scatterplot(
    x="PanglaoDB markers", y="cell ontology class", hue="NES", size="-log10 FDR",
    data=nes_fdr_df, palette="seismic", hue_norm=hue_norm, sizes=(1, 300),
    edgecolor=None, ax=ax
)
ax.invert_yaxis()
ax.tick_params(axis="x", labelrotation=90)
cbar = plt.colorbar(sm)
cbar.set_label("NES", rotation=270)
h, l = ax.get_legend_handles_labels()
ax.legend(h[-6:], l[-6:], loc="center left", bbox_to_anchor=(1.3, 0.5), frameon=False, labelspacing=1.2)
fig.savefig(f"{PATH}/mammary_gland/dotplot.pdf", bbox_inches="tight")