In [None]:
import json
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from collections import OrderedDict

In [None]:
with open("accession_translator.json", "r") as s:
    accession_translator = json.load(s)

In [None]:
sns.set_theme(style="whitegrid")

In [None]:
datasets = {
    "augustus": "AUGUSTUS",
#    "braker3_noprotdb": "BRAKER3 (no orthoDB)",
    "braker3": "BRAKER3",
    "helixer": "Helixer",
    "geneml": "geneML",
}

In [None]:
gene_counts = pd.read_csv("/home/lisvad/mnt/nisin/geneml/benchmarking_stats.tsv", sep="\t")
gene_counts = gene_counts.drop(["genes","braker3_noprotdb"], axis=1)
gene_counts = gene_counts.rename(columns={})
gene_counts = gene_counts.melt(id_vars=["genome","genome_size","protein_coding_genes"], var_name="tool", value_name="genes")
gene_counts["tool"] = pd.Categorical(gene_counts["tool"], categories=datasets.keys(), ordered=True)
gene_counts["tool"] = gene_counts["tool"].map(datasets)
gene_counts["species"] = gene_counts["genome"].map(accession_translator)
gene_counts["species"] = pd.Categorical(gene_counts["species"], categories=accession_translator.values(), ordered=True)

In [None]:
tool_palette = {
    "AUGUSTUS" : "#5A749F",
#    "BRAKER3 (no orthoDB)" : "#FF9898",
    "BRAKER3" : "#B53535",
    "Helixer" : "#622870",
    "geneML" : "#FFAA00",
}

In [None]:
fig, ax = plt.subplots(figsize=(10, 4))

# Plot each tool manually
for tool, sub in gene_counts.groupby("tool"):
    sns.swarmplot(
        data=sub,
        y="species", x="genes",
        color=tool_palette[tool],
        label=tool,
        size=9,
        ax=ax
    )

# Overlay reference marker
sns.pointplot(
    data=gene_counts,
    y="species", x="protein_coding_genes",
    linestyle="none", marker="|", color="black",
    markersize=12, linewidth=2,
    ax=ax
)

# --- 🔧 Draw deviation lines (min–max including reference) ---
species_order = ax.get_yticks()
for i, species in enumerate(ax.get_yticklabels()):
    sname = species.get_text()
    sub = gene_counts[gene_counts["species"] == sname]
    if not sub.empty:
        tool_vals = sub["genes"]
        ref_val = sub["protein_coding_genes"].iloc[0]
        x_min = min(tool_vals.min(), ref_val)
        x_max = max(tool_vals.max(), ref_val)
        y_val = species_order[i]
        # horizontal line connecting min ↔ max (including reference)
        ax.hlines(y=y_val, xmin=x_min, xmax=x_max, color="gray", lw=2, zorder=1)

# --- 🧹 Legend (deduplicated + reference) ---
handles, labels = ax.get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))

ref_handle = mlines.Line2D([], [], color="black", marker="|", linestyle="none",
                           markersize=10, label="Reference")
by_label["Reference"] = ref_handle

ax.legend(by_label.values(), by_label.keys(),
          title="Tool", bbox_to_anchor=(1.05, 1), loc="upper left")

plt.tight_layout()
plt.show()