In [None]:

import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import torchmetrics
import torch
import scipy

matplotlib.style.use(snakemake.input.mpl_style)

In [None]:
df = pd.read_csv(snakemake.input.all_perplexities, index_col=0)
if np.isnan(df["question_id"]).any():
    df["question_id"] = df.index

In [None]:
def quantile(group):
    correct = group.loc[group.type == "correct", "eval_all_perplexities"].iloc[0]
    incorrect = group.loc[group.type == "incorrect", "eval_all_perplexities"]

    return (correct < incorrect).sum() / len(incorrect)


quantiles = df.groupby("question_id").apply(quantile, include_groups=False)
matched = df.groupby("question_id").apply(
    lambda group: group.loc[group.type == "correct", "eval_all_perplexities"].iloc[0],
    include_groups=False,
)
mismatched_mean = df.groupby("question_id").apply(
    lambda group: group.loc[group.type == "incorrect", "eval_all_perplexities"].mean(),
    include_groups=False,
)
mismatched_std = df.groupby("question_id").apply(
    lambda group: group.loc[group.type == "incorrect", "eval_all_perplexities"].std(),
    include_groups=False,
)
true_responses = df.groupby("question_id").apply(
    lambda group: group.loc[group.type == "correct", "response"].unique()[0]
)

plotdf = pd.DataFrame(
    {
        "matched": matched,
        "quantiles": quantiles,
        "mismatched": mismatched_mean,
        "mismatch_std": mismatched_std,
        "id": quantiles.index,
    }
).melt(id_vars=["id"], value_vars=["matched", "mismatched"], value_name="perplexity")

In [None]:
sns.violinplot(
    x=pd.DataFrame({"quantile": quantiles, "celltype": true_responses})
    .groupby("celltype")["quantile"]
    .mean()
)

In [None]:
barplot_style = dict(
    color="#b1c25a",
    edgecolor="#424242",
    linewidth=1.5,
    errwidth=2,
    width=0.7,
)

if (
    "tabula_sapiens" in snakemake.wildcards.dataset
    and not snakemake.wildcards.dataset.endswith("_top50genes")
):
    # plot by cell type
    fig, (ax1, ax2) = plt.subplots(
        2, 1, figsize=(3.8, 3.5), gridspec_kw={"height_ratios": [8, 1]}, sharex=True
    )

    subdf = pd.DataFrame({"quantile": quantiles, "celltype": true_responses})
    subdf = subdf[
        subdf.celltype.str.lower().isin(
            [
                (snakemake.params.response_prefix + c).lower()
                for c in snakemake.params.plot_celltypes
            ]
        )
    ]
    order = subdf["celltype"].sort_values().drop_duplicates().values

    sns.barplot(
        data=subdf,
        y="celltype",
        x="quantile",
        ax=ax1,
        order=order,
        **barplot_style,
    )
    sns.violinplot(
        x=pd.DataFrame({"quantile": quantiles, "celltype": true_responses})
        .groupby("celltype")["quantile"]
        .mean(),
        ax=ax2,
        **barplot_style,
    )
    ax2.set_yticklabels(["All 177 cell types"])
    # ax.set_yticklabels(ax.get_yticklabels(), ha="right", rotation=90)
    ax1.axvline(0.5, linestyle="--", color="#424242", linewidth=1)
    ax2.axvline(0.5, linestyle="--", color="#424242", linewidth=1)
    ax = ax2
else:
    fig, ax = plt.subplots(figsize=(4.5, 0.2))
    sns.barplot(x=quantiles, ax=ax, **barplot_style)
    if snakemake.wildcards.dataset.startswith("main"):
        ax.set_yticklabels(
            ["Training-excluded ARCHS4 and CELLxGENE question-answer pairs (n=200)"]
        )
    elif "tabula_sapiens" in snakemake.wildcards.dataset:
        ax.set_yticklabels(["TabSap top 50 gene prediction"])
    ax.set_title(
        "LLM preference for correct vs mismatched embedding for test conversations",
        fontdict={"fontsize": 8},
    )

    ax.axvline(0.5, linestyle="--", color="#424242", linewidth=1)
    
ax.set(
    xlabel=f"Perplexity quantile of matched vs {df['replicate'].max() + 1} mismatched responses"
)
sns.despine()
plt.tight_layout()
fig.savefig(snakemake.output.comparison_plot)

### Render the Supp. Table xlsx

In [None]:
supp_table = (
    pd.DataFrame({"quantile": quantiles, "celltype": true_responses})
    .groupby("celltype")["quantile"]
    .agg([np.mean, np.std])
)
supp_table.columns = ["Mean of quantile", "Standard deviation of quantile"]
supp_table.index.name = "Cell type (with prefix)"
supp_table.to_excel(snakemake.output.full_supp_table)
supp_table

In [None]:
with open(snakemake.output.log_perplexity_ratio, "w") as f:
    f.write(
        str(
            np.log(
                plotdf.loc[plotdf["variable"] == "matched", "perplexity"].mean()
                / plotdf.loc[plotdf["variable"] == "mismatched", "perplexity"].mean()
            )
        )
    )

In [None]:
num_questions = 5
subdf = df[df.question_id < num_questions]

fig, ax = plt.subplots(1, 1, figsize=(5, 2))

sns.swarmplot(
    data=subdf,
    x="eval_all_perplexities",
    # hue="type",
    y="question_id",
    ax=ax,
    color="#cccccc55",
    # palette={"incorrect": "#cccccc55", "correct": "#000000ff"},
)


sns.violinplot(
    data=subdf,
    x="eval_all_perplexities",
    orient="h",
    # hue="type",
    y="question_id",
    ax=ax,
    color="gray",
    # palette={"incorrect": "#cccccc55", "correct": "#000000ff"},
)

x = list(range(num_questions))
ax.scatter(
    subdf.query("type == 'correct'").loc[x, "eval_all_perplexities"], x, color="red"
)

In [None]:
num_questions = 5
subdf = df[df.question_id < num_questions]

fig, ax = plt.subplots(1, 1, figsize=(2, 0.5))
# palette with green: {"incorrect": "#333333", "correct": "#a8b567"}
sns.stripplot(
    data=subdf,
    x="eval_all_perplexities",
    hue="type",
    y="question_id",
    ax=ax,
    palette={"incorrect": "#cccccc55", "correct": "#000000ff"},
)

xmax = subdf.eval_all_perplexities.max()
xmin = subdf.eval_all_perplexities.min()

# Set the xticks to only the first and last
ax.set_ylim([-0.5, 4.5])
ax.set_xlim(
    [xmin - 2, xmax + 2]
)  # this should be wrong (because the x-axis does take indices or so. but it works...)
ticks = ax.get_xticklabels()
ax.set(
    xticks=[ticks[0]._x, ticks[-1]._x],
    xticklabels=[ticks[0]._text, ticks[-1]._text],
    yticks=range(5),
    yticklabels=range(1, 6),
)

for i in range(num_questions):
    ax.text(xmax, i, f"{quantiles.iloc[i]:.2f}")
# plt.tight_layout()

fig.savefig(snakemake.output.detailed_plot)