In [None]:
# start coding here
import logging
import json
import os
from collections import defaultdict

import seaborn as sns
import pandas as pd
import anndata
import matplotlib

import numpy as np


# setup snakemake logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[logging.FileHandler(snakemake.log[0]), logging.StreamHandler()],  # type: ignore [reportUndefinedVariable]
)

matplotlib.style.use(snakemake.input.mpl_style)

In [None]:
reviews = []
with open(snakemake.input.evaluation) as f:
    for review_str in f:
        review = json.loads(review_str)
        review["gpt4_all_info"] = review["scores"].pop("generation_reference_responses")
        review["CellWhisperer"] = review["scores"].pop("generation_llava_responses")
        # review["llava_with_top_genes"] = review["scores"].pop("generation_llava_responses_with_top_genes")
        review["llava_text_only"] = review["scores"].pop(
            "generation_llava_responses_text_only"
        )
        review["gpt4_top_genes"] = review["scores"].pop(
            "generation_gpt4transcriptome_responses"
        )
        review["dataset"] = (
            "archs4_geo" if "SRX" in review["question_id"] else "cellxgene_census"
        )
        del review["scores"]
        review["answer_ids"] = tuple(review["answer_ids"])
        reviews.append(review)
df = pd.DataFrame(reviews).set_index("id")
df_metadata = df[["question_id", "answer_ids", "content", "category", "dataset"]].copy()
df = df.drop(columns=["question_id", "answer_ids", "content", "category", "dataset"])

In [None]:
normed_df = df.divide(df["gpt4_all_info"], axis=0)
normed_df

In [None]:
# plot it
ax = sns.violinplot(
    data=df.melt(var_name="llm_type", value_name="scores"), x="llm_type", y="scores"
)
# ax.set(ylim=[0, 1])
ax.set_xticklabels(ax.get_xticklabels(), ha="right", rotation=15)

In [None]:
plot_df = normed_df.drop(columns=["gpt4_all_info"]).melt(
    var_name="llm_type", value_name="normalized_score"
)
plot_df.head()

In [None]:

ax = sns.barplot(data=plot_df, x="llm_type", y="normalized_score")
# ax.set(ylim=[0, 1])

In [None]:
# TODO test the failed ones
# df[df.normed < 0.5].iloc[5]["content"]

In [None]:
df_metadata.join(df).loc[df.CellWhisperer < 4].iloc[0]

In [None]:
df_metadata.loc[df.CellWhisperer < 4]["content"].iloc[0]

In [None]:
import anndata

adata = anndata.read_h5ad(snakemake.input.archs4_data, backed="r")

In [None]:
single_cells = adata.obs.query(
    "singlecellprobability > 0.1"
).index  # TODO can also use 0.5 (not much difference)

In [None]:
single_cells

In [None]:
df["sample_id"] = df.join(df_metadata).question_id.apply(
    lambda v: v.split("_", maxsplit=1)[1]
)

In [None]:
df["singlecell"] = df["sample_id"].isin(single_cells)

In [None]:
df["is_complex"] = df["sample_id"].isin(snakemake.params.complex_samples)
df["is_detailed"] = df["sample_id"].isin(snakemake.params.detailed_samples)
df["group"] = df.apply(
    lambda row: (
        "detailed question"
        if row.is_detailed
        else ("complex question" if row.is_complex else "normal question")
    ),
    axis=1,
)

In [None]:
df["is_detailed"].value_counts()

In [None]:
df.head()

In [None]:
colors = {
    "gpt4_all_info": "#7c7c7c",
    "CellWhisperer": "#b1c25a",
    "llava_text_only": "#8c9464",
    "gpt4_top_genes": "#648e94",  # from their logo
}

In [None]:
import matplotlib.pyplot as plt

# plot_df = df.drop(df.index[df.singlecell])  # leads to similar results
plot_df = df.copy()
plot_df = plot_df.join(df_metadata[["dataset"]]).melt(
    var_name="llm_type",
    value_name="score",
    id_vars=["group", "dataset"],
    value_vars=[
        "gpt4_all_info",
        "CellWhisperer",  # "llava_with_top_genes",
        "llava_text_only",  # Mistral
        "gpt4_top_genes",
    ],
)
plot_df.rename(columns={"group": "question type"}, inplace=True)
plot_df["question type"] = plot_df["question type"].apply(lambda v: v.split(" ")[0])

plot_df.drop(
    plot_df.index[
        (plot_df.llm_type == "gpt4_all_info") | (plot_df["question type"] == "detailed")
    ],
    inplace=True,
)  #  'detailed' has an unfair advantage. gpt4_all_info is not informative. see Methods section for details

In [None]:
plot_df

In [None]:
matplotlib.style.use(snakemake.input.mpl_style)

fig, axes = plt.subplots(
    1, 2, figsize=(3, 2), sharey=True
)  # , gridspec_kw={'width_ratios': [3, 2]})
# Create the boxplots without legends
for ax, x_var in zip(axes, ["question type", "dataset"]):
    sns.violinplot(
        data=plot_df,
        x=x_var,
        hue="llm_type",
        y="score",
        ax=ax,
        palette=colors,
        flierprops={"marker": "x"},
        fliersize=1,
        linewidth=1,
    )
    ax.set_xticklabels(ax.get_xticklabels(), ha="right", rotation=30)
    sns.despine()


# Remove the individual legends created by Seaborn
for ax in axes:
    ax.get_legend().remove()

handles, labels = axes[-1].get_legend_handles_labels()
fig.legend(handles, labels, loc="lower center", bbox_to_anchor=(0.5, 1.0), ncol=2)

# Adjust the subplots to make room for the legend
plt.subplots_adjust(top=0.85)

plt.tight_layout()
fig.savefig(snakemake.output.overview_plot)
fig.savefig(snakemake.output.overview_plot + ".png")

# logging.info(f"Overall score: {df.normed.mean()}. without single cells: {df[~df.singlecell].normed.mean()}")

# the presence of single cells in the training dataset overall impacted these scores only minorly (exclusion of 'cells with single cell probability > 0.1' in test set: 0.65 -> 0.63)

In [None]:
# sns.barplot(data=df[~df.singlecell], y="normed", x="dataset", hue="group")