In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

In [None]:
sns.set_theme(style="whitegrid")

In [None]:
genome_stats = pd.read_csv("/home/lisvad/mnt/nisin/geneml/genome_stats.tsv", sep="\t")
genome_stats = genome_stats.drop("genes", axis=1)
genome_stats = genome_stats.rename(columns={"protein_coding_genes": "original_annotation", "geneml_genes": "geneml"})
genome_stats = genome_stats.melt(id_vars=["genome","genome_size"], value_vars=["original_annotation", "geneml"], var_name="dataset", value_name="genes")
genome_stats

In [None]:
colour_palette = {
    "original_annotation" : "#20558a",
    "geneml" : "#FFAA00"
}

In [None]:
sns.swarmplot(genome_stats, x="genes", y="dataset", hue="dataset", log_scale=True, size=2.5, palette=colour_palette)

In [None]:
sns.violinplot(genome_stats, x="genes", hue="dataset", log_scale=True, palette=colour_palette, split=True, inner="quart", gap=0.05, saturation=1)

In [None]:
def facet_scatter_with_reg(x, y, color=None, **kwargs):
    ax = plt.gca()

    # Convert to numpy
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)

    # Scatter
    ax.scatter(x, y, color=color, alpha=0.5, **kwargs)

    # Log–log scale
    ax.set_xscale("log")
    ax.set_yscale("log")

    # Regression in log space
    x_log = np.log10(x).reshape(-1, 1)
    y_log = np.log10(y)

    from sklearn.linear_model import LinearRegression
    model = LinearRegression().fit(x_log, y_log)

    # Predict regression line
    x_range = np.linspace(x_log.min(), x_log.max(), 200).reshape(-1, 1)
    y_pred = model.predict(x_range)

    ax.plot(10**x_range.ravel(), 10**y_pred, color=color, lw=2)

# FacetGrid with scatter + regression
g = sns.FacetGrid(genome_stats, col="dataset", hue="dataset", palette=colour_palette, height=6, aspect=1)
g.map(facet_scatter_with_reg, "genome_size", "genes")
g.set_titles("dataset: {col_name}", y=1.25)

# Compute global bins
x_bins = np.logspace(np.log10(genome_stats["genome_size"].min()),
                     np.log10(genome_stats["genome_size"].max()), 20)
y_bins = np.logspace(np.log10(genome_stats["genes"].min()),
                     np.log10(genome_stats["genes"].max()), 20)

# Add marginal histograms for each facet
for (col_val, ax) in g.axes_dict.items():
    sub = genome_stats[genome_stats["dataset"] == col_val]
    colour = colour_palette[col_val]

    ax_x = ax.inset_axes([0, 1.05, 1, 0.2], transform=ax.transAxes)
    ax_y = ax.inset_axes([1.05, 0, 0.2, 1], transform=ax.transAxes)

    # Top histogram (same global bins for all facets)
    ax_x.hist(sub["genome_size"], bins=x_bins, color=colour)
    ax_x.set_xscale("log")
    ax_x.set_xlim(ax.get_xlim())
    ax_x.axis("off")

    # Right histogram (same global bins for all facets)
    ax_y.hist(sub["genes"], bins=y_bins, orientation="horizontal", color=colour)
    ax_y.set_yscale("log")
    ax_y.set_ylim(ax.get_ylim())
    ax_y.axis("off")

plt.tight_layout()
plt.show()