In [None]:
import json
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import seaborn as sns

from scipy.stats import linregress

In [None]:
with open("accession_translator.json", "r") as s:
    accession_translator = json.load(s)

In [None]:
sns.set_theme(style="whitegrid")

In [None]:
datasets = {
    "augustus": "AUGUSTUS",
#    "braker3_noprotdb": "BRAKER3 (no orthoDB)",
    "braker3": "BRAKER3",
    "helixer": "Helixer",
    "geneml": "geneML",
}

In [None]:
timings = pd.read_csv("/home/lisvad/mnt/nisin/geneml/timings.tsv", delimiter='\t')
timings = timings.drop(columns=["braker3_noprotdb"])
timings["species"] = timings["genome_id"].map(accession_translator)
timings = timings.melt(id_vars=["genome_id", "species"],value_vars=["augustus","braker3","geneml","helixer"],var_name="tool",value_name="time")
tool_order = timings.groupby("tool")["time"].mean().sort_values().index
timings["tool"] = pd.Categorical(timings["tool"], categories=tool_order, ordered=True)
timings["tool"] = timings["tool"].map(datasets)
species_order = timings.groupby("species")["time"].mean().sort_values().index
timings["species"] = pd.Categorical(timings["species"], categories=species_order, ordered=True)

In [None]:
timings.groupby(["tool"])["time"].agg(['mean','std'])

In [None]:
tool_palette = {
    "AUGUSTUS" : "#5A749F",
#    "BRAKER3 (no orthoDB)" : "#FF9898",
    "BRAKER3" : "#B53535",
    "Helixer" : "#622870",
    "geneML" : "#FFAA00",
#    "geneML repeat masked 100bp": "#5FC047",
#    "geneML repeat masked 200bp": "#237010",
}

In [None]:
timings["tool"] = pd.Categorical(timings["tool"], categories=datasets.values(), ordered=True)
fig, ax = plt.subplots(figsize=(7,4))
sns.swarmplot(timings, x="time", y="tool", hue="tool", palette=tool_palette, log_scale=True, ax=ax, size=6)
plt.xlabel("time (hours)")
ax.set_xlim(0.01,50)
ax.xaxis.set_major_formatter(mtick.ScalarFormatter())

In [None]:
benchmarking_stats = pd.read_csv("/home/lisvad/mnt/nisin/geneml/benchmarking_stats.tsv", sep="\t")
size_mapping = dict(zip(benchmarking_stats["genome"], benchmarking_stats["genome_size"]))
timings["genome_size"] = timings["genome_id"].map(size_mapping)

In [None]:
ax = sns.scatterplot(
    data=timings,
    x="genome_size",
    y="time",
    hue="tool",
    palette=tool_palette
)

# Fit log–log regression per tool
for tool, group in timings.groupby("tool"):
    x = group["genome_size"]
    y = group["time"]

    # Remove non-positive values (log undefined for <= 0)
    mask = (x > 0) & (y > 0)
    x = x[mask]
    y = y[mask]

    # Fit regression in log–log space
    slope, intercept, r, p, stderr = linregress(np.log10(x), np.log10(y))

    # Generate predicted line
    x_line = np.linspace(x.min(), x.max(), 100)
    y_line = 10**(intercept + slope * np.log10(x_line))

    # Plot regression line
    ax.plot(x_line, y_line, color=tool_palette[tool], lw=2, label=f"{tool} fit (slope={slope:.2f})")

# Log–log axes
ax.set_xscale("log")
ax.set_yscale("log")

# Axis limits and ticks
ax.set_xlim(1e7, 1.8e8)
ax.set_ylim(0.01, 60)
ax.yaxis.set_major_formatter(mtick.ScalarFormatter())