In [None]:
# start coding here
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.style.use(snakemake.input.mpl_style)


# Data provided by running `cellwhisperer test --config src/experiments/408_ablation_study3/base_config.yaml --model_ckpt results/models/jointemb/cellwhisperer_clip_v1.ckpt --seed_everything 0`
data = {
    'disease_deduplicated/text_as_classes_accuracy_macroAvg': 0.1899999976158142,
    'disease_deduplicated/text_as_classes_f1_macroAvg': 0.1536666750907898,
    'disease_deduplicated/text_as_classes_precision_macroAvg': 0.14154762029647827,
    'disease_deduplicated/text_as_classes_recall_at_10_macroAvg': 0.7900000214576721,
    'disease_deduplicated/text_as_classes_recall_at_1_macroAvg': 0.1899999976158142,
    'disease_deduplicated/text_as_classes_recall_at_50_macroAvg': 0.9700000286102295,
    'disease_deduplicated/text_as_classes_recall_at_5_macroAvg': 0.550000011920929,
    'disease_deduplicated/text_as_classes_rocauc_macroAvg': 0.9286867380142212,
    'disease_deduplicated/transcriptomes_as_classes_accuracy_macroAvg': 0.23999999463558197,
    'disease_deduplicated/transcriptomes_as_classes_f1_macroAvg': 0.17900002002716064,
    'disease_deduplicated/transcriptomes_as_classes_precision_macroAvg': 0.15416665375232697,
    'disease_deduplicated/transcriptomes_as_classes_recall_at_10_macroAvg': 0.7900000214576721,
    'disease_deduplicated/transcriptomes_as_classes_recall_at_1_macroAvg': 0.23999999463558197,
    'disease_deduplicated/transcriptomes_as_classes_recall_at_50_macroAvg': 0.9800000190734863,
    'disease_deduplicated/transcriptomes_as_classes_recall_at_5_macroAvg': 0.5899999737739563,
    'disease_deduplicated/transcriptomes_as_classes_rocauc_macroAvg': 0.9259596467018127
}

# Create a pandas Series
series = pd.Series(data)

In [None]:


# Assuming 'series' is the pandas Series you've created earlier

# Filter the series to only include recall_at values
recall_at_series = series.filter(like='recall_at_')
recall_at_series

In [None]:
# Split the index to create a DataFrame suitable for seaborn
recall_at_df = recall_at_series.reset_index()
recall_at_df[['class', 'retrieval_at']] = recall_at_df['index'].str.split('_', expand=True).iloc[:, [1, -2]]
recall_at_df.drop('index', axis=1, inplace=True)
recall_at_df.rename(columns={0: 'value'}, inplace=True)
recall_at_df

In [None]:
# Create the barplot
fig = plt.figure(figsize=(2, 2))
barplot = sns.barplot(
    data=recall_at_df,
    x='class',
    y='value',
    hue='retrieval_at',
    hue_order=["1", "5", "10", "50"],
    palette='viridis'
)
barplot.set_xticklabels(barplot.get_xticklabels(), ha="right", rotation=30)

# Add value labels
for p in barplot.patches:
    if p.get_height() == 0.0:
        continue
    barplot.annotate(
        format(p.get_height(), '.2f'),
        (p.get_x() + p.get_width() / 2., p.get_height()),
        ha='center',
        va='center',
        xytext=(0, 9),
        textcoords='offset points'
    )
    
# Adjust legend and plot
# plt.legend(title='Retrieval at', bbox_to_anchor=(1.1, 1), loc='upper left')

sns.despine()
plt.tight_layout()

fig.savefig(snakemake.output.barplot)

In [None]:
# lineplot

recall_at_5_df = pd.read_csv(snakemake.input.csv)

In [None]:
fig, ax = plt.subplots(figsize=(1.5,1.1))
sns.lineplot(data=recall_at_5_df, x=recall_at_5_df.index + 1, y="cellwhisperer_v1 - valfn_daniel_strictly_deduplicated_dmis-lab_biobert-v1.1_CLS_pooling/recall_at_5_macroAvg", ax=ax, color="gray")
ax.set(xlabel="epoch", ylabel="Recall at top 5")

fig.savefig(snakemake.output.lineplot)