In [None]:
# start coding here
from cellwhisperer.config import get_path

In [None]:
# start coding here

import matplotlib.pyplot as plt
import seaborn
import matplotlib
import pandas as pd

df = pd.read_csv(snakemake.input.csv, index_col=0)
df["valfn_human_disease_dedup/recall_at_5_macroAvg"] = (
    df[
        "valfn_human_disease_strictly_deduplicated_dmis-lab_biobert-v1.1_CLS_pooling/transcriptomes_as_classes_recall_at_5_macroAvg"
    ]
    + df[
        "valfn_human_disease_strictly_deduplicated_dmis-lab_biobert-v1.1_CLS_pooling/text_as_classes_recall_at_5_macroAvg"
    ]
)
df = df.loc[:, snakemake.params.metrics]

df["run_config"] = df.index.map(lambda x: x.rsplit("_", 1)[0])
# df["seed"] = df.index.map(lambda x: x.rsplit("_", 1)[1]).astype(int)
df = df.melt(var_name="metric", value_name="value", id_vars=["run_config"])

# 0-1 normalize each metric across all configs and seeds
df["normalized_value"] = df.groupby("metric")["value"].transform(
    lambda x: (x - x.min()) / (x.max() - x.min())
)

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Calculate means
means = df.groupby("run_config")["normalized_value"].mean().reset_index()

# Sort the means
sorted_means = means.sort_values("normalized_value")

# Order the DataFrame based on the sorted means
df["run_config"] = pd.Categorical(
    df["run_config"], categories=sorted_means["run_config"], ordered=True
)
df = df.sort_values("run_config")

# Now plot with seaborn
fig, ax = plt.subplots(figsize=(6, 3))
sns.barplot(data=df, y="run_config", x="normalized_value", ax=ax, color="gray")

# Add color for "CellWhisperer" bar (`#ee9703`). First, identify the correct patch/bar using the label, then set color. TODO test
label = "full_model"
for patch in ax.patches:
    if patch.get_y() == df[df["run_config"] == label]["run_config"].values[0]:
        patch.set_facecolor("#ee9703")  # Set the color for the specific bar
        break

sns.stripplot(
    data=df,
    y="run_config",
    x="normalized_value",
    ax=ax,
    hue="metric",
    palette=sns.cubehelix_palette(start=0.5, rot=-0.75),
)

ax.set(xlim=[0, 1], xlabel="0-1-normalized scores")
# Set the x-tick labels with rotation
# ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha="right")


# Use tight_layout to adjust the plot
plt.tight_layout()

fig.savefig(snakemake.output.all_models_comparison)

In [None]:
import matplotlib.pyplot as plt
import seaborn

# Assuming df and get_path are defined elsewhere in your code
# matplotlib.style.use(get_path(["plot_style"]))
fig, ax = plt.subplots(figsize=(4, 4))

df["run_config"] = df["run_config"].astype(str)

seaborn.lineplot(
    data=df[df.run_config.isin(snakemake.params.top_models)],
    x="metric",
    y="normalized_value",
    hue="run_config",
    ax=ax,
)

ax.set(xlabel="", ylabel="Normalized value", ylim=(0, 1))
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

# Set the legend to display in multiple columns, for example, 2 columns
legend = ax.legend(ncol=2, loc="lower center", fancybox=True)


plt.tight_layout()
# fig.savefig(snakemake.output.top_models_metrics_details)

In [None]:
from IPython.core.display import HTML, Image

In [None]:
import plotly.express as px
import pandas as pd

plot_df = (
    df[df.run_config.isin(snakemake.params.top_models) & (~df.isna().any(axis=1))]
    .groupby(["run_config", "metric"])["normalized_value"]
    .mean()
    .dropna()
    .to_frame()
    .reset_index()
)


fig = px.line_polar(
    plot_df,
    r="normalized_value",
    theta="metric",
    color="run_config",
    line_close=True,
    color_discrete_map={
        "no_census_data": "#123456",
        "no_archs4_data": "#654321",
        "scgpt": "#216543",
        # "full_model":
    },
)
fig.update_layout(polar=dict(radialaxis=dict(showticklabels=False)))
fig.update_layout(polar=dict(radialaxis=dict(visible=False)))
fig.update_traces(fill="toself")

fig.write_image(
    snakemake.output.top_models_metrics_details + ".png", width=600, height=300, scale=1
)
fig.write_image(
    snakemake.output.top_models_metrics_details, width=600, height=300, scale=1
)
# HTML(fig.to_html())
Image(snakemake.output.top_models_metrics_details + ".png")