In [None]:
import glob
import json
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from collections import defaultdict

In [None]:
with open("accession_translator.json", "r") as s:
    accession_translator = json.load(s)

In [None]:
sns.set_theme(style="whitegrid")

In [None]:
def parse_stats_file(stats_path):
    results = defaultdict(list)
    with open(stats_path, 'r') as stream:
        for line in stream:
            if line.startswith("#= Summary for dataset:"):
                datapath = line.split()[-1]
            if "level" in line:
                level, stats = line.split("level:")
                results["datapath"].append(datapath)
                results["level"].append(level.strip())
                results["sensitivity"].append(float(stats.split()[0]))
                results["precision"].append(float(stats.split()[2]))
    return pd.DataFrame.from_dict(results)

In [None]:
stats_paths = glob.glob("/home/lisvad/mnt/nisin/geneml/outputs/gffcompare_benchmarking/*/*.stats")
dfs = []
for path in stats_paths:
    accession = path.split('/')[-2]
    stats_df = parse_stats_file(path)
    stats_df["dataset"] = stats_df["datapath"].str.split('/').str[-2]
    stats_df["species"] = accession_translator[accession]
    dfs.append(stats_df)
all_stats = pd.concat(dfs)
all_stats["species"] = pd.Categorical(all_stats["species"], accession_translator.values())
all_stats = all_stats.sort_values("dataset")
all_stats["F1 score"] = 2*all_stats["precision"]*all_stats["sensitivity"]/(all_stats["precision"]+all_stats["sensitivity"])

In [None]:
datasets = {
    "augustus": "AUGUSTUS",
#    "braker3_noprotdb": "BRAKER3 (no orthoDB)",
    "braker3": "BRAKER3",
    "helixer": "Helixer",
    "GeneML800_c657g_ncbi_for_benchmarking_ep10": "geneML",
    #"geneml_repeats100": "geneML repeat masked 100bp",
    #"geneml_repeats200": "geneML repeat masked 200bp",
}

In [None]:
all_stats["dataset"] = pd.Categorical(all_stats["dataset"], categories=datasets.keys(), ordered=True)
all_stats["dataset"] = all_stats["dataset"].map(datasets)

In [None]:
levels = {
    "Base": "Base",
    "Exon": "Exon",
    "Locus": "Gene",
}

In [None]:
selected_stats = all_stats[all_stats["level"].isin(["Base", "Exon", "Locus"])]
selected_stats["level"] = pd.Categorical(selected_stats["level"], categories=levels.keys(), ordered=True)
selected_stats["level"] = selected_stats["level"].map(levels)

In [None]:
summary = selected_stats.groupby(["level","dataset"])[["sensitivity","precision","F1 score"]].agg(['mean','std'])

collapsed = summary.copy()
collapsed.columns = ['_'.join(col) for col in summary.columns]  # flatten columns

for metric in ["sensitivity","precision","F1 score"]:
    collapsed[metric] = (
        collapsed[f'{metric}_mean'].round(1).astype(str)
        + " ± " +
        collapsed[f'{metric}_std'].round(1).astype(str)
    )

collapsed = collapsed[["sensitivity","precision","F1 score"]]
print(collapsed)

In [None]:
tool_palette = {
    "AUGUSTUS" : "#5A749F",
#    "BRAKER3 (no orthoDB)" : "#FF9898",
    "BRAKER3" : "#B53535",
    "Helixer" : "#622870",
    "geneML" : "#FFAA00",
#    "geneML repeat masked 100bp": "#5FC047",
#    "geneML repeat masked 200bp": "#237010",
}

In [None]:
data = selected_stats.sort_values("dataset").rename(columns={"dataset":"Gene prediction tool","level":"Level"})
plot = sns.relplot(data=data, x="precision", y="sensitivity", col="species", col_wrap=3, height=3, hue="Gene prediction tool", style="Level", palette=tool_palette, s=100, hue_order=selected_stats["dataset"].cat.categories)
plot.set(xlim=(0, 100), ylim=(0, 100))
plot.set_titles("{col_name}", size=11)
plot._legend.set_title(None)
#plt.savefig("gffcompare_plot.svg")