In [None]:

import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib

matplotlib.style.use(snakemake.input.mpl_style)


def violin_strip(*args, **kwargs):
    sns.violinplot(*args, **kwargs, color="gray", inner=None, cut=0)
    sns.stripplot(*args, **kwargs, alpha=0.4, size=1, color="black", dodge=True)


plot_fn = {
    "violin": lambda *args, **kwargs: sns.violinplot(*args, **kwargs, cut=0),
    "boxplot": sns.boxplot,
    "violinstrip": violin_strip,
}[snakemake.params.plot_type]

In [None]:

snakemake.params.comparison_sequence

In [None]:
keys = snakemake.params.comparison_sequence[0]

common_fields = {
    key: snakemake.params.comparison_sequence[0][key]
    for key in snakemake.params.comparison_sequence[0]
    if len(set([v[key] for v in snakemake.params.comparison_sequence])) == 1
}
common_fields

In [None]:
dfs = []
config_order = []  # for plotting later
for params, fn in zip(
    snakemake.params.comparison_sequence, snakemake.input.perplexities
):
    config_str = "__".join(
        [v for k, v in params.items() if k not in common_fields.keys()]
    )
    dfs.append(pd.read_csv(fn, index_col=0))
    dfs[-1]["configuration"] = config_str
    for key, value in params.items():
        dfs[-1][key] = value

    config_order.append(config_str)

df = pd.concat(dfs)
df

In [None]:
if np.isnan(df["question_id"]).any():
    df["question_id"] = df.index

In [None]:
import json

with open(snakemake.input.evaluation_dataset) as f:
    evaluation_dataset = json.load(f)
evaluation_dataset

df["response"] = df["question_id"].map(
    lambda i: evaluation_dataset[i]["conversations"][1]["value"].replace(
        snakemake.params.response_prefix, ""
    )
)

In [None]:
df.iloc[0]

## Strategy

1. Compute the expected (i.e. mean over the 30 mismatched cases) ppl ratio between correct and incorrect embedding
2. plot method-grouped violin plots of these ratios

### For main and TabSap

- violinplots across all responses

### For TabSap

- Make 20 plots (4x5) for the selected ones
- For the full 177 cell types:
  - Just take the ranks (how often was method A first, how often was B first, how often was C first, ...)


In [None]:
if snakemake.wildcards.plot_metric == "log2_ppl_ratio":
    # Compute the expected perplexity ratio for each model
    def ppl_ratio(group):
        correct = group.loc[group.type == "correct", "eval_all_perplexities"]
        assert len(correct) == 1, "There should be exactly one correct answer"
        incorrect = group.loc[group.type == "incorrect", "eval_all_perplexities"]

        return correct.iloc[0] / incorrect.mean()

    ppl_ratios = (
        df.groupby(
            ["configuration", "question_id"]
            + list(snakemake.params.comparison_sequence[0].keys())
        )
        .apply(ppl_ratio)
        .reset_index()
    )
    response = df.drop_duplicates(
        subset=["question_id", "configuration", "response"]
    ).set_index(["question_id", "configuration"])["response"]
    ppl_ratios["response"] = ppl_ratios.apply(
        lambda row: response.loc[row["question_id"], row["configuration"]], axis=1
    )
    ppl_ratios[snakemake.wildcards.plot_metric] = np.log2(ppl_ratios[0])

    ppl_df = ppl_ratios.reset_index()
elif snakemake.wildcards.plot_metric == "log2_correct_ppl":
    # just show the (log2) of the (correct) perplexity
    ppl_df = (
        df[df.type == "correct"]
        .groupby(
            ["configuration", "question_id", "response"]
            + list(snakemake.params.comparison_sequence[0].keys())
        )["eval_all_perplexities"]
        .mean()
        .apply(np.log2)
        .reset_index()
    )
    ppl_df[snakemake.wildcards.plot_metric] = ppl_df["eval_all_perplexities"]

In [None]:
# Plot individual performance


x = "configuration"
hue = None
order = config_order
hue_order = None
width = 0.7 * ppl_df[x].nunique()

if snakemake.wildcards.plot_type == "llava_cw_vs_geneformer":
    pass
if snakemake.wildcards.plot_type == "gene_predictability":
    pass

if snakemake.wildcards.plot_type == "cw_preprompt_useless":
    pass

fig, ax = plt.subplots(figsize=(width, 2.5))
if snakemake.wildcards.plot_metric == "log2_ppl_ratio":
    ax.axhline(0, color="gray", linestyle="--")
plot_fn(
    data=ppl_df,
    x=x,
    y=snakemake.wildcards.plot_metric,
    hue=hue,
    ax=ax,
    order=order,
)
# theoretical baseline

ylabel = (
    r"$\leftarrow;  \log_2 perplexity(\text{matched} / \text{mismatched})$"
    if snakemake.wildcards.plot_metric == "log2_ppl_ratio"
    else r"$\leftarrow;  \log_2 perplexity$"
)
ax.set(
    title=f"{snakemake.wildcards.plot_type}, {snakemake.wildcards.dataset}  ({', '.join(common_fields.values())})",
    ylabel=ylabel,
)

# Rename labels a little
xticklabels = ax.get_xticklabels()
for label in xticklabels:
    label.set_text(
        {
            "cellwhisperer_clip_v1__without50topgenes": "CellWhisperer",
            "cellwhisperer_clip_v1__without50topgenesresponsepermuted": "CellWhisperer (response permuted)",
            "cellwhisperer_clip_v1__with50topgenes": "CellWhisperer with top-50 genes",
        }.get(label.get_text(), label.get_text())
    )


ax.set_xticklabels(xticklabels, rotation=45, ha="right")

if hue is not None:
    ax.legend(loc="lower left")

# plt.tight_layout()


fig.savefig(snakemake.output.individual_performances)
fig.savefig(snakemake.output.individual_performances + ".png")

In [None]:
ax.get_xticklabels()

A bit disturbing that Mistral-7B__NONE with top 50 genes (blue) is better than any multimodal LLM.
Reason though: The negatives (wrong embeddings) as (as well as the positives) in the original Mistral lead to much larger ppls.

Thus better would be to normalize across the mean of all negatives?


### Relative comparisons


In [None]:
from itertools import combinations

# Find unique configurations
configurations = ppl_df["configuration"].unique()

# Initialize comparison matrix
comparison_matrix = pd.DataFrame(0, index=config_order, columns=config_order)

# Compare configurations
for question_id, group in ppl_df.groupby("question_id"):
    pairs = combinations(group["configuration"], 2)
    for config1, config2 in pairs:
        value1 = group.loc[
            group["configuration"] == config1, snakemake.wildcards.plot_metric
        ].values[0]
        value2 = group.loc[
            group["configuration"] == config2, snakemake.wildcards.plot_metric
        ].values[0]
        if value1 <= value2:
            comparison_matrix.loc[config1, config2] += 1
            comparison_matrix.loc[config2, config1] -= 1
        if value1 >= value2:
            comparison_matrix.loc[config2, config1] += 1
            comparison_matrix.loc[config1, config2] -= 1

# Normalize the matrix to get percentages
# comparison_matrix = comparison_matrix.div(comparison_matrix.sum(axis=1), axis=0) * 100

# Plot heatmap
plt.figure(figsize=(3.5, 3))
sns.heatmap(
    comparison_matrix,
    annot=True,
    cmap="RdBu",
    fmt=".0f",
    center=0,
    cbar_kws={"label": "Relative performance"},
)
plt.title(f"Configuration Performance Comparison {snakemake.wildcards.plot_metric}")
plt.xlabel("Configuration")
plt.ylabel("Configuration")


if snakemake.wildcards.plot_type != "llava_cw_vs_geneformer":
    plt.savefig(snakemake.output.relative_performances)
    plt.savefig(snakemake.output.relative_performances + ".png")

In [None]:
if snakemake.wildcards.plot_type == "llava_cw_vs_geneformer":

    plot_df = comparison_matrix.copy()
    plot_df.index = pd.MultiIndex.from_tuples(
        list(map(tuple, comparison_matrix.index.str.split("__")))
    )
    plot_df.columns = pd.MultiIndex.from_tuples(
        list(map(tuple, comparison_matrix.columns.str.split("__")))
    )

    plot_df = plot_df.loc["cellwhisperer_clip_v1", "geneformer"]
    diagonal_series = pd.Series(np.diag(plot_df), index=plot_df.index)
    ax = diagonal_series.plot(kind="bar")
    ax.set(
        ylabel="Num elements where CW is better.",
        title="point-wise comparison between CW and geneformer",
    )
    ax.axhline(0, color="black")
    plt.savefig(snakemake.output.relative_performances)
    plt.savefig(snakemake.output.relative_performances + ".png")

### By cell type (works for tab sap only)


- Make 20 plots (4x5) for the selected ones
- For the full 177 cell types:
  - Just take the ranks (how often was method A first, how often was B first, how often was C first, ...)



In [None]:
fig, axes = plt.subplots(4, 5, figsize=(20, 16), sharex=True, sharey=True)
for ax, celltype in zip(axes.ravel(), snakemake.params.plot_celltypes):
    # Plot the ratio of positive to negative perplexities
    plot_fn(
        data=ppl_df.loc[ppl_df.response == celltype].reset_index(),
        x="configuration",
        y=snakemake.wildcards.plot_metric,
        ax=ax,
    )
    ax.set_title(celltype)

# For lowest row, set x labels
for ax in axes[-1]:
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90, ha="right")

plt.tight_layout()

# plt.savefig(snakemake.output.by_celltype)