In [None]:
import warnings
# avoid DeprecationWarning: np.find_common_type is deprecated due to pandas version (needed by other packages)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pandas.core.algorithms")

In [None]:
import logging
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import json


from cellwhisperer.validation.integration.functions import eval_scib_metrics
from server.common.colors import CSS4_NAMED_COLORS 
from zero_shot_validation_scripts.utils import TABSAP_WELLSTUDIED_COLORMAPPING, PANCREAS_ORDER, SUFFIX_PREFIX_DICT

from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset

In [None]:
#### Parameters ####

matplotlib.style.use(snakemake.input.mpl_style)
sc.set_figure_params(
    vector_friendly=True, dpi_save=500
)  # Makes PDFs of scatter plots much smaller in size but still high-quality


dataset_name = snakemake.wildcards.dataset

In [None]:
result_metrics_dict = {}

#### Load data
adata = load_and_preprocess_dataset(dataset_name=dataset_name, read_count_table_path = snakemake.input.raw_read_count_table,
                            # model-specific embeddings/features
                            obsm_paths={"X_umap_on_neighbors": (snakemake.input.umap, "neighbors"),
                                        "X_features": (snakemake.input.processed_dataset, "transcriptome_embeds"),
                                        # "X_geneformer": snakemake.input.TODO
                            })
logging.info(f"Data loaded and preprocessed. Shape: {adata.shape}")



In [None]:
# Calculate integration metrics
results = eval_scib_metrics(
    adata,
    label_key="celltype",
    batch_key="batch",
    embedding_key=f"X_features",
)

with open(snakemake.output.integration_scores, "w") as f:
    json.dump(results, f, indent=4)

In [None]:
celltype_palette = {celltype:list(CSS4_NAMED_COLORS.values())[i if i<len(CSS4_NAMED_COLORS.values()) else i-len(CSS4_NAMED_COLORS.values())] for i,celltype in enumerate(adata.obs.celltype.unique())}
if "tabula_sapiens" in dataset_name:
    # update the celltype palette with the well-studied cell types
    celltype_palette.update(TABSAP_WELLSTUDIED_COLORMAPPING)

In [None]:
# Plot the embeddings colored by celltype and batch. Plot the embeddings colored by batch and celltype, and add the integration scores to the title.

fig, axes = plt.subplots(
    1, 2, figsize=(15, 5)
)

sc.pl.embedding(
    adata,
    basis=f"X_umap_on_neighbors",
    color="batch",
    frameon=False,
    s=10,
    alpha=0.5,
    legend_fontsize=8,
    legend_loc="right margin",
    legend_fontoutline=2,
    ax=axes[0],
    show=False,
)
try:
    batch_integration_score = round(
        results["ASW_label__batch"],
        2,
    )
except KeyError:
    batch_integration_score = "NA"
axes[0].set_title(
    f"batch integration score= {batch_integration_score}"
)

sc.pl.embedding(
    adata,
    basis=f"X_umap_on_neighbors",
    color="celltype",
    frameon=False,
    s=10,
    alpha=0.5,
    legend_fontsize=8,
    legend_loc="right margin",
    legend_fontoutline=2,
    ax=axes[1],
    show=False,
    palette=celltype_palette,
)
asw_label = round(
    results["ASW_label"], 2
)
avg_bio = round(
    results["avg_bio"], 2
)
axes[1].set_title(
    f"ASW_label= {asw_label}\n avg_bio= {avg_bio}"
)
if adata.obs.celltype.nunique() > 50:
    axes[1].get_legend().remove()
plt.tight_layout()
plt.suptitle(f"{dataset_name}")
plt.savefig(snakemake.output.embedding_plots_zero_shot_comparison_pdf, dpi=None)
plt.savefig(snakemake.output.embedding_plots_zero_shot_comparison_png, dpi=900)
plt.show()
plt.close()