BASELINE MODELS TABLE

In [None]:
import pandas as pd
from pathlib import Path

metrics = pd.read_csv("../reports/metrics_noq/metrics_summary.tsv", sep="\t")

# Keep only columns you care to show
cols = [
    "model",
    "test_accuracy",
    "test_precision",
    "test_recall",
    "test_f1",
    "test_roc_auc",
]
tab_baseline = metrics[cols].copy()

# Optional: round for prettier LaTeX
tab_baseline = tab_baseline.round(3)

Path("tables").mkdir(exist_ok=True)
tab_baseline.to_latex(
    "tables/baseline_models_noq.tex",
    index=False,
    caption="Performance of baseline classifiers on the held-out test set.",
    label="tab:baseline_models_noq",
    escape=False
)

Tuned models table

In [5]:
tuned = pd.read_csv("../reports/hparam_tuning_noq/tuned_models_summary.tsv", sep="\t")

cols_tuned = [
    "model",
    "test_accuracy",
    "test_precision",
    "test_recall",
    "test_f1",
    "test_roc_auc",
]
tab_tuned = tuned[cols_tuned].copy().round(3)

Path("tables").mkdir(exist_ok=True)
tab_tuned.to_latex(
    "tables/tuned_models_noq.tex",
    index=False,
    caption="Performance of tuned classifiers on the held-out test set.",
    label="tab:tuned_models_noq",
    escape=False
)

Baseline F1 bar chart


In [6]:
from pathlib import Path

Path("figures").mkdir(parents=True, exist_ok=True)

In [7]:
import matplotlib.pyplot as plt

metrics_sorted = metrics.sort_values("test_f1", ascending=False)

plt.figure(figsize=(7, 4))
plt.barh(metrics_sorted["model"], metrics_sorted["test_f1"])
plt.xlabel("Test F1")
plt.title("Baseline models")
plt.gca().invert_yaxis()
plt.tight_layout()
plt.savefig("figures/baseline_f1.png", dpi=300, bbox_inches="tight")
plt.close()

Before vs after tuning (F1 + AUC)

In [9]:
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd

Path("figures").mkdir(parents=True, exist_ok=True)

In [11]:
from pathlib import Path

baseline_path = Path("../reports/metrics_noq/metrics_summary.tsv")
tuned_path    = Path("../reports/hparam_tuning_noq/tuned_models_summary.tsv")

baseline = pd.read_csv(baseline_path, sep="\t")
tuned    = pd.read_csv(tuned_path,    sep="\t")

# Clean up old column if present
for df in (baseline, tuned):
    if "model_base" in df.columns:
        del df["model_base"]

# Normalised key
tuned["model_base"]    = tuned["model"].str.replace("_tuned$", "", regex=True)
baseline["model_base"] = baseline["model"]

# Drop original 'model' to avoid suffix clash
for df in (baseline, tuned):
    if "model" in df.columns:
        del df["model"]

# Merge
merged = baseline.merge(
    tuned,
    on="model_base",
    suffixes=("_base", "_tuned"),
)

# Nice display name and sort by tuned F1
merged["model_name"] = merged["model_base"]
merged = merged.sort_values("test_f1_tuned", ascending=False).reset_index(drop=True)

# This is what the plotting code expects
plot_df = merged.copy()

In [12]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=False)

axes[0].bar(plot_df["model_name"], plot_df["test_f1_base"], label="baseline")
axes[0].bar(plot_df["model_name"], plot_df["test_f1_tuned"], alpha=0.7, label="tuned")
axes[0].set_xticklabels(plot_df["model_name"], rotation=45, ha="right")
axes[0].set_ylabel("Test F1")
axes[0].set_title("F1 before vs after tuning")
axes[0].legend()

axes[1].bar(plot_df["model_name"], plot_df["test_roc_auc_base"], label="baseline")
axes[1].bar(plot_df["model_name"], plot_df["test_roc_auc_tuned"], alpha=0.7, label="tuned")
axes[1].set_xticklabels(plot_df["model_name"], rotation=45, ha="right")
axes[1].set_ylabel("Test ROC–AUC")
axes[1].set_title("ROC–AUC before vs after tuning")
axes[1].legend()

plt.tight_layout()
plt.savefig("figures/f1_auc_tuning.png", dpi=300, bbox_inches="tight")
plt.close()

  axes[0].set_xticklabels(plot_df["model_name"], rotation=45, ha="right")
  axes[1].set_xticklabels(plot_df["model_name"], rotation=45, ha="right")


ROC + PR curves

In [14]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from joblib import load

from mitochime.hyperparam_search_top import load_dataset
from sklearn.metrics import (
    classification_report, confusion_matrix,
    roc_curve, roc_auc_score,
    precision_recall_curve, average_precision_score,
)

In [15]:
# 1) Load test set (no-quality)
X_test, y_test, feature_names = load_dataset("../data/processed/test_noq.tsv")

# 2) Load tuned models you want to analyse
models = {
    "catboost":           load("../models_noq_tuned/catboost_tuned.joblib"),
    "gradient_boosting":  load("../models_noq_tuned/gradient_boosting_tuned.joblib"),
    "random_forest":      load("../models_noq_tuned/random_forest_tuned.joblib"),
    "logreg_l2":          load("../models_noq_tuned/logreg_l2_tuned.joblib"),
}

In [16]:
plt.figure(figsize=(10,4))

# ROC
plt.subplot(1,2,1)
for name, model in models.items():
    if hasattr(model, "predict_proba"):
        scores = model.predict_proba(X_test)[:, 1]
    elif hasattr(model, "decision_function"):
        scores = model.decision_function(X_test)
    else:
        continue

    fpr, tpr, _ = roc_curve(y_test, scores)
    auc = roc_auc_score(y_test, scores)
    plt.plot(fpr, tpr, label=f"{name} (AUC={auc:.3f})")

plt.plot([0,1], [0,1], "k--", lw=1)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC curves")
plt.legend()

# PR
plt.subplot(1,2,2)
for name, model in models.items():
    if hasattr(model, "predict_proba"):
        scores = model.predict_proba(X_test)[:, 1]
    elif hasattr(model, "decision_function"):
        scores = model.decision_function(X_test)
    else:
        continue

    prec, rec, _ = precision_recall_curve(y_test, scores)
    ap = average_precision_score(y_test, scores)
    plt.plot(rec, prec, label=f"{name} (AP={ap:.3f})")

plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision–Recall curves")
plt.legend()

plt.tight_layout()
plt.savefig("figures/roc_pr_curves.png", dpi=300, bbox_inches="tight")
plt.close()

Confusion matrices for 4 models

In [17]:
import seaborn as sns

for name, model in models.items():
    y_pred = model.predict(X_test)
    cm = confusion_matrix(y_test, y_pred)

    plt.figure(figsize=(4,4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=["clean", "chimeric"],
                yticklabels=["clean", "chimeric"])
    plt.title(f"Confusion matrix – {name}")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.savefig(f"figures/cm_{name}.png", dpi=300, bbox_inches="tight")
    plt.close()