In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.ticker as ticker

from torchmetrics import F1Score
import torch

In [None]:
def compute_metric(predictions):
    try:
        return predictions.loc[snakemake.params.metric, "value"]
    except KeyError:
        if snakemake.params.metric == "accuracy":
            return predictions.is_correct.mean()
        elif snakemake.params.metric == "f1_score":
            # Convert labels to indices
            unique_labels = predictions["label"].unique()
            label_to_index = {label: idx for idx, label in enumerate(unique_labels)}

            # Handle NaN in predicted_labels by assigning a special index
            predictions.loc[
                ~predictions["predicted_labels"].isin(unique_labels), "predicted_labels"
            ] = "NaN"
            unique_labels_with_nan = np.append(unique_labels, "NaN")
            label_to_index_with_nan = {
                label: idx for idx, label in enumerate(unique_labels_with_nan)
            }

            # Convert predicted and true labels to indices
            y_pred = (
                predictions["predicted_labels"]
                .map(label_to_index_with_nan)
                .values.astype(int)
            )
            y_true = predictions["label"].map(label_to_index).values

            # Calculate F1 score using torchmetrics
            f1 = F1Score(
                "multiclass",
                num_classes=len(unique_labels_with_nan),
                average="macro",
                ignore_index=label_to_index_with_nan["NaN"],
            )
            f1_score = f1(torch.tensor(y_pred), torch.tensor(y_true))

            return f1_score.item()

        else:
            raise ValueError

In [None]:
# create a table with columns=datasets and rows=models

df = pd.DataFrame(
    index=snakemake.params.models, columns=snakemake.params.datasets, dtype=float
)

for model_i, model_name in enumerate(snakemake.params.models):
    for dataset_i, dataset_name in enumerate(snakemake.params.datasets):
        fn = snakemake.input["predictions"][
            model_i * len(snakemake.params.datasets) + dataset_i
        ]
        # assert model_name in fn
        assert dataset_name in fn, f"{dataset_name}, {fn}"
        subdf = pd.read_csv(fn, index_col=0)
        df.loc[model_name, dataset_name] = compute_metric(subdf.copy())

In [None]:
df.to_csv(snakemake.output.aggregated_predictions)

In [None]:
df

In [None]:
if snakemake.params.datasets == ["tabula_sapiens_100_cells_per_type"]:
    # Cost data (manually computed)
    # Model types
    colors = ["tab:gray"] * 6 + ["tab:green"] * 2 + ["tab:orange"] * 3

    # scoring types
    markers = [".", ".", ".", ".", 
               "*", "*", "*", "*", 
               "+", "+", "+"]

    plt.figure(figsize=(8, 5))

    # Create scatter plot with dual encoding
    for i in range(len(df)):
        plt.scatter(
            snakemake.params.costs[i],
            df.iloc[i]["tabula_sapiens_100_cells_per_type"],
            color=colors[i],
            marker=markers[i],
            s=100,  # Adjust marker size as needed
        )
        # Add text label for each point
        plt.text(
            snakemake.params.costs[i],
            df.iloc[i]["tabula_sapiens_100_cells_per_type"] + 0.007,
            snakemake.params.labels[i],
            fontsize=8,
            ha="left",  # Horizontal alignment
        )

    # Add horizontal line for random baseline
    plt.axhline(
        1 / 177, color="gray", linestyle="--", zorder=-1, label="Random baseline"
    )

    # Create custom legend elements
    legend_elements = [
        plt.Line2D(
            [0],
            [0],
            marker=".",
            color="w",
            markerfacecolor="black",
            markersize=10,
            label="LLM generative evaluation",
        ),
        plt.Line2D(
            [0],
            [0],
            marker="*",
            color="w",
            markerfacecolor="black",
            markersize=10,
            label="LLM perplexity evaluation",
        ),
        plt.Line2D(
            [0],
            [0],
            marker="+",
            color="w",
            markerfacecolor="black",
            markersize=10,
            label="CellWhisperer score evaluation",
        ),
        plt.Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor="gray",
            markersize=10,
            label="Text-only LLM",
        ),
        plt.Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor="green",
            markersize=10,
            label="CellWhisperer chat model",
        ),
        plt.Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor="orange",
            markersize=10,
            label="CellWhisperer embedding model",
        ),
        # plt.Line2D([0], [0], color='black', linestyle='--', label='Random Baseline')
    ]

    plt.legend(handles=legend_elements, title="Legend", loc="upper right")

    # Set the x-axis formatter to use plain numbers
    plt.gca().xaxis.set_major_formatter(ticker.ScalarFormatter())
    plt.gca().xaxis.get_major_formatter().set_scientific(False)

    # Axis labels and title
    plt.ylabel("→ Accuracy")
    plt.xlabel("← Cost (USD per 1M cells)")
    plt.title("Tabula Sapiens (100 cells per type) prediction performance vs. cost")
    plt.xscale("log")  # Recommended for better cost visualization
else:
    sns.heatmap(df, annot=True)
    plt.title(snakemake.params.plot_title)

sns.despine()
plt.tight_layout()
plt.savefig(snakemake.output.aggregated_predictions_plot)