In [None]:

import json

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


def prepare_integration_df(result_metrics_dict: dict) -> pd.DataFrame:
    """
    Restructure the result_metrics_dict to a pandas DataFrame for easier plotting.
    :param result_metrics_dict: Dictionary containing the integration scores.
    :return: pandas DataFrame containing the integration scores.
    """
    integration_scores_df = pd.DataFrame(result_metrics_dict).T
    integration_scores_df = integration_scores_df.reset_index()
    integration_scores_df = integration_scores_df.rename(
        columns={"level_0": "dataset_name", "level_1": "method"}
    )
    integration_scores_df = integration_scores_df.rename(
        columns={
            "ASW_label__batch": "Batch integration score",
            "avg_bio": "Cell type integration score (avg-bio)",
            "ASW_label": "Cell type integration score (ASW)",
        }
    )
    integration_scores_df = integration_scores_df[
        [
            "dataset_name",
            "method",
            "Batch integration score",
            "Cell type integration score (avg-bio)",
            "Cell type integration score (ASW)",
        ]
    ]
    integration_scores_df = pd.melt(
        integration_scores_df,
        id_vars=["dataset_name", "method"],
        value_vars=[
            "Batch integration score",
            "Cell type integration score (avg-bio)",
            "Cell type integration score (ASW)",
        ],
        var_name="metric",
        value_name="value",
    )
    return integration_scores_df

In [None]:

# if adata.obs.batch.nunique() > 1:

#### Plot and Save integration metrics

result_metrics_dict = {}
for model, fn in zip(snakemake.params.models, snakemake.input):
    with open(fn) as f:
        result_metrics_dict[(snakemake.wildcards.dataset, model)] = json.load(f)

integration_scores_df = prepare_integration_df(result_metrics_dict)
integration_scores_df.to_csv(snakemake.output.integration_scores, index=True)
integration_scores_df.head()

In [None]:
integration_scores_df["scFM"] = integration_scores_df["method"].apply({
    "geneformer": "geneformer",
    "uce": "uce",
    "scgpt": "scgpt",
    "cellwhisperer_clip_v1": "geneformer",
    "cellwhisperer_clip_v2_uce": "uce",
    "cellwhisperer_clip_v2_scgpt": "scgpt"}.get)
integration_scores_df["CellWhisperer"] = integration_scores_df["method"].str.startswith("cellwhisperer")


metrics = integration_scores_df["metric"].unique()
fig, axes = plt.subplots(len(metrics), 1, figsize=(4, 4))

for ax, metric in zip(axes, metrics):
    sns.barplot(
        data=integration_scores_df.loc[integration_scores_df["metric"] == metric],
        y="scFM",
        hue="CellWhisperer",
        x="value",
        ax=ax,
        palette="Greys",
    )
    ax.set_title(metric)
plt.tight_layout()
fig.savefig(snakemake.output.integration_scores_plot)