In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import matplotlib
matplotlib.style.use(snakemake.input.mpl_style)

In [None]:
# Performance on some examples for various tasks:
fig, axes = plt.subplots(2, len(snakemake.input.per_class_performance), 
                         figsize=(1.2*len(snakemake.input.per_class_performance), 2),
                         sharex="col", sharey="row")

for i, (label_col, per_class_performance,macroavg_performance, selected_samples, dataset) in enumerate(zip(snakemake.params.label_cols, snakemake.input.per_class_performance, 
                                                                                                snakemake.input.macroavg_performance,
                                                                                                snakemake.params.selected_sample_lists, snakemake.params.datasets)):
    prefix, suffix = snakemake.params.suffix_prefix_dict[label_col]

    macroavg_df=pd.read_csv(macroavg_performance, index_col=0)
    macroavg_df=pd.DataFrame({"class":["All (Macro-average)","All (Macro-average)"], "metric":["ROC-AUC","Accuracy"], "value":[float(macroavg_df.loc["rocauc_macroAvg"].values[0].replace("tensor(","").replace(")","")),float(macroavg_df.loc["accuracy_macroAvg"].values[0].replace("tensor(","").replace(")",""))]})


    df = pd.read_csv(per_class_performance)
    df["class"] = df["class"].str.replace(prefix, "").str.replace(suffix, "")
    plot_df = df[df["class"].isin(selected_samples)][["class", "rocauc", "accuracy"]].copy()
    plot_df = plot_df.rename(columns={"rocauc": "ROC-AUC", "accuracy": "Accuracy"})
    plot_df = pd.melt(plot_df, id_vars="class", value_vars=["ROC-AUC","Accuracy"], var_name="metric", value_name="value")
    plot_df=pd.concat([plot_df,macroavg_df])

    for j, metric in enumerate(["ROC-AUC","Accuracy"]):
        plt.sca(axes[j][i])
        sns.barplot(data=plot_df[plot_df["metric"]==metric], x="class", y="value", width=0.6, color="#ee9703")
        if metric=="ROC-AUC":
            plt.axhline(y=0.5, color="black", linestyle="--",linewidth=0.5)
        elif metric=="Accuracy":
            plt.axhline(y=1/len(df["class"].unique()), color="black", linestyle="--",linewidth=0.5)
        else:
            raise ValueError("Unknown metric")
        plt.ylim(0,1)
        plt.legend([],frameon=False)
        plt.xticks(rotation=45, ha="right")
        plt.xlabel("")
        plt.ylabel(metric)
        plt.ylim(0, 1)

plt.subplots_adjust(hspace=0.3, wspace=0.2)
plt.savefig(snakemake.output.per_class_examples_plot)

plt.show()